From 4004d7c27c78870fd1f0cc754b7ee4b68e915bcd Mon Sep 17 00:00:00 2001
From: Caroline DE POURTALES <cdepourt@montana.irit.fr>
Date: Tue, 28 Jun 2022 11:14:23 +0200
Subject: [PATCH] trainning

---
 Configuration/config.ini |  2 +-
 Linker/Linker.py         |  7 +---
 Linker/utils_linker.py   | 87 ++++++++++++++++++----------------------
 train.py                 |  2 +-
 4 files changed, 42 insertions(+), 56 deletions(-)

diff --git a/Configuration/config.ini b/Configuration/config.ini
index e0d94a3..73f1743 100644
--- a/Configuration/config.ini
+++ b/Configuration/config.ini
@@ -4,7 +4,7 @@ transformers = 4.16.2
 [DATASET_PARAMS]
 symbols_vocab_size=26
 atom_vocab_size=18
-max_len_sentence=83
+max_len_sentence=290
 max_atoms_in_sentence=875
 max_atoms_in_one_type=324
 
diff --git a/Linker/Linker.py b/Linker/Linker.py
index 2d79231..99524a2 100644
--- a/Linker/Linker.py
+++ b/Linker/Linker.py
@@ -147,11 +147,8 @@ class Linker(Module):
         sentences_tokens, sentences_mask = self.Supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch)
         print(sentences_tokens)
 
-        atoms_batch, atoms_polarity_batch = get_GOAL(self.max_atoms_in_sentence, df_axiom_links["Z"])
-        atoms_batch_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(
-            list(map(lambda sentence: [item for item in sentence.split(" ")], atoms_batch)))
-
-        num_atoms_per_word = get_num_atoms_batch(df_axiom_links["Z"], self.max_len_sentence)
+        atoms_batch, atoms_polarity_batch, num_atoms_per_word = get_GOAL(self.max_len_sentence, self.max_atoms_in_sentence, df_axiom_links["Z"])
+        atoms_batch_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms_batch)
 
         pos_idx = get_pos_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type, self.max_atoms_in_sentence)
         neg_idx = get_neg_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type, self.max_atoms_in_sentence)
diff --git a/Linker/utils_linker.py b/Linker/utils_linker.py
index c7c7a8d..1c0ab28 100644
--- a/Linker/utils_linker.py
+++ b/Linker/utils_linker.py
@@ -40,7 +40,6 @@ def get_axiom_links(max_atoms_in_one_type, atoms_polarity, batch_axiom_links):
         batch_true_links : (batch_size, atom_vocab_size, max_atoms_in_one_cat) contains the index of the negative atoms
     """
     atoms_batch = get_atoms_links_batch(batch_axiom_links)
-    atoms_batch = list(map(lambda sentence: sentence.split(" "), atoms_batch))
     linking_plus_to_minus_all_types = []
     for atom_type in list(atom_map_redux.keys()):
         # filtrer sur atom_batch que ce type puis filtrer avec les indices sur atom polarity
@@ -76,12 +75,12 @@ def category_to_atoms_axiom_links(category, categories_to_atoms):
         word, cat = category.split(':')
         return category_to_atoms_axiom_links(cat, categories_to_atoms)
     elif True in res:
-        return " " + category
+        return [category]
     else:
         category_cut = regex.match(regex_categories_axiom_links, category).groups()
         category_cut = [cat for cat in category_cut if cat is not None]
         for cat in category_cut:
-            categories_to_atoms += category_to_atoms_axiom_links(cat, "")
+            categories_to_atoms += category_to_atoms_axiom_links(cat, [])
         return categories_to_atoms
 
 
@@ -94,15 +93,13 @@ def get_atoms_links_batch(category_batch):
     """
     batch = []
     for sentence in category_batch:
-        categories_to_atoms = ""
+        categories_to_atoms = []
         for category in sentence:
             if category != "let" and not category.startswith("GOAL:"):
-                categories_to_atoms += category_to_atoms_axiom_links(category, "")
-                categories_to_atoms += " [SEP]"
-                categories_to_atoms = categories_to_atoms.lstrip()
+                categories_to_atoms += category_to_atoms_axiom_links(category, [])
+                categories_to_atoms.append("[SEP]")
             elif category.startswith("GOAL:"):
