From d1c8b81386fcf1e578c6db288af0ee5f2b335990 Mon Sep 17 00:00:00 2001
From: Caroline DE POURTALES <cdepourt@montana.irit.fr>
Date: Thu, 23 Jun 2022 17:30:00 +0200
Subject: [PATCH] change padding handling

---
 Configuration/config.ini     |  6 ++--
 Linker/DataParallelLinker.py | 64 ++++++++++++++++++++++++++++++++++++
 2 files changed, 67 insertions(+), 3 deletions(-)
 create mode 100644 Linker/DataParallelLinker.py

diff --git a/Configuration/config.ini b/Configuration/config.ini
index 61872f4..b33d6df 100644
--- a/Configuration/config.ini
+++ b/Configuration/config.ini
@@ -4,7 +4,7 @@ transformers = 4.16.2
 [DATASET_PARAMS]
 symbols_vocab_size=26
 atom_vocab_size=18
-max_len_sentence=290
+max_len_sentence=83
 max_atoms_in_sentence=875
 max_atoms_in_one_type=324
 
@@ -12,10 +12,10 @@ max_atoms_in_one_type=324
 dim_encoder = 768
 
 [MODEL_LINKER]
-nhead=8
+nhead=16
 dim_emb_atom = 256
 dim_feedforward_transformer = 768
-num_layers=2
+num_layers=3
 dim_cat_inter=512
 dim_cat_out=256
 dim_intermediate_FFN=128
diff --git a/Linker/DataParallelLinker.py b/Linker/DataParallelLinker.py
new file mode 100644
index 0000000..5885845
--- /dev/null
+++ b/Linker/DataParallelLinker.py
@@ -0,0 +1,64 @@
+import datetime
+
+from torch.nn import DataParallel, Module
+from Linker import *
+
+
+class DataParallelModel(Module):
+
+    def __init__(self):
+        super().__init__()
+        self.linker = DataParallel(Linker("models/flaubert_super_98_V2_50e.pt"))
+
+    def forward(self, x):
+        x = self.linker(x)
+        return x
+
+    def train_linker(self, df_axiom_links, validation_rate=0.1, epochs=20,
+                     batch_size=32, checkpoint=True, tensorboard=False):
+        r"""
+        Args:
+            df_axiom_links : pandas dataFrame containing the atoms anoted with _i
+            validation_rate : float
+            epochs : int
+            batch_size : int
+            checkpoint : boolean
+            tensorboard : boolean
+        Returns:
+            Final accuracy and final loss
+        """
+        training_dataloader, validation_dataloader = self.__preprocess_data(batch_size, df_axiom_links,
+                                                                            validation_rate)
+        if checkpoint or tensorboard:
+            checkpoint_dir, writer = output_create_dir()
+
+        for epoch_i in range(epochs):
+            print("")
+            print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs))
+            print('Training...')
+            avg_train_loss, avg_accuracy_train, training_time = self.train_epoch(training_dataloader)
+
+            print("")
+            print(f'Epoch: {epoch_i + 1:02} | Epoch Time: {training_time}')
+            print(f'\tTrain Loss: {avg_train_loss:.3f} | Train Acc: {avg_accuracy_train * 100:.2f}%')
+
+            if validation_rate > 0.0:
+                loss_test, accuracy_test = self.eval_epoch(validation_dataloader)
+                print(f'\tVal Loss: {loss_test:.3f} | Val Acc: {accuracy_test * 100:.2f}%')
+
+            if checkpoint:
+                self.__checkpoint_save(
+                    path=os.path.join("Output", 'linker' + datetime.today().strftime('%d-%m_%H-%M') + '.pt'))
+
+            if tensorboard:
+                writer.add_scalars(f'Accuracy', {
+                    'Train': avg_accuracy_train}, epoch_i)
+                writer.add_scalars(f'Loss', {
+                    'Train': avg_train_loss}, epoch_i)
+                if validation_rate > 0.0:
+                    writer.add_scalars(f'Accuracy', {
+                        'Validation': accuracy_test}, epoch_i)
+                    writer.add_scalars(f'Loss', {
+                        'Validation': loss_test}, epoch_i)
+
+            print('\n')
\ No newline at end of file
-- 
GitLab