From 3fd2c96b01291605d955e071980df06d1e59fd06 Mon Sep 17 00:00:00 2001
From: Caroline DE POURTALES <cdepourt@montana.irit.fr>
Date: Mon, 23 May 2022 11:34:36 +0200
Subject: [PATCH] update linker padding

---
 Linker/Linker.py       | 90 +++++++++++++++++++++++++++---------------
 Linker/utils_linker.py | 83 +++++++++++++++++++++++++-------------
 2 files changed, 114 insertions(+), 59 deletions(-)

diff --git a/Linker/Linker.py b/Linker/Linker.py
index e7f2bac..f3a4538 100644
--- a/Linker/Linker.py
+++ b/Linker/Linker.py
@@ -1,10 +1,10 @@
 import os
+import re
 import sys
 import datetime
 
 import time
 
-import torch
 import torch.nn.functional as F
 from torch.nn import Sequential, LayerNorm, Dropout
 from torch.optim import AdamW
@@ -19,8 +19,7 @@ from Linker.MHA import AttentionDecoderLayer
 from Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn
 from Linker.atom_map import atom_map
 from Linker.eval import mesure_accuracy, SinkhornLoss
-from Linker.utils_linker import find_pos_neg_idexes, get_atoms_batch, FFN, get_axiom_links, get_pos_encoding_for_s_idx, \
-    get_neg_encoding_for_s_idx
+from Linker.utils_linker import FFN, get_axiom_links, get_GOAL
 from Supertagger import *
 from utils import pad_sequence
 
@@ -108,11 +107,9 @@ class Linker(Module):
         sentences_batch = df_axiom_links["X"].tolist()
         sentences_tokens, sentences_mask = self.Supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch)
 
-        atoms_batch = get_atoms_batch(df_axiom_links["Z"])
+        atoms_batch, atoms_polarity_batch = get_GOAL(self.max_atoms_in_sentence, df_axiom_links["Z"])
         atoms_batch_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms_batch)
 
-        atoms_polarity_batch = find_pos_neg_idexes(self.max_atoms_in_sentence, df_axiom_links["Z"])
-
         truth_links_batch = get_axiom_links(self.max_atoms_in_one_type, self.sub_atoms_type_list, atoms_polarity_batch,
                                             df_axiom_links["Y"])
         truth_links_batch = truth_links_batch.permute(1, 0, 2)
@@ -160,17 +157,13 @@ class Linker(Module):
 
         link_weights = []
         for atom_type in self.sub_atoms_type_list:
-            pos_encoding = pad_sequence(
-                [get_pos_encoding_for_s_idx(self.dim_embedding_atoms, atoms_encoding, atoms_batch_tokenized,
-                                            atoms_polarity_batch, atom_type, self.inverse_map, s_idx)
-                 for s_idx in range(len(atoms_polarity_batch))], padding_value=0,
-                max_len=self.max_atoms_in_one_type // 2)
-
-            neg_encoding = pad_sequence(
-                [get_neg_encoding_for_s_idx(self.dim_embedding_atoms, atoms_encoding, atoms_batch_tokenized,
-                                            atoms_polarity_batch, atom_type, self.inverse_map, s_idx)
-                 for s_idx in range(len(atoms_polarity_batch))], padding_value=0,
-                max_len=self.max_atoms_in_one_type // 2)
+            pos_encoding = torch.stack([self.get_pos_encoding_for_s_idx(atoms_encoding, atoms_batch_tokenized,
+                                                                        atoms_polarity_batch, atom_type, s_idx)
+                                        for s_idx in range(len(atoms_polarity_batch))])
+
+            neg_encoding = torch.stack([self.get_neg_encoding_for_s_idx(atoms_encoding, atoms_batch_tokenized,
+                                                                        atoms_polarity_batch, atom_type, s_idx)
+                                        for s_idx in range(len(atoms_polarity_batch))])
 
             pos_encoding = self.pos_transformation(pos_encoding)
             neg_encoding = self.neg_transformation(neg_encoding)
@@ -210,8 +203,9 @@ class Linker(Module):
             print("")
             print(f'Epoch: {epoch_i + 1:02} | Epoch Time: {training_time}')
             print(f'\tTrain Loss: {avg_train_loss:.3f} | Train Acc: {avg_accuracy_train * 100:.2f}%')
+
             if validation_rate > 0.0:
-                loss_test, accuracy_test = self.eval_epoch(validation_dataloader, self.cross_entropy_loss)
+                loss_test, accuracy_test = self.eval_epoch(validation_dataloader)
                 print(f'\tVal Loss: {loss_test:.3f} | Val Acc: {accuracy_test * 100:.2f}%')
 
             if checkpoint:
@@ -236,7 +230,6 @@ class Linker(Module):
 
         Args:
             training_dataloader : DataLoader from torch , contains atoms, polarities, axiom_links, sents_tokenized, sents_masks
-            validation_dataloader : DataLoader from torch , contains atoms, polarities, axiom_links, sents_tokenized, sents_masks
         Returns:
              accuracy on validation set
              loss on train set
@@ -300,12 +293,9 @@ class Linker(Module):
         self.eval()
         with torch.no_grad():
             # get atoms
-            atoms_batch = get_atoms_batch(categories)
+            atoms_batch, polarities = get_GOAL(self.max_atoms_in_sentence, categories)
             atoms_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms_batch)
 
-            # get polarities
-            polarities = find_pos_neg_idexes(self.max_atoms_in_sentence, categories)
-
             # atoms embedding
             atoms_embedding = self.atoms_embedding(atoms_tokenized)
 
@@ -316,14 +306,12 @@ class Linker(Module):
             link_weights = []
             for atom_type in self.sub_atoms_type_list:
                 pos_encoding = pad_sequence(
-                    [get_pos_encoding_for_s_idx(self.dim_embedding_atoms, atoms_encoding, atoms_tokenized,
-                                                polarities, atom_type, self.inverse_map, s_idx)
+                    [self.get_pos_encoding_for_s_idx(atoms_encoding, atoms_tokenized, polarities, atom_type, s_idx)
                      for s_idx in range(len(polarities))], padding_value=0,
                     max_len=self.max_atoms_in_one_type // 2)
 
                 neg_encoding = pad_sequence(
-                    [get_neg_encoding_for_s_idx(self.dim_embedding_atoms, atoms_encoding, atoms_tokenized,
-                                                polarities, atom_type, self.inverse_map, s_idx)
+                    [self.get_neg_encoding_for_s_idx(atoms_encoding, atoms_tokenized, polarities, atom_type, s_idx)
                      for s_idx in range(len(polarities))], padding_value=0,
                     max_len=self.max_atoms_in_one_type // 2)
 
@@ -337,7 +325,7 @@ class Linker(Module):
             axiom_links = torch.argmax(F.log_softmax(logits_predictions, dim=3), dim=3)
             return axiom_links
 
-    def eval_batch(self, batch, cross_entropy_loss):
+    def eval_batch(self, batch):
         batch_atoms = batch[0].to(self.device)
         batch_polarity = batch[1].to(self.device)
         batch_true_links = batch[2].to(self.device)
@@ -349,12 +337,20 @@ class Linker(Module):
                                        batch_sentences_mask)
         axiom_links_pred = torch.argmax(logits_axiom_links_pred, dim=3)
 
+        print('\n')
+        print("Tokens de la phrase : ", batch_sentences_tokens[1])
+        print("Atoms dans la phrase : ", (batch_atoms[1][:50]))
+        print("Polarités des atoms de la phrase : ", batch_polarity[1][:50])
+        print("Les vrais liens de la catégorie n : ", batch_true_links[1][2][:100])
+        print("Les prédictions : ", axiom_links_pred[1][2][:100])
+        print('\n')
+
         accuracy = mesure_accuracy(batch_true_links, axiom_links_pred)
-        loss = cross_entropy_loss(logits_axiom_links_pred, batch_true_links)
+        loss = self.cross_entropy_loss(logits_axiom_links_pred, batch_true_links)
 
         return loss, accuracy
 
-    def eval_epoch(self, dataloader, cross_entropy_loss):
+    def eval_epoch(self, dataloader):
         r"""Average the evaluation of all the batch.
 
         Args:
@@ -365,7 +361,7 @@ class Linker(Module):
         loss_average = 0
         with torch.no_grad():
             for step, batch in enumerate(dataloader):
-                loss, accuracy = self.eval_batch(batch, cross_entropy_loss)
+                loss, accuracy = self.eval_batch(batch)
                 accuracy_average += accuracy
                 loss_average += float(loss)
 
@@ -405,3 +401,35 @@ class Linker(Module):
             'optimizer': self.optimizer,
         }, path)
         self.to(self.device)
+
+    def get_pos_encoding_for_s_idx(self, atoms_encoding, atoms_batch_tokenized, atoms_polarity_batch, atom_type, s_idx):
+        pos_encoding = [x for i, x in enumerate(atoms_encoding[s_idx])
+                        if (atom_map[atom_type] in atoms_batch_tokenized[s_idx] and
+                            bool(re.match(r"" + atom_type + "_?\w*",
+                                          self.inverse_map[int(atoms_batch_tokenized[s_idx][i])])) and
+                            atoms_polarity_batch[s_idx][i])]
+        if len(pos_encoding) == 0:
+            return torch.zeros(self.max_atoms_in_one_type//2, self.dim_embedding_atoms,
+                               device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
+        else:
+            len_pos_encoding = len(pos_encoding)
+            pos_encoding += [torch.zeros(self.dim_embedding_atoms,
+                               device=torch.device("cuda" if torch.cuda.is_available() else "cpu")) for i in
+                             range(self.max_atoms_in_one_type//2 - len_pos_encoding)]
+            return torch.stack(pos_encoding)
+
+    def get_neg_encoding_for_s_idx(self, atoms_encoding, atoms_batch_tokenized, atoms_polarity_batch, atom_type, s_idx):
+        neg_encoding = [x for i, x in enumerate(atoms_encoding[s_idx])
+                        if (atom_map[atom_type] in atoms_batch_tokenized[s_idx] and
+                            bool(re.match(r"" + atom_type + "_?\w*",
+                                          self.inverse_map[int(atoms_batch_tokenized[s_idx][i])])) and
+                            not atoms_polarity_batch[s_idx][i])]
+        if len(neg_encoding) == 0:
+            return torch.zeros(self.max_atoms_in_one_type//2, self.dim_embedding_atoms,
+                               device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
+        else:
+            len_neg_encoding = len(neg_encoding)
+            neg_encoding += [torch.zeros(self.dim_embedding_atoms,
+                               device=torch.device("cuda" if torch.cuda.is_available() else "cpu")) for i in
+                             range(self.max_atoms_in_one_type//2 - len_neg_encoding)]
+            return torch.stack(neg_encoding)
diff --git a/Linker/utils_linker.py b/Linker/utils_linker.py
index 0863f9e..0aa6dc2 100644
--- a/Linker/utils_linker.py
+++ b/Linker/utils_linker.py
@@ -165,7 +165,6 @@ def category_to_atoms_polarity(category, polarity):
     """
     category_to_polarity = []
     res = [(category == atom_type) for atom_type in atom_map.keys()]
-
     # mot final
     if category.startswith("GOAL:"):
         word, cat = category.split(':')
@@ -177,7 +176,7 @@ def category_to_atoms_polarity(category, polarity):
     elif category == "let":
         pass
     # le mot a une category atomique
-    elif True in res or category.startswith("dia") or category.startswith("box"):
+    elif True in res:
         category_to_polarity.append(not polarity)
     # sinon c'est une formule longue
     else:
@@ -201,13 +200,34 @@ def category_to_atoms_polarity(category, polarity):
             # for the right side
             category_to_polarity += category_to_atoms_polarity(right_side, polarity)
 
+        # p
+        elif category.startswith("p"):
+            category_cut = regex.match(regex_categories, category).groups()
+            category_cut = [cat for cat in category_cut if cat is not None]
+            left_side, right_side = category_cut[0], category_cut[1]
+            # for the left side
+            category_to_polarity += category_to_atoms_polarity(left_side, not polarity)
+            # for the right side
+            category_to_polarity += category_to_atoms_polarity(right_side, polarity)
+
+        # box
+        elif category.startswith("box"):
+            category_cut = regex.match(regex_categories, category).groups()
+            category_cut = [cat for cat in category_cut if cat is not None]
+            category_to_polarity += category_to_atoms_polarity(category_cut[0], polarity)
+
+        # dia
+        elif category.startswith("dia"):
+            category_cut = regex.match(regex_categories, category).groups()
+            category_cut = [cat for cat in category_cut if cat is not None]
+            category_to_polarity += category_to_atoms_polarity(category_cut[0], polarity)
+
     return category_to_polarity
 
 
-def find_pos_neg_idexes(max_atoms_in_sentence, atoms_batch):
+def find_pos_neg_idexes(atoms_batch):
     r"""
     Args:
-        max_atoms_in_sentence : configuration
         atoms_batch : (batch_size, max_atoms_in_sentence) flattened categories in prefix order
     Returns:
         (batch_size, max_atoms_in_sentence) flattened categories'polarities in prefix order
@@ -218,35 +238,42 @@ def find_pos_neg_idexes(max_atoms_in_sentence, atoms_batch):
         for category in sentence:
             for at in category_to_atoms_polarity(category, True):
                 list_atoms.append(at)
-        list_batch.append(torch.as_tensor(list_atoms))
-    return pad_sequence([list_batch[i] for i in range(len(list_batch))],
-                        max_len=max_atoms_in_sentence, padding_value=0)
+        list_batch.append(list_atoms)
+    return list_batch
 
 
 #########################################################################################
-################################ Prepare encoding ###############################################
+################################ GOAL ###############################################
 #########################################################################################
 
 
-def get_pos_encoding_for_s_idx(dim_embedding_atoms, atoms_encoding, atoms_batch_tokenized, atoms_polarity_batch,
-                               atom_type, inverse_map, s_idx):
-    pos_encoding = [x for i, x in enumerate(atoms_encoding[s_idx])
-                    if (atom_map[atom_type] in atoms_batch_tokenized[s_idx] and
-                        bool(re.match(r'' + atom_type + '_?\w*', inverse_map[int(atoms_batch_tokenized[s_idx][i])])) and
-                        atoms_polarity_batch[s_idx][i])]
-    if len(pos_encoding) == 0:
-        return torch.zeros(1, dim_embedding_atoms, device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
-    else:
-        return torch.stack(pos_encoding)
+def get_GOAL(max_atoms_in_sentence, categories_batch):
+    polarities = find_pos_neg_idexes(categories_batch)
+    atoms_batch = get_atoms_batch(categories_batch)
+    for s_idx in range(len(atoms_batch)):
+        for atom_type in list(atom_map.keys()):
+            list_plus = [x for i, x in enumerate(atoms_batch[s_idx]) if polarities[s_idx][i]
+                         and atoms_batch[s_idx][i] == atom_type]
+            list_minus = [x for i, x in enumerate(atoms_batch[s_idx]) if not polarities[s_idx][i]
+                          and atoms_batch[s_idx][i] == atom_type]
+            while len(list_minus) != len(list_plus):
+                if len(list_minus) > len(list_plus):
+                    atoms_batch[s_idx].append(atom_type)
+                    polarities[s_idx].append(True)
+                else:
+                    atoms_batch[s_idx].append(atom_type)
+                    polarities[s_idx].append(False)
+                list_plus = [x for i, x in enumerate(atoms_batch[s_idx]) if polarities[s_idx][i]
+                             and atoms_batch[s_idx][i] == atom_type]
+                list_minus = [x for i, x in enumerate(atoms_batch[s_idx]) if not polarities[s_idx][i]
+                              and atoms_batch[s_idx][i] == atom_type]
+
+    return atoms_batch, pad_sequence([torch.as_tensor(polarities[i], dtype=torch.bool) for i in range(len(polarities))],
+                                     max_len=max_atoms_in_sentence, padding_value=0)
+
+
+#########################################################################################
+################################ Prepare encoding ###############################################
+#########################################################################################
 
 
-def get_neg_encoding_for_s_idx(dim_embedding_atoms, atoms_encoding, atoms_batch_tokenized, atoms_polarity_batch,
-                               atom_type, inverse_map, s_idx):
-    neg_encoding = [x for i, x in enumerate(atoms_encoding[s_idx])
-                    if (atom_map[atom_type] in atoms_batch_tokenized[s_idx] and
-                        bool(re.match(r'' + atom_type + '_?\w*', inverse_map[int(atoms_batch_tokenized[s_idx][i])])) and
-                        not atoms_polarity_batch[s_idx][i])]
-    if len(neg_encoding) == 0:
-        return torch.zeros(1, dim_embedding_atoms, device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
-    else:
-        return torch.stack(neg_encoding)
-- 
GitLab