From 1f43915a006c860577b7710e01d4706c3fc0cf1d Mon Sep 17 00:00:00 2001
From: Caroline DE POURTALES <cdepourt@montana.irit.fr>
Date: Tue, 17 May 2022 17:14:18 +0200
Subject: [PATCH] update linker encoding

---
 Configuration/config.ini |  7 +--
 Linker/Linker.py         | 92 ++++++++++++++++++++++++++++------------
 Linker/__init__.py       |  1 +
 Linker/utils_linker.py   | 72 ++++++++++++++++---------------
 main.py                  |  4 +-
 train.py                 |  9 +---
 6 files changed, 108 insertions(+), 77 deletions(-)

diff --git a/Configuration/config.ini b/Configuration/config.ini
index 15547f6..c79def5 100644
--- a/Configuration/config.ini
+++ b/Configuration/config.ini
@@ -31,9 +31,4 @@ device=cpu
 batch_size=32
 epoch=20
 seed_val=42
-learning_rate=0.005
-use_checkpoint_SAVE=0
-output_path=Output
-use_checkpoint_LOAD=0
-input_path=Input
-model_to_load=model_check.pt
\ No newline at end of file
+learning_rate=0.005
\ No newline at end of file
diff --git a/Linker/Linker.py b/Linker/Linker.py
index 7f7462d..d2a4d89 100644
--- a/Linker/Linker.py
+++ b/Linker/Linker.py
@@ -37,7 +37,7 @@ class Linker(Module):
         self.atom_vocab_size = int(Configuration.datasetConfig['atom_vocab_size'])
         learning_rate = float(Configuration.modelTrainingConfig['learning_rate'])
         self.dropout = Dropout(0.1)
-        self.device = ""
+        self.device = "cpu"
 
         self.Supertagger = supertagger
 
@@ -66,6 +66,16 @@ class Linker(Module):
                                                          num_training_steps=100)
 
     def __preprocess_data(self, batch_size, df_axiom_links, sentences_tokens, sentences_mask, validation_rate=0.0):
+        r"""
+        Args:
+            batch_size : int
+            df_axiom_links pandas DataFrame
+            sentences_tokens
+            sentences_mask
+            validation_rate
+        Returns:
+            the training dataloader and the validation dataloader. They contains the list of atoms, their polarities, the axiom links, the sentences tokenized, sentence mask
+        """
         atoms_batch = get_atoms_batch(df_axiom_links["sub_tree"])
         atom_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence)
         atoms_batch_tokenized = atom_tokenizer.convert_batchs_to_ids(atoms_batch)
@@ -98,22 +108,21 @@ class Linker(Module):
         return decoder_attn_mask.unsqueeze(1).repeat(1, atoms_token.shape[1], 1).repeat(self.nhead, 1, 1)
 
     def forward(self, atoms_batch_tokenized, atoms_polarity_batch, sents_embedding, sents_mask=None):
-        r'''
-        Parameters :
-        atoms_batch_tokenized : (batch_size, max_atoms_in_one_sentence) flattened categories
-        atoms_polarity_batch : (batch_size, max_atoms_in_one_sentence) flattened categories polarities
-        sents_embedding : (batch_size, len_sentence, dim_encoder) output of BERT for context
-        sents_mask
-        Returns :
-        link_weights : atom_vocab_size, batch-size, max_atoms_in_one_cat, max_atoms_in_one_cat)
-        '''
+        r"""
+        Args:
+            atoms_batch_tokenized : (batch_size, max_atoms_in_one_sentence) flattened categories
+            atoms_polarity_batch : (batch_size, max_atoms_in_one_sentence) flattened categories polarities
+            sents_embedding : (batch_size, len_sentence, dim_encoder) output of BERT for context
+            sents_mask : mask from BERT tokenizer
+        Returns:
+            link_weights : atom_vocab_size, batch-size, max_atoms_in_one_cat, max_atoms_in_one_cat)
+        """
 
         # atoms embedding
         atoms_embedding = self.atoms_embedding(atoms_batch_tokenized)
 
         # MHA ou LSTM avec sortie de BERT
-        batch_size, _, _ = sents_embedding.shape
-        sents_mask = torch.randn(batch_size * self.nhead, self.max_atoms_in_sentence, self.max_len_sentence)
+        sents_mask = sents_mask.unsqueeze(1).repeat(self.nhead, self.max_atoms_in_sentence, 1).to(torch.float64)
         atoms_encoding = self.linker_encoder(atoms_embedding, sents_embedding, sents_mask,
                                              self.make_decoder_mask(atoms_batch_tokenized))
 
@@ -147,15 +156,35 @@ class Linker(Module):
 
     def train_linker(self, df_axiom_links, sentences_tokens, sentences_mask, validation_rate=0.1, epochs=20,
                      batch_size=32, checkpoint=True, validate=True):
-
+        r"""
+        Args:
+            df_axiom_links : pandas dataFrame containing the atoms anoted with _i
+            sentences_tokens : sentences tokenized by BERT
+            sentences_mask : mask of tokens
+            validation_rate : float
+            epochs : int
+            batch_size : int
+            checkpoint : boolean
+            validate : boolean
+        Returns:
+            Final accuracy and final loss
+        """
         training_dataloader, validation_dataloader = self.__preprocess_data(batch_size, df_axiom_links,
                                                                             sentences_tokens, sentences_mask,
                                                                             validation_rate)
-
         for epoch_i in range(0, epochs):
             epoch_acc, epoch_loss = self.train_epoch(training_dataloader, validation_dataloader, checkpoint, validate)
 
     def train_epoch(self, training_dataloader, validation_dataloader, checkpoint=True, validate=True):
+        r""" Train epoch
+
+        Args:
+            training_dataloader : DataLoader from torch , contains atoms, polarities, axiom_links, sents_tokenized, sents_masks
+            validation_dataloader : DataLoader from torch , contains atoms, polarities, axiom_links, sents_tokenized, sents_masks
+        Returns:
+             accuracy on validation set
+             loss on train set
+        """
 
         # Reset the total loss for this epoch.
         epoch_loss = 0
@@ -195,8 +224,8 @@ class Linker(Module):
         print("Average Loss on train dataset : ", avg_train_loss)
 
         if checkpoint:
-            checkpoint_dir = os.path.join("Output", 'Tranning_' + datetime.today().strftime('%d-%m_%H-%M'))
-            self.__checkpoint_save(path=os.path.join(checkpoint_dir, 'model_check.pt'))
+            self.__checkpoint_save(
+                path=os.path.join("Output", 'linker' + datetime.today().strftime('%d-%m_%H-%M') + '.pt'))
 
         if validate:
             with torch.no_grad():
@@ -204,17 +233,20 @@ class Linker(Module):
                 print("Average Loss on test dataset : ", average_test_loss)
                 print("Average Accuracy on test dataset : ", accuracy)
 
+        print('\n')
+
         return accuracy, avg_train_loss
 
     def predict(self, categories, sents_embedding, sents_mask=None):
-        r'''
-        Parameters :
-        categories : (batch_size, len_sentence)
-        sents_embedding : (batch_size, len_sentence, dim_encoder) output of BERT for context
-        sents_mask
-        Returns :
-        axiom_links : atom_vocab_size, batch-size, max_atoms_in_one_cat)
-        '''
+        r"""Prediction from categories output by BERT and hidden_state from BERT
+
+        Args:
+            categories : (batch_size, len_sentence)
+            sents_embedding : (batch_size, len_sentence, dim_encoder) output of BERT for context
+            sents_mask
+        Returns:
+            axiom_links : atom_vocab_size, batch-size, max_atoms_in_one_cat)
+        """
         self.eval()
 
         # get atoms
@@ -268,8 +300,9 @@ class Linker(Module):
         batch_sentences_tokens = batch[3].to("cuda" if torch.cuda.is_available() else "cpu")
         batch_sentences_mask = batch[4].to("cuda" if torch.cuda.is_available() else "cpu")
 
-        logits_axiom_links_pred = self.forward(batch_atoms, batch_polarity, batch_sentences_tokens,
-                                               batch_sentences_mask)
+        logits, sentences_embedding = self.Supertagger.foward(batch_sentences_tokens, batch_sentences_mask)
+        logits_axiom_links_pred = self(batch_atoms, batch_polarity, sentences_embedding,
+                                       batch_sentences_mask)
         logits_axiom_links_pred = logits_axiom_links_pred.permute(1, 0, 2, 3)
         axiom_links_pred = torch.argmax(F.softmax(logits_axiom_links_pred, dim=3), dim=3)
 
