From d1c8b81386fcf1e578c6db288af0ee5f2b335990 Mon Sep 17 00:00:00 2001 From: Caroline DE POURTALES <cdepourt@montana.irit.fr> Date: Thu, 23 Jun 2022 17:30:00 +0200 Subject: [PATCH] change padding handling --- Configuration/config.ini | 6 ++-- Linker/DataParallelLinker.py | 64 ++++++++++++++++++++++++++++++++++++ 2 files changed, 67 insertions(+), 3 deletions(-) create mode 100644 Linker/DataParallelLinker.py diff --git a/Configuration/config.ini b/Configuration/config.ini index 61872f4..b33d6df 100644 --- a/Configuration/config.ini +++ b/Configuration/config.ini @@ -4,7 +4,7 @@ transformers = 4.16.2 [DATASET_PARAMS] symbols_vocab_size=26 atom_vocab_size=18 -max_len_sentence=290 +max_len_sentence=83 max_atoms_in_sentence=875 max_atoms_in_one_type=324 @@ -12,10 +12,10 @@ max_atoms_in_one_type=324 dim_encoder = 768 [MODEL_LINKER] -nhead=8 +nhead=16 dim_emb_atom = 256 dim_feedforward_transformer = 768 -num_layers=2 +num_layers=3 dim_cat_inter=512 dim_cat_out=256 dim_intermediate_FFN=128 diff --git a/Linker/DataParallelLinker.py b/Linker/DataParallelLinker.py new file mode 100644 index 0000000..5885845 --- /dev/null +++ b/Linker/DataParallelLinker.py @@ -0,0 +1,64 @@ +import datetime + +from torch.nn import DataParallel, Module +from Linker import * + + +class DataParallelModel(Module): + + def __init__(self): + super().__init__() + self.linker = DataParallel(Linker("models/flaubert_super_98_V2_50e.pt")) + + def forward(self, x): + x = self.linker(x) + return x + + def train_linker(self, df_axiom_links, validation_rate=0.1, epochs=20, + batch_size=32, checkpoint=True, tensorboard=False): + r""" + Args: + df_axiom_links : pandas dataFrame containing the atoms anoted with _i + validation_rate : float + epochs : int + batch_size : int + checkpoint : boolean + tensorboard : boolean + Returns: + Final accuracy and final loss + """ + training_dataloader, validation_dataloader = self.__preprocess_data(batch_size, df_axiom_links, + validation_rate) + if checkpoint or tensorboard: + checkpoint_dir, writer = output_create_dir() + + for epoch_i in range(epochs): + print("") + print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs)) + print('Training...') + avg_train_loss, avg_accuracy_train, training_time = self.train_epoch(training_dataloader) + + print("") + print(f'Epoch: {epoch_i + 1:02} | Epoch Time: {training_time}') + print(f'\tTrain Loss: {avg_train_loss:.3f} | Train Acc: {avg_accuracy_train * 100:.2f}%') + + if validation_rate > 0.0: + loss_test, accuracy_test = self.eval_epoch(validation_dataloader) + print(f'\tVal Loss: {loss_test:.3f} | Val Acc: {accuracy_test * 100:.2f}%') + + if checkpoint: + self.__checkpoint_save( + path=os.path.join("Output", 'linker' + datetime.today().strftime('%d-%m_%H-%M') + '.pt')) + + if tensorboard: + writer.add_scalars(f'Accuracy', { + 'Train': avg_accuracy_train}, epoch_i) + writer.add_scalars(f'Loss', { + 'Train': avg_train_loss}, epoch_i) + if validation_rate > 0.0: + writer.add_scalars(f'Accuracy', { + 'Validation': accuracy_test}, epoch_i) + writer.add_scalars(f'Loss', { + 'Validation': loss_test}, epoch_i) + + print('\n') \ No newline at end of file -- GitLab