-                categories_to_atoms += category_to_atoms_axiom_links(category, "")
-                categories_to_atoms = categories_to_atoms.lstrip()
+                categories_to_atoms += category_to_atoms_axiom_links(category, [])
         batch.append(categories_to_atoms)
     return batch
 
@@ -132,12 +129,12 @@ def category_to_atoms(category, categories_to_atoms):
         word, cat = category.split(':')
         return category_to_atoms(cat, categories_to_atoms)
     elif True in res:
-        return " " + category
+        return [category]
     else:
         category_cut = regex.match(regex_categories, category).groups()
         category_cut = [cat for cat in category_cut if cat is not None]
         for cat in category_cut:
-            categories_to_atoms += category_to_atoms(cat, "")
+            categories_to_atoms += category_to_atoms(cat, [])
         return categories_to_atoms
 
 
@@ -150,17 +147,16 @@ def get_atoms_batch(category_batch):
     """
     batch = []
     for sentence in category_batch:
-        categories_to_atoms = ""
+        categories_to_atoms = []
         for category in sentence:
             if category != "let":
-                categories_to_atoms += category_to_atoms(category, "")
-                categories_to_atoms += " [SEP]"
-                categories_to_atoms = categories_to_atoms.lstrip()
+                categories_to_atoms += category_to_atoms(category, [])
+                categories_to_atoms.append("[SEP]")
         batch.append(categories_to_atoms)
     return batch
 
 
-print(" test for get atoms in categories on ['dr(0,s,np)', 'let']", get_atoms_batch([["dr(0,s,np)", "let"]]))
+print(" test for get atoms in categories on ['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'let']", get_atoms_batch([['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'let']]))
 
 
 # endregion
@@ -319,36 +315,35 @@ print(
 
 # region get atoms and polarities with GOAL
 
-def get_GOAL(max_atoms_in_sentence, categories_batch):
+def get_GOAL(max_len_sentence, max_atoms_in_sentence, categories_batch):
     polarities = find_pos_neg_idexes(categories_batch)
     atoms_batch = get_atoms_batch(categories_batch)
-    atoms_batch_for_polarities = list(
-        map(lambda sentence: [item for item in sentence.split(" ")], atoms_batch))
+    num_atoms_batch = get_num_atoms_batch(categories_batch, max_len_sentence)
     for s_idx in range(len(atoms_batch)):
         for atom_type in list(atom_map_redux.keys()):
-            list_plus = [x for i, x in enumerate(atoms_batch_for_polarities[s_idx]) if polarities[s_idx][i]
-                         and bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch_for_polarities[s_idx][i]))]
-            list_minus = [x for i, x in enumerate(atoms_batch_for_polarities[s_idx]) if not polarities[s_idx][i]
-                          and bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch_for_polarities[s_idx][i]))]
+            list_plus = [x for i, x in enumerate(atoms_batch[s_idx]) if polarities[s_idx][i]
+                         and bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch[s_idx][i]))]
+            list_minus = [x for i, x in enumerate(atoms_batch[s_idx]) if not polarities[s_idx][i]
+                          and bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch[s_idx][i]))]
             while len(list_minus) != len(list_plus):
                 if len(list_minus) > len(list_plus):
-                    atoms_batch[s_idx] += " " + atom_type
-                    atoms_batch_for_polarities[s_idx].append(atom_type)
-                    polarities[s_idx].append(True)
+                    atoms_batch[s_idx].insert(0, atom_type)
+                    polarities[s_idx].insert(0, True)
+                    num_atoms_batch[s_idx][0] += 1
                 else:
-                    atoms_batch[s_idx] += " " + atom_type
-                    atoms_batch_for_polarities[s_idx].append(atom_type)
-                    polarities[s_idx].append(False)
-                list_plus = [x for i, x in enumerate(atoms_batch_for_polarities[s_idx]) if polarities[s_idx][i]
-                             and atoms_batch_for_polarities[s_idx][i] == atom_type]
-                list_minus = [x for i, x in enumerate(atoms_batch_for_polarities[s_idx]) if not polarities[s_idx][i]
-                              and atoms_batch_for_polarities[s_idx][i] == atom_type]
+                    atoms_batch[s_idx].insert(0, atom_type)
+                    polarities[s_idx].insert(0, False)
+                    num_atoms_batch[s_idx][0] += 1
+                list_plus = [x for i, x in enumerate(atoms_batch[s_idx]) if polarities[s_idx][i]
+                             and atoms_batch[s_idx][i] == atom_type]
+                list_minus = [x for i, x in enumerate(atoms_batch[s_idx]) if not polarities[s_idx][i]
+                              and atoms_batch[s_idx][i] == atom_type]
 
     return atoms_batch, pad_sequence([torch.as_tensor(polarities[i], dtype=torch.bool) for i in range(len(polarities))],
-                                     max_len=max_atoms_in_sentence, padding_value=0)
+                                     max_len=max_atoms_in_sentence, padding_value=0), num_atoms_batch
 
 
-print(" test for get GOAL on ['dr(0,s,np)', 's']", get_GOAL(12, [["dr(0,s,np)", "s"]]))
+print(" test for get GOAL on ['dr(0,s,np)', 's']", get_GOAL(5, 12, [["dr(0,s,np)", "s"]]))
 
 
 # endregion
@@ -356,13 +351,10 @@ print(" test for get GOAL on ['dr(0,s,np)', 's']", get_GOAL(12, [["dr(0,s,np)",
 # region get idx for pos and neg
 
 def get_pos_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type, max_atoms_in_sentence):
-    atoms_batch_for_polarities = list(
-        map(lambda sentence: sentence.split(" "), atoms_batch))
     pos_idx = [pad_sequence([torch.as_tensor([i for i, x in enumerate(sentence) if
-                                              bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z",
-                                                            atoms_batch_for_polarities[s_idx][i])) and
+                                              bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch[s_idx][i])) and
                                               atoms_polarity_batch[s_idx][i]])
-                             for s_idx, sentence in enumerate(atoms_batch_for_polarities)],
+                             for s_idx, sentence in enumerate(atoms_batch)],
                             max_len=max_atoms_in_one_type // 2, padding_value=-1)
                for atom_type in list(atom_map_redux.keys())]
 
@@ -370,24 +362,21 @@ def get_pos_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type, max_at
 
 
 def get_neg_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type, max_atoms_in_sentence):
-    atoms_batch_for_polarities = list(
-        map(lambda sentence: sentence.split(" "), atoms_batch))
     pos_idx = [pad_sequence([torch.as_tensor([i for i, x in enumerate(sentence) if
-                                              bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z",
-                                                            atoms_batch_for_polarities[s_idx][i])) and not
+                                              bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z",atoms_batch[s_idx][i])) and not
                                               atoms_polarity_batch[s_idx][i]])
-                             for s_idx, sentence in enumerate(atoms_batch_for_polarities)],
+                             for s_idx, sentence in enumerate(atoms_batch)],
                             max_len=max_atoms_in_one_type // 2, padding_value=-1)
                for atom_type in list(atom_map_redux.keys())]
 
     return torch.stack(pos_idx).permute(1, 0, 2)
 
 
-print(" test for cut into pos neg on ['s np [SEP] s [SEP] np s s n n']", get_neg_idx(['s np [SEP] s [SEP] np s s n n'],
+print(" test for cut into pos neg on ['dr(0,s,np)', 's']", get_neg_idx([['s', 's', 'np', 's', 'np', '[SEP]', 's', '[SEP]']],
                                                                                      torch.as_tensor(
-                                                                                         [[False, True, False, False,
-                                                                                           False, False, True, True,
-                                                                                           False, True,
+                                                                                         [[True, True, False, False,
+                                                                                           True, False, False, False,
+                                                                                           False, False,
                                                                                            False, False]]), 10, 50))
 
 # endregion
diff --git a/train.py b/train.py
index 4721ed9..fdf3936 100644
--- a/train.py
+++ b/train.py
@@ -5,7 +5,7 @@ from utils import read_csv_pgbar
 
 torch.cuda.empty_cache()
 batch_size = int(Configuration.modelTrainingConfig['batch_size'])
-nb_sentences = batch_size * 4
+nb_sentences = batch_size * 800
 epochs = int(Configuration.modelTrainingConfig['epoch'])
 
 file_path_axiom_links = 'Datasets/goldANDsilver_dataset_links.csv'
-- 
GitLab