@@ -315,7 +348,10 @@ class Linker(Module):
         print("#" * 15)
 
     def __checkpoint_save(self, path='/linker.pt'):
-        self.linker.cpu()
+        """
+        @param path:
+        """
+        self.cpu()
 
         torch.save({
             'args': dict(atom_map=self.atom_map, max_atoms_in_sentence=self.max_atoms_in_sentence),
@@ -325,4 +361,4 @@ class Linker(Module):
             'neg_transformation': self.neg_transformation.state_dict(),
             'optimizer': self.optimizer,
         }, path)
-        self.linker.to(self.device)
+        #self.to(self.device)
diff --git a/Linker/__init__.py b/Linker/__init__.py
index e69de29..c0df5b8 100644
--- a/Linker/__init__.py
+++ b/Linker/__init__.py
@@ -0,0 +1 @@
+from .Linker import Linker
\ No newline at end of file
diff --git a/Linker/utils_linker.py b/Linker/utils_linker.py
index da295de..13c63f4 100644
--- a/Linker/utils_linker.py
+++ b/Linker/utils_linker.py
@@ -32,14 +32,14 @@ regex_categories = r'\w+\(\d+,(?:((?R))|(\w+))*,?(?:((?R))|(\w+))*\)'
 
 
 def get_axiom_links(max_atoms_in_one_type, atoms_polarity, batch_axiom_links):
-    r'''
-    Parameters :
-    max_atoms_in_one_type : configuration
-    atoms_polarity : (batch_size, max_atoms_in_sentence)
-    batch_axiom_links : (batch_size, len_sentence) categories with the _i which allows linking atoms
-    Returns :
-    batch_true_links : (batch_size, atom_vocab_size, max_atoms_in_one_cat) contains the index of the negative atoms
-    '''
+    r"""
+    Args:
+        max_atoms_in_one_type : configuration
+        atoms_polarity : (batch_size, max_atoms_in_sentence)
+        batch_axiom_links : (batch_size, len_sentence) categories with the _i which allows linking atoms
+    Returns:
+        batch_true_links : (batch_size, atom_vocab_size, max_atoms_in_one_cat) contains the index of the negative atoms
+    """
     atoms_batch = get_atoms_links_batch(batch_axiom_links)
     linking_plus_to_minus_all_types = []
     for atom_type in list(atom_map.keys())[:-1]:
@@ -62,13 +62,13 @@ def get_axiom_links(max_atoms_in_one_type, atoms_polarity, batch_axiom_links):
 
 
 def category_to_atoms_axiom_links(category, categories_to_atoms):
-    r'''
-    Parameters :
-    category
-    categories_to_atoms : recursive list
+    r"""
+    Args:
+        category : str of kind AtomCat | CategoryCat(dr or dl)
+        categories_to_atoms : recursive list
     Returns :
-    List of atoms inside the category in prefix order
-    '''
+        List of atoms inside the category in prefix order
+    """
     res = [bool(re.match(r'' + atom_type + "_\d+", category)) for atom_type in atom_map.keys()]
     if category.startswith("GOAL:"):
         word, cat = category.split(':')
@@ -85,7 +85,8 @@ def category_to_atoms_axiom_links(category, categories_to_atoms):
 
 def get_atoms_links_batch(category_batch):
     r"""
-    category_batch : (batch_size, max_atoms_in_sentence) flattened categories in prefix order
+    Args:
+        category_batch : (batch_size, max_atoms_in_sentence) flattened categories in prefix order
     Returns :
      (batch_size, max_atoms_in_sentence) flattened categories in prefix order
     """
@@ -104,13 +105,13 @@ def get_atoms_links_batch(category_batch):
 
 
 def category_to_atoms(category, categories_to_atoms):
-    r'''
-    Parameters :
-    category
-    categories_to_atoms : recursive list
-    Returns :
-    List of atoms inside the category in prefix order
-    '''
+    r"""
+    Args:
+        category : str of kind AtomCat | CategoryCat(dr or dl)
+        categories_to_atoms : recursive list
+    Returns:
+        List of atoms inside the category in prefix order
+    """
     res = [bool(re.match(r'' + atom_type + "_\d+", category)) for atom_type in atom_map.keys()]
     if category.startswith("GOAL:"):
         word, cat = category.split(':')
@@ -129,8 +130,9 @@ def category_to_atoms(category, categories_to_atoms):
 
 def get_atoms_batch(category_batch):
     r"""
