Skip to content
Snippets Groups Projects
Commit 39ab01fb authored by Caroline de Pourtalès's avatar Caroline de Pourtalès :speech_balloon:
Browse files

Merge branch 'version-linker-correction-padding' into 'version-linker'

update linker padding

See merge request !1
parents dbe58c65 3fd2c96b
Branches
No related tags found
3 merge requests!6Linker with transformer,!5Linker with transformer,!1update linker padding
import os import os
import re
import sys import sys
import datetime import datetime
import time import time
import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn import Sequential, LayerNorm, Dropout from torch.nn import Sequential, LayerNorm, Dropout
from torch.optim import AdamW from torch.optim import AdamW
...@@ -19,8 +19,7 @@ from Linker.MHA import AttentionDecoderLayer ...@@ -19,8 +19,7 @@ from Linker.MHA import AttentionDecoderLayer
from Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn from Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn
from Linker.atom_map import atom_map from Linker.atom_map import atom_map
from Linker.eval import mesure_accuracy, SinkhornLoss from Linker.eval import mesure_accuracy, SinkhornLoss
from Linker.utils_linker import find_pos_neg_idexes, get_atoms_batch, FFN, get_axiom_links, get_pos_encoding_for_s_idx, \ from Linker.utils_linker import FFN, get_axiom_links, get_GOAL
get_neg_encoding_for_s_idx
from Supertagger import * from Supertagger import *
from utils import pad_sequence from utils import pad_sequence
...@@ -108,11 +107,9 @@ class Linker(Module): ...@@ -108,11 +107,9 @@ class Linker(Module):
sentences_batch = df_axiom_links["X"].tolist() sentences_batch = df_axiom_links["X"].tolist()
sentences_tokens, sentences_mask = self.Supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch) sentences_tokens, sentences_mask = self.Supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch)
atoms_batch = get_atoms_batch(df_axiom_links["Z"]) atoms_batch, atoms_polarity_batch = get_GOAL(self.max_atoms_in_sentence, df_axiom_links["Z"])
atoms_batch_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms_batch) atoms_batch_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms_batch)
atoms_polarity_batch = find_pos_neg_idexes(self.max_atoms_in_sentence, df_axiom_links["Z"])
truth_links_batch = get_axiom_links(self.max_atoms_in_one_type, self.sub_atoms_type_list, atoms_polarity_batch, truth_links_batch = get_axiom_links(self.max_atoms_in_one_type, self.sub_atoms_type_list, atoms_polarity_batch,
df_axiom_links["Y"]) df_axiom_links["Y"])
truth_links_batch = truth_links_batch.permute(1, 0, 2) truth_links_batch = truth_links_batch.permute(1, 0, 2)
...@@ -160,17 +157,13 @@ class Linker(Module): ...@@ -160,17 +157,13 @@ class Linker(Module):
link_weights = [] link_weights = []
for atom_type in self.sub_atoms_type_list: for atom_type in self.sub_atoms_type_list:
pos_encoding = pad_sequence( pos_encoding = torch.stack([self.get_pos_encoding_for_s_idx(atoms_encoding, atoms_batch_tokenized,
[get_pos_encoding_for_s_idx(self.dim_embedding_atoms, atoms_encoding, atoms_batch_tokenized, atoms_polarity_batch, atom_type, s_idx)
atoms_polarity_batch, atom_type, self.inverse_map, s_idx) for s_idx in range(len(atoms_polarity_batch))])
for s_idx in range(len(atoms_polarity_batch))], padding_value=0,
max_len=self.max_atoms_in_one_type // 2) neg_encoding = torch.stack([self.get_neg_encoding_for_s_idx(atoms_encoding, atoms_batch_tokenized,
atoms_polarity_batch, atom_type, s_idx)
neg_encoding = pad_sequence( for s_idx in range(len(atoms_polarity_batch))])
[get_neg_encoding_for_s_idx(self.dim_embedding_atoms, atoms_encoding, atoms_batch_tokenized,
atoms_polarity_batch, atom_type, self.inverse_map, s_idx)
for s_idx in range(len(atoms_polarity_batch))], padding_value=0,
max_len=self.max_atoms_in_one_type // 2)
pos_encoding = self.pos_transformation(pos_encoding) pos_encoding = self.pos_transformation(pos_encoding)
neg_encoding = self.neg_transformation(neg_encoding) neg_encoding = self.neg_transformation(neg_encoding)
...@@ -210,8 +203,9 @@ class Linker(Module): ...@@ -210,8 +203,9 @@ class Linker(Module):
print("") print("")
print(f'Epoch: {epoch_i + 1:02} | Epoch Time: {training_time}') print(f'Epoch: {epoch_i + 1:02} | Epoch Time: {training_time}')
print(f'\tTrain Loss: {avg_train_loss:.3f} | Train Acc: {avg_accuracy_train * 100:.2f}%') print(f'\tTrain Loss: {avg_train_loss:.3f} | Train Acc: {avg_accuracy_train * 100:.2f}%')
if validation_rate > 0.0: if validation_rate > 0.0:
loss_test, accuracy_test = self.eval_epoch(validation_dataloader, self.cross_entropy_loss) loss_test, accuracy_test = self.eval_epoch(validation_dataloader)
print(f'\tVal Loss: {loss_test:.3f} | Val Acc: {accuracy_test * 100:.2f}%') print(f'\tVal Loss: {loss_test:.3f} | Val Acc: {accuracy_test * 100:.2f}%')
if checkpoint: if checkpoint:
...@@ -236,7 +230,6 @@ class Linker(Module): ...@@ -236,7 +230,6 @@ class Linker(Module):
Args: Args:
training_dataloader : DataLoader from torch , contains atoms, polarities, axiom_links, sents_tokenized, sents_masks training_dataloader : DataLoader from torch , contains atoms, polarities, axiom_links, sents_tokenized, sents_masks
validation_dataloader : DataLoader from torch , contains atoms, polarities, axiom_links, sents_tokenized, sents_masks
Returns: Returns:
accuracy on validation set accuracy on validation set
loss on train set loss on train set
...@@ -300,12 +293,9 @@ class Linker(Module): ...@@ -300,12 +293,9 @@ class Linker(Module):
self.eval() self.eval()
with torch.no_grad(): with torch.no_grad():
# get atoms # get atoms
atoms_batch = get_atoms_batch(categories) atoms_batch, polarities = get_GOAL(self.max_atoms_in_sentence, categories)
atoms_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms_batch) atoms_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms_batch)
# get polarities
polarities = find_pos_neg_idexes(self.max_atoms_in_sentence, categories)
# atoms embedding # atoms embedding
atoms_embedding = self.atoms_embedding(atoms_tokenized) atoms_embedding = self.atoms_embedding(atoms_tokenized)
...@@ -316,14 +306,12 @@ class Linker(Module): ...@@ -316,14 +306,12 @@ class Linker(Module):
link_weights = [] link_weights = []
for atom_type in self.sub_atoms_type_list: for atom_type in self.sub_atoms_type_list:
pos_encoding = pad_sequence( pos_encoding = pad_sequence(
[get_pos_encoding_for_s_idx(self.dim_embedding_atoms, atoms_encoding, atoms_tokenized, [self.get_pos_encoding_for_s_idx(atoms_encoding, atoms_tokenized, polarities, atom_type, s_idx)
polarities, atom_type, self.inverse_map, s_idx)
for s_idx in range(len(polarities))], padding_value=0, for s_idx in range(len(polarities))], padding_value=0,
max_len=self.max_atoms_in_one_type // 2) max_len=self.max_atoms_in_one_type // 2)
neg_encoding = pad_sequence( neg_encoding = pad_sequence(
[get_neg_encoding_for_s_idx(self.dim_embedding_atoms, atoms_encoding, atoms_tokenized, [self.get_neg_encoding_for_s_idx(atoms_encoding, atoms_tokenized, polarities, atom_type, s_idx)
polarities, atom_type, self.inverse_map, s_idx)
for s_idx in range(len(polarities))], padding_value=0, for s_idx in range(len(polarities))], padding_value=0,
max_len=self.max_atoms_in_one_type // 2) max_len=self.max_atoms_in_one_type // 2)
...@@ -337,7 +325,7 @@ class Linker(Module): ...@@ -337,7 +325,7 @@ class Linker(Module):
axiom_links = torch.argmax(F.log_softmax(logits_predictions, dim=3), dim=3) axiom_links = torch.argmax(F.log_softmax(logits_predictions, dim=3), dim=3)
return axiom_links return axiom_links
def eval_batch(self, batch, cross_entropy_loss): def eval_batch(self, batch):
batch_atoms = batch[0].to(self.device) batch_atoms = batch[0].to(self.device)
batch_polarity = batch[1].to(self.device) batch_polarity = batch[1].to(self.device)
batch_true_links = batch[2].to(self.device) batch_true_links = batch[2].to(self.device)
...@@ -349,12 +337,20 @@ class Linker(Module): ...@@ -349,12 +337,20 @@ class Linker(Module):
batch_sentences_mask) batch_sentences_mask)
axiom_links_pred = torch.argmax(logits_axiom_links_pred, dim=3) axiom_links_pred = torch.argmax(logits_axiom_links_pred, dim=3)
print('\n')
print("Tokens de la phrase : ", batch_sentences_tokens[1])
print("Atoms dans la phrase : ", (batch_atoms[1][:50]))
print("Polarités des atoms de la phrase : ", batch_polarity[1][:50])
print("Les vrais liens de la catégorie n : ", batch_true_links[1][2][:100])
print("Les prédictions : ", axiom_links_pred[1][2][:100])
print('\n')
accuracy = mesure_accuracy(batch_true_links, axiom_links_pred) accuracy = mesure_accuracy(batch_true_links, axiom_links_pred)
loss = cross_entropy_loss(logits_axiom_links_pred, batch_true_links) loss = self.cross_entropy_loss(logits_axiom_links_pred, batch_true_links)
return loss, accuracy return loss, accuracy
def eval_epoch(self, dataloader, cross_entropy_loss): def eval_epoch(self, dataloader):
r"""Average the evaluation of all the batch. r"""Average the evaluation of all the batch.
Args: Args:
...@@ -365,7 +361,7 @@ class Linker(Module): ...@@ -365,7 +361,7 @@ class Linker(Module):
loss_average = 0 loss_average = 0
with torch.no_grad(): with torch.no_grad():
for step, batch in enumerate(dataloader): for step, batch in enumerate(dataloader):
loss, accuracy = self.eval_batch(batch, cross_entropy_loss) loss, accuracy = self.eval_batch(batch)
accuracy_average += accuracy accuracy_average += accuracy
loss_average += float(loss) loss_average += float(loss)
...@@ -405,3 +401,35 @@ class Linker(Module): ...@@ -405,3 +401,35 @@ class Linker(Module):
'optimizer': self.optimizer, 'optimizer': self.optimizer,
}, path) }, path)
self.to(self.device) self.to(self.device)
def get_pos_encoding_for_s_idx(self, atoms_encoding, atoms_batch_tokenized, atoms_polarity_batch, atom_type, s_idx):
pos_encoding = [x for i, x in enumerate(atoms_encoding[s_idx])
if (atom_map[atom_type] in atoms_batch_tokenized[s_idx] and
bool(re.match(r"" + atom_type + "_?\w*",
self.inverse_map[int(atoms_batch_tokenized[s_idx][i])])) and
atoms_polarity_batch[s_idx][i])]
if len(pos_encoding) == 0:
return torch.zeros(self.max_atoms_in_one_type//2, self.dim_embedding_atoms,
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
else:
len_pos_encoding = len(pos_encoding)
pos_encoding += [torch.zeros(self.dim_embedding_atoms,
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")) for i in
range(self.max_atoms_in_one_type//2 - len_pos_encoding)]
return torch.stack(pos_encoding)
def get_neg_encoding_for_s_idx(self, atoms_encoding, atoms_batch_tokenized, atoms_polarity_batch, atom_type, s_idx):
neg_encoding = [x for i, x in enumerate(atoms_encoding[s_idx])
if (atom_map[atom_type] in atoms_batch_tokenized[s_idx] and
bool(re.match(r"" + atom_type + "_?\w*",
self.inverse_map[int(atoms_batch_tokenized[s_idx][i])])) and
not atoms_polarity_batch[s_idx][i])]
if len(neg_encoding) == 0:
return torch.zeros(self.max_atoms_in_one_type//2, self.dim_embedding_atoms,
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
else:
len_neg_encoding = len(neg_encoding)
neg_encoding += [torch.zeros(self.dim_embedding_atoms,
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")) for i in
range(self.max_atoms_in_one_type//2 - len_neg_encoding)]
return torch.stack(neg_encoding)
...@@ -165,7 +165,6 @@ def category_to_atoms_polarity(category, polarity): ...@@ -165,7 +165,6 @@ def category_to_atoms_polarity(category, polarity):
""" """
category_to_polarity = [] category_to_polarity = []
res = [(category == atom_type) for atom_type in atom_map.keys()] res = [(category == atom_type) for atom_type in atom_map.keys()]
# mot final # mot final
if category.startswith("GOAL:"): if category.startswith("GOAL:"):
word, cat = category.split(':') word, cat = category.split(':')
...@@ -177,7 +176,7 @@ def category_to_atoms_polarity(category, polarity): ...@@ -177,7 +176,7 @@ def category_to_atoms_polarity(category, polarity):
elif category == "let": elif category == "let":
pass pass
# le mot a une category atomique # le mot a une category atomique
elif True in res or category.startswith("dia") or category.startswith("box"): elif True in res:
category_to_polarity.append(not polarity) category_to_polarity.append(not polarity)
# sinon c'est une formule longue # sinon c'est une formule longue
else: else:
...@@ -201,13 +200,34 @@ def category_to_atoms_polarity(category, polarity): ...@@ -201,13 +200,34 @@ def category_to_atoms_polarity(category, polarity):
# for the right side # for the right side
category_to_polarity += category_to_atoms_polarity(right_side, polarity) category_to_polarity += category_to_atoms_polarity(right_side, polarity)
# p
elif category.startswith("p"):
category_cut = regex.match(regex_categories, category).groups()
category_cut = [cat for cat in category_cut if cat is not None]
left_side, right_side = category_cut[0], category_cut[1]
# for the left side
category_to_polarity += category_to_atoms_polarity(left_side, not polarity)
# for the right side
category_to_polarity += category_to_atoms_polarity(right_side, polarity)
# box
elif category.startswith("box"):
category_cut = regex.match(regex_categories, category).groups()
category_cut = [cat for cat in category_cut if cat is not None]
category_to_polarity += category_to_atoms_polarity(category_cut[0], polarity)
# dia
elif category.startswith("dia"):
category_cut = regex.match(regex_categories, category).groups()
category_cut = [cat for cat in category_cut if cat is not None]
category_to_polarity += category_to_atoms_polarity(category_cut[0], polarity)
return category_to_polarity return category_to_polarity
def find_pos_neg_idexes(max_atoms_in_sentence, atoms_batch): def find_pos_neg_idexes(atoms_batch):
r""" r"""
Args: Args:
max_atoms_in_sentence : configuration
atoms_batch : (batch_size, max_atoms_in_sentence) flattened categories in prefix order atoms_batch : (batch_size, max_atoms_in_sentence) flattened categories in prefix order
Returns: Returns:
(batch_size, max_atoms_in_sentence) flattened categories'polarities in prefix order (batch_size, max_atoms_in_sentence) flattened categories'polarities in prefix order
...@@ -218,35 +238,42 @@ def find_pos_neg_idexes(max_atoms_in_sentence, atoms_batch): ...@@ -218,35 +238,42 @@ def find_pos_neg_idexes(max_atoms_in_sentence, atoms_batch):
for category in sentence: for category in sentence:
for at in category_to_atoms_polarity(category, True): for at in category_to_atoms_polarity(category, True):
list_atoms.append(at) list_atoms.append(at)
list_batch.append(torch.as_tensor(list_atoms)) list_batch.append(list_atoms)
return pad_sequence([list_batch[i] for i in range(len(list_batch))], return list_batch
max_len=max_atoms_in_sentence, padding_value=0)
######################################################################################### #########################################################################################
################################ Prepare encoding ############################################### ################################ GOAL ###############################################
######################################################################################### #########################################################################################
def get_pos_encoding_for_s_idx(dim_embedding_atoms, atoms_encoding, atoms_batch_tokenized, atoms_polarity_batch, def get_GOAL(max_atoms_in_sentence, categories_batch):
atom_type, inverse_map, s_idx): polarities = find_pos_neg_idexes(categories_batch)
pos_encoding = [x for i, x in enumerate(atoms_encoding[s_idx]) atoms_batch = get_atoms_batch(categories_batch)
if (atom_map[atom_type] in atoms_batch_tokenized[s_idx] and for s_idx in range(len(atoms_batch)):
bool(re.match(r'' + atom_type + '_?\w*', inverse_map[int(atoms_batch_tokenized[s_idx][i])])) and for atom_type in list(atom_map.keys()):
atoms_polarity_batch[s_idx][i])] list_plus = [x for i, x in enumerate(atoms_batch[s_idx]) if polarities[s_idx][i]
if len(pos_encoding) == 0: and atoms_batch[s_idx][i] == atom_type]
return torch.zeros(1, dim_embedding_atoms, device=torch.device("cuda" if torch.cuda.is_available() else "cpu")) list_minus = [x for i, x in enumerate(atoms_batch[s_idx]) if not polarities[s_idx][i]
else: and atoms_batch[s_idx][i] == atom_type]
return torch.stack(pos_encoding) while len(list_minus) != len(list_plus):
if len(list_minus) > len(list_plus):
atoms_batch[s_idx].append(atom_type)
polarities[s_idx].append(True)
else:
atoms_batch[s_idx].append(atom_type)
polarities[s_idx].append(False)
list_plus = [x for i, x in enumerate(atoms_batch[s_idx]) if polarities[s_idx][i]
and atoms_batch[s_idx][i] == atom_type]
list_minus = [x for i, x in enumerate(atoms_batch[s_idx]) if not polarities[s_idx][i]
and atoms_batch[s_idx][i] == atom_type]
return atoms_batch, pad_sequence([torch.as_tensor(polarities[i], dtype=torch.bool) for i in range(len(polarities))],
max_len=max_atoms_in_sentence, padding_value=0)
#########################################################################################
################################ Prepare encoding ###############################################
#########################################################################################
def get_neg_encoding_for_s_idx(dim_embedding_atoms, atoms_encoding, atoms_batch_tokenized, atoms_polarity_batch,
atom_type, inverse_map, s_idx):
neg_encoding = [x for i, x in enumerate(atoms_encoding[s_idx])
if (atom_map[atom_type] in atoms_batch_tokenized[s_idx] and
bool(re.match(r'' + atom_type + '_?\w*', inverse_map[int(atoms_batch_tokenized[s_idx][i])])) and
not atoms_polarity_batch[s_idx][i])]
if len(neg_encoding) == 0:
return torch.zeros(1, dim_embedding_atoms, device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
else:
return torch.stack(neg_encoding)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment