From 3fd2c96b01291605d955e071980df06d1e59fd06 Mon Sep 17 00:00:00 2001 From: Caroline DE POURTALES <cdepourt@montana.irit.fr> Date: Mon, 23 May 2022 11:34:36 +0200 Subject: [PATCH] update linker padding --- Linker/Linker.py | 90 +++++++++++++++++++++++++++--------------- Linker/utils_linker.py | 83 +++++++++++++++++++++++++------------- 2 files changed, 114 insertions(+), 59 deletions(-) diff --git a/Linker/Linker.py b/Linker/Linker.py index e7f2bac..f3a4538 100644 --- a/Linker/Linker.py +++ b/Linker/Linker.py @@ -1,10 +1,10 @@ import os +import re import sys import datetime import time -import torch import torch.nn.functional as F from torch.nn import Sequential, LayerNorm, Dropout from torch.optim import AdamW @@ -19,8 +19,7 @@ from Linker.MHA import AttentionDecoderLayer from Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn from Linker.atom_map import atom_map from Linker.eval import mesure_accuracy, SinkhornLoss -from Linker.utils_linker import find_pos_neg_idexes, get_atoms_batch, FFN, get_axiom_links, get_pos_encoding_for_s_idx, \ - get_neg_encoding_for_s_idx +from Linker.utils_linker import FFN, get_axiom_links, get_GOAL from Supertagger import * from utils import pad_sequence @@ -108,11 +107,9 @@ class Linker(Module): sentences_batch = df_axiom_links["X"].tolist() sentences_tokens, sentences_mask = self.Supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch) - atoms_batch = get_atoms_batch(df_axiom_links["Z"]) + atoms_batch, atoms_polarity_batch = get_GOAL(self.max_atoms_in_sentence, df_axiom_links["Z"]) atoms_batch_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms_batch) - atoms_polarity_batch = find_pos_neg_idexes(self.max_atoms_in_sentence, df_axiom_links["Z"]) - truth_links_batch = get_axiom_links(self.max_atoms_in_one_type, self.sub_atoms_type_list, atoms_polarity_batch, df_axiom_links["Y"]) truth_links_batch = truth_links_batch.permute(1, 0, 2) @@ -160,17 +157,13 @@ class Linker(Module): link_weights = [] for atom_type in self.sub_atoms_type_list: - pos_encoding = pad_sequence( - [get_pos_encoding_for_s_idx(self.dim_embedding_atoms, atoms_encoding, atoms_batch_tokenized, - atoms_polarity_batch, atom_type, self.inverse_map, s_idx) - for s_idx in range(len(atoms_polarity_batch))], padding_value=0, - max_len=self.max_atoms_in_one_type // 2) - - neg_encoding = pad_sequence( - [get_neg_encoding_for_s_idx(self.dim_embedding_atoms, atoms_encoding, atoms_batch_tokenized, - atoms_polarity_batch, atom_type, self.inverse_map, s_idx) - for s_idx in range(len(atoms_polarity_batch))], padding_value=0, - max_len=self.max_atoms_in_one_type // 2) + pos_encoding = torch.stack([self.get_pos_encoding_for_s_idx(atoms_encoding, atoms_batch_tokenized, + atoms_polarity_batch, atom_type, s_idx) + for s_idx in range(len(atoms_polarity_batch))]) + + neg_encoding = torch.stack([self.get_neg_encoding_for_s_idx(atoms_encoding, atoms_batch_tokenized, + atoms_polarity_batch, atom_type, s_idx) + for s_idx in range(len(atoms_polarity_batch))]) pos_encoding = self.pos_transformation(pos_encoding) neg_encoding = self.neg_transformation(neg_encoding) @@ -210,8 +203,9 @@ class Linker(Module): print("") print(f'Epoch: {epoch_i + 1:02} | Epoch Time: {training_time}') print(f'\tTrain Loss: {avg_train_loss:.3f} | Train Acc: {avg_accuracy_train * 100:.2f}%') + if validation_rate > 0.0: - loss_test, accuracy_test = self.eval_epoch(validation_dataloader, self.cross_entropy_loss) + loss_test, accuracy_test = self.eval_epoch(validation_dataloader) print(f'\tVal Loss: {loss_test:.3f} | Val Acc: {accuracy_test * 100:.2f}%') if checkpoint: @@ -236,7 +230,6 @@ class Linker(Module): Args: training_dataloader : DataLoader from torch , contains atoms, polarities, axiom_links, sents_tokenized, sents_masks - validation_dataloader : DataLoader from torch , contains atoms, polarities, axiom_links, sents_tokenized, sents_masks Returns: accuracy on validation set loss on train set @@ -300,12 +293,9 @@ class Linker(Module): self.eval() with torch.no_grad(): # get atoms - atoms_batch = get_atoms_batch(categories) + atoms_batch, polarities = get_GOAL(self.max_atoms_in_sentence, categories) atoms_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms_batch) - # get polarities - polarities = find_pos_neg_idexes(self.max_atoms_in_sentence, categories) - # atoms embedding atoms_embedding = self.atoms_embedding(atoms_tokenized) @@ -316,14 +306,12 @@ class Linker(Module): link_weights = [] for atom_type in self.sub_atoms_type_list: pos_encoding = pad_sequence( - [get_pos_encoding_for_s_idx(self.dim_embedding_atoms, atoms_encoding, atoms_tokenized, - polarities, atom_type, self.inverse_map, s_idx) + [self.get_pos_encoding_for_s_idx(atoms_encoding, atoms_tokenized, polarities, atom_type, s_idx) for s_idx in range(len(polarities))], padding_value=0, max_len=self.max_atoms_in_one_type // 2) neg_encoding = pad_sequence( - [get_neg_encoding_for_s_idx(self.dim_embedding_atoms, atoms_encoding, atoms_tokenized, - polarities, atom_type, self.inverse_map, s_idx) + [self.get_neg_encoding_for_s_idx(atoms_encoding, atoms_tokenized, polarities, atom_type, s_idx) for s_idx in range(len(polarities))], padding_value=0, max_len=self.max_atoms_in_one_type // 2) @@ -337,7 +325,7 @@ class Linker(Module): axiom_links = torch.argmax(F.log_softmax(logits_predictions, dim=3), dim=3) return axiom_links - def eval_batch(self, batch, cross_entropy_loss): + def eval_batch(self, batch): batch_atoms = batch[0].to(self.device) batch_polarity = batch[1].to(self.device) batch_true_links = batch[2].to(self.device) @@ -349,12 +337,20 @@ class Linker(Module): batch_sentences_mask) axiom_links_pred = torch.argmax(logits_axiom_links_pred, dim=3) + print('\n') + print("Tokens de la phrase : ", batch_sentences_tokens[1]) + print("Atoms dans la phrase : ", (batch_atoms[1][:50])) + print("Polarités des atoms de la phrase : ", batch_polarity[1][:50]) + print("Les vrais liens de la catégorie n : ", batch_true_links[1][2][:100]) + print("Les prédictions : ", axiom_links_pred[1][2][:100]) + print('\n') + accuracy = mesure_accuracy(batch_true_links, axiom_links_pred) - loss = cross_entropy_loss(logits_axiom_links_pred, batch_true_links) + loss = self.cross_entropy_loss(logits_axiom_links_pred, batch_true_links) return loss, accuracy - def eval_epoch(self, dataloader, cross_entropy_loss): + def eval_epoch(self, dataloader): r"""Average the evaluation of all the batch. Args: @@ -365,7 +361,7 @@ class Linker(Module): loss_average = 0 with torch.no_grad(): for step, batch in enumerate(dataloader): - loss, accuracy = self.eval_batch(batch, cross_entropy_loss) + loss, accuracy = self.eval_batch(batch) accuracy_average += accuracy loss_average += float(loss) @@ -405,3 +401,35 @@ class Linker(Module): 'optimizer': self.optimizer, }, path) self.to(self.device) + + def get_pos_encoding_for_s_idx(self, atoms_encoding, atoms_batch_tokenized, atoms_polarity_batch, atom_type, s_idx): + pos_encoding = [x for i, x in enumerate(atoms_encoding[s_idx]) + if (atom_map[atom_type] in atoms_batch_tokenized[s_idx] and + bool(re.match(r"" + atom_type + "_?\w*", + self.inverse_map[int(atoms_batch_tokenized[s_idx][i])])) and + atoms_polarity_batch[s_idx][i])] + if len(pos_encoding) == 0: + return torch.zeros(self.max_atoms_in_one_type//2, self.dim_embedding_atoms, + device=torch.device("cuda" if torch.cuda.is_available() else "cpu")) + else: + len_pos_encoding = len(pos_encoding) + pos_encoding += [torch.zeros(self.dim_embedding_atoms, + device=torch.device("cuda" if torch.cuda.is_available() else "cpu")) for i in + range(self.max_atoms_in_one_type//2 - len_pos_encoding)] + return torch.stack(pos_encoding) + + def get_neg_encoding_for_s_idx(self, atoms_encoding, atoms_batch_tokenized, atoms_polarity_batch, atom_type, s_idx): + neg_encoding = [x for i, x in enumerate(atoms_encoding[s_idx]) + if (atom_map[atom_type] in atoms_batch_tokenized[s_idx] and + bool(re.match(r"" + atom_type + "_?\w*", + self.inverse_map[int(atoms_batch_tokenized[s_idx][i])])) and + not atoms_polarity_batch[s_idx][i])] + if len(neg_encoding) == 0: + return torch.zeros(self.max_atoms_in_one_type//2, self.dim_embedding_atoms, + device=torch.device("cuda" if torch.cuda.is_available() else "cpu")) + else: + len_neg_encoding = len(neg_encoding) + neg_encoding += [torch.zeros(self.dim_embedding_atoms, + device=torch.device("cuda" if torch.cuda.is_available() else "cpu")) for i in + range(self.max_atoms_in_one_type//2 - len_neg_encoding)] + return torch.stack(neg_encoding) diff --git a/Linker/utils_linker.py b/Linker/utils_linker.py index 0863f9e..0aa6dc2 100644 --- a/Linker/utils_linker.py +++ b/Linker/utils_linker.py @@ -165,7 +165,6 @@ def category_to_atoms_polarity(category, polarity): """ category_to_polarity = [] res = [(category == atom_type) for atom_type in atom_map.keys()] - # mot final if category.startswith("GOAL:"): word, cat = category.split(':') @@ -177,7 +176,7 @@ def category_to_atoms_polarity(category, polarity): elif category == "let": pass # le mot a une category atomique - elif True in res or category.startswith("dia") or category.startswith("box"): + elif True in res: category_to_polarity.append(not polarity) # sinon c'est une formule longue else: @@ -201,13 +200,34 @@ def category_to_atoms_polarity(category, polarity): # for the right side category_to_polarity += category_to_atoms_polarity(right_side, polarity) + # p + elif category.startswith("p"): + category_cut = regex.match(regex_categories, category).groups() + category_cut = [cat for cat in category_cut if cat is not None] + left_side, right_side = category_cut[0], category_cut[1] + # for the left side + category_to_polarity += category_to_atoms_polarity(left_side, not polarity) + # for the right side + category_to_polarity += category_to_atoms_polarity(right_side, polarity) + + # box + elif category.startswith("box"): + category_cut = regex.match(regex_categories, category).groups() + category_cut = [cat for cat in category_cut if cat is not None] + category_to_polarity += category_to_atoms_polarity(category_cut[0], polarity) + + # dia + elif category.startswith("dia"): + category_cut = regex.match(regex_categories, category).groups() + category_cut = [cat for cat in category_cut if cat is not None] + category_to_polarity += category_to_atoms_polarity(category_cut[0], polarity) + return category_to_polarity -def find_pos_neg_idexes(max_atoms_in_sentence, atoms_batch): +def find_pos_neg_idexes(atoms_batch): r""" Args: - max_atoms_in_sentence : configuration atoms_batch : (batch_size, max_atoms_in_sentence) flattened categories in prefix order Returns: (batch_size, max_atoms_in_sentence) flattened categories'polarities in prefix order @@ -218,35 +238,42 @@ def find_pos_neg_idexes(max_atoms_in_sentence, atoms_batch): for category in sentence: for at in category_to_atoms_polarity(category, True): list_atoms.append(at) - list_batch.append(torch.as_tensor(list_atoms)) - return pad_sequence([list_batch[i] for i in range(len(list_batch))], - max_len=max_atoms_in_sentence, padding_value=0) + list_batch.append(list_atoms) + return list_batch ######################################################################################### -################################ Prepare encoding ############################################### +################################ GOAL ############################################### ######################################################################################### -def get_pos_encoding_for_s_idx(dim_embedding_atoms, atoms_encoding, atoms_batch_tokenized, atoms_polarity_batch, - atom_type, inverse_map, s_idx): - pos_encoding = [x for i, x in enumerate(atoms_encoding[s_idx]) - if (atom_map[atom_type] in atoms_batch_tokenized[s_idx] and - bool(re.match(r'' + atom_type + '_?\w*', inverse_map[int(atoms_batch_tokenized[s_idx][i])])) and - atoms_polarity_batch[s_idx][i])] - if len(pos_encoding) == 0: - return torch.zeros(1, dim_embedding_atoms, device=torch.device("cuda" if torch.cuda.is_available() else "cpu")) - else: - return torch.stack(pos_encoding) +def get_GOAL(max_atoms_in_sentence, categories_batch): + polarities = find_pos_neg_idexes(categories_batch) + atoms_batch = get_atoms_batch(categories_batch) + for s_idx in range(len(atoms_batch)): + for atom_type in list(atom_map.keys()): + list_plus = [x for i, x in enumerate(atoms_batch[s_idx]) if polarities[s_idx][i] + and atoms_batch[s_idx][i] == atom_type] + list_minus = [x for i, x in enumerate(atoms_batch[s_idx]) if not polarities[s_idx][i] + and atoms_batch[s_idx][i] == atom_type] + while len(list_minus) != len(list_plus): + if len(list_minus) > len(list_plus): + atoms_batch[s_idx].append(atom_type) + polarities[s_idx].append(True) + else: + atoms_batch[s_idx].append(atom_type) + polarities[s_idx].append(False) + list_plus = [x for i, x in enumerate(atoms_batch[s_idx]) if polarities[s_idx][i] + and atoms_batch[s_idx][i] == atom_type] + list_minus = [x for i, x in enumerate(atoms_batch[s_idx]) if not polarities[s_idx][i] + and atoms_batch[s_idx][i] == atom_type] + + return atoms_batch, pad_sequence([torch.as_tensor(polarities[i], dtype=torch.bool) for i in range(len(polarities))], + max_len=max_atoms_in_sentence, padding_value=0) + + +######################################################################################### +################################ Prepare encoding ############################################### +######################################################################################### -def get_neg_encoding_for_s_idx(dim_embedding_atoms, atoms_encoding, atoms_batch_tokenized, atoms_polarity_batch, - atom_type, inverse_map, s_idx): - neg_encoding = [x for i, x in enumerate(atoms_encoding[s_idx]) - if (atom_map[atom_type] in atoms_batch_tokenized[s_idx] and - bool(re.match(r'' + atom_type + '_?\w*', inverse_map[int(atoms_batch_tokenized[s_idx][i])])) and - not atoms_polarity_batch[s_idx][i])] - if len(neg_encoding) == 0: - return torch.zeros(1, dim_embedding_atoms, device=torch.device("cuda" if torch.cuda.is_available() else "cpu")) - else: - return torch.stack(neg_encoding) -- GitLab