-    category_batch : (batch_size, max_atoms_in_sentence) flattened categories in prefix order
-    Returns :
+    Args:
+        category_batch : (batch_size, max_atoms_in_sentence) flattened categories in prefix order
+    Returns:
      (batch_size, max_atoms_in_sentence) flattened categories in prefix order
     """
     batch = []
@@ -147,12 +149,13 @@ def get_atoms_batch(category_batch):
 #########################################################################################
 
 def category_to_atoms_polarity(category, polarity):
-    r'''
-    Parameters :
-    category : str of kind AtomCat | CategoryCat(dr or dl)
-    Returns :
-    Boolean Tensor of shape max_symbols_in_word, containing 1 for pos indexes and 0 for neg indexes
-    '''
+    r"""
+    Args:
+        category : str of kind AtomCat | CategoryCat(dr or dl)
+        polarity : polarity according to recursivity
+    Returns:
+        Boolean Tensor of shape max_symbols_in_word, containing 1 for pos indexes and 0 for neg indexes
+    """
     category_to_polarity = []
     res = [bool(re.match(r'' + atom_type + "_\d+", category)) for atom_type in atom_map.keys()]
     if category.startswith("GOAL:"):
@@ -233,10 +236,11 @@ def category_to_atoms_polarity(category, polarity):
 
 def find_pos_neg_idexes(max_atoms_in_sentence, atoms_batch):
     r"""
-    max_atoms_in_sentence : configuration
-    atoms_batch : (batch_size, max_atoms_in_sentence) flattened categories in prefix order
-    Returns :
-     (batch_size, max_atoms_in_sentence) flattened categories'polarities in prefix order
+    Args:
+        max_atoms_in_sentence : configuration
+        atoms_batch : (batch_size, max_atoms_in_sentence) flattened categories in prefix order
+    Returns:
+        (batch_size, max_atoms_in_sentence) flattened categories'polarities in prefix order
     """
     list_batch = []
     for sentence in atoms_batch:
diff --git a/main.py b/main.py
index 55e8c52..14d3fc0 100644
--- a/main.py
+++ b/main.py
@@ -1,8 +1,8 @@
 import torch.nn.functional as F
 import torch
 from Configuration import Configuration
-from Linker.Linker import Linker
-from Supertagger.SuperTagger.SuperTagger import SuperTagger
+from Linker import *
+from Supertagger import *
 
 max_atoms_in_sentence = int(Configuration.datasetConfig['max_atoms_in_sentence'])
 
diff --git a/train.py b/train.py
index f83a951..bc2f785 100644
--- a/train.py
+++ b/train.py
@@ -1,12 +1,10 @@
 import torch
-
 from Configuration import Configuration
-from Linker.Linker import Linker
-from Supertagger.SuperTagger.SuperTagger import SuperTagger
+from Linker import *
+from Supertagger import *
 from utils import read_csv_pgbar
 
 torch.cuda.empty_cache()
-
 batch_size = int(Configuration.modelTrainingConfig['batch_size'])
 nb_sentences = batch_size * 10
 epochs = int(Configuration.modelTrainingConfig['epoch'])
@@ -15,14 +13,11 @@ file_path_axiom_links = 'Datasets/aa1_links_dataset_links.csv'
 df_axiom_links = read_csv_pgbar(file_path_axiom_links, nb_sentences)
 
 sentences_batch = df_axiom_links["Sentences"].tolist()
-
 supertagger = SuperTagger()
 supertagger.load_weights("models/model_supertagger.pt")
-
 sents_tokenized, sents_mask = supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch)
 
 print("Linker")
 linker = Linker(supertagger)
-
 print("Linker Training")
 linker.train_linker(df_axiom_links, sents_tokenized, sents_mask, validation_rate=0.1, epochs=epochs, batch_size=batch_size, checkpoint=True, validate=True)
-- 
GitLab