Skip to content
Snippets Groups Projects

update linker padding

Merged Caroline de Pourtalès requested to merge version-linker-correction-padding into version-linker
2 files
+ 114
59
Compare changes
  • Side-by-side
  • Inline
Files
2
+ 59
31
import os
import re
import sys
import datetime
import time
import torch
import torch.nn.functional as F
from torch.nn import Sequential, LayerNorm, Dropout
from torch.optim import AdamW
@@ -19,8 +19,7 @@ from Linker.MHA import AttentionDecoderLayer
from Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn
from Linker.atom_map import atom_map
from Linker.eval import mesure_accuracy, SinkhornLoss
from Linker.utils_linker import find_pos_neg_idexes, get_atoms_batch, FFN, get_axiom_links, get_pos_encoding_for_s_idx, \
get_neg_encoding_for_s_idx
from Linker.utils_linker import FFN, get_axiom_links, get_GOAL
from Supertagger import *
from utils import pad_sequence
@@ -108,11 +107,9 @@ class Linker(Module):
sentences_batch = df_axiom_links["X"].tolist()
sentences_tokens, sentences_mask = self.Supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch)
atoms_batch = get_atoms_batch(df_axiom_links["Z"])
atoms_batch, atoms_polarity_batch = get_GOAL(self.max_atoms_in_sentence, df_axiom_links["Z"])
atoms_batch_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms_batch)
atoms_polarity_batch = find_pos_neg_idexes(self.max_atoms_in_sentence, df_axiom_links["Z"])
truth_links_batch = get_axiom_links(self.max_atoms_in_one_type, self.sub_atoms_type_list, atoms_polarity_batch,
df_axiom_links["Y"])
truth_links_batch = truth_links_batch.permute(1, 0, 2)
@@ -160,17 +157,13 @@ class Linker(Module):
link_weights = []
for atom_type in self.sub_atoms_type_list:
pos_encoding = pad_sequence(
[get_pos_encoding_for_s_idx(self.dim_embedding_atoms, atoms_encoding, atoms_batch_tokenized,
atoms_polarity_batch, atom_type, self.inverse_map, s_idx)
for s_idx in range(len(atoms_polarity_batch))], padding_value=0,
max_len=self.max_atoms_in_one_type // 2)
neg_encoding = pad_sequence(
[get_neg_encoding_for_s_idx(self.dim_embedding_atoms, atoms_encoding, atoms_batch_tokenized,
atoms_polarity_batch, atom_type, self.inverse_map, s_idx)
for s_idx in range(len(atoms_polarity_batch))], padding_value=0,
max_len=self.max_atoms_in_one_type // 2)
pos_encoding = torch.stack([self.get_pos_encoding_for_s_idx(atoms_encoding, atoms_batch_tokenized,
atoms_polarity_batch, atom_type, s_idx)
for s_idx in range(len(atoms_polarity_batch))])
neg_encoding = torch.stack([self.get_neg_encoding_for_s_idx(atoms_encoding, atoms_batch_tokenized,
atoms_polarity_batch, atom_type, s_idx)
for s_idx in range(len(atoms_polarity_batch))])
pos_encoding = self.pos_transformation(pos_encoding)
neg_encoding = self.neg_transformation(neg_encoding)
@@ -210,8 +203,9 @@ class Linker(Module):
print("")
print(f'Epoch: {epoch_i + 1:02} | Epoch Time: {training_time}')
print(f'\tTrain Loss: {avg_train_loss:.3f} | Train Acc: {avg_accuracy_train * 100:.2f}%')
if validation_rate > 0.0:
loss_test, accuracy_test = self.eval_epoch(validation_dataloader, self.cross_entropy_loss)
loss_test, accuracy_test = self.eval_epoch(validation_dataloader)
print(f'\tVal Loss: {loss_test:.3f} | Val Acc: {accuracy_test * 100:.2f}%')
if checkpoint:
@@ -236,7 +230,6 @@ class Linker(Module):
Args:
training_dataloader : DataLoader from torch , contains atoms, polarities, axiom_links, sents_tokenized, sents_masks
validation_dataloader : DataLoader from torch , contains atoms, polarities, axiom_links, sents_tokenized, sents_masks
Returns:
accuracy on validation set
loss on train set
@@ -300,12 +293,9 @@ class Linker(Module):
self.eval()
with torch.no_grad():
# get atoms
atoms_batch = get_atoms_batch(categories)
atoms_batch, polarities = get_GOAL(self.max_atoms_in_sentence, categories)
atoms_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms_batch)
# get polarities
polarities = find_pos_neg_idexes(self.max_atoms_in_sentence, categories)
# atoms embedding
atoms_embedding = self.atoms_embedding(atoms_tokenized)
@@ -316,14 +306,12 @@ class Linker(Module):
link_weights = []
for atom_type in self.sub_atoms_type_list:
pos_encoding = pad_sequence(
[get_pos_encoding_for_s_idx(self.dim_embedding_atoms, atoms_encoding, atoms_tokenized,
polarities, atom_type, self.inverse_map, s_idx)
[self.get_pos_encoding_for_s_idx(atoms_encoding, atoms_tokenized, polarities, atom_type, s_idx)
for s_idx in range(len(polarities))], padding_value=0,
max_len=self.max_atoms_in_one_type // 2)
neg_encoding = pad_sequence(
[get_neg_encoding_for_s_idx(self.dim_embedding_atoms, atoms_encoding, atoms_tokenized,
polarities, atom_type, self.inverse_map, s_idx)
[self.get_neg_encoding_for_s_idx(atoms_encoding, atoms_tokenized, polarities, atom_type, s_idx)
for s_idx in range(len(polarities))], padding_value=0,
max_len=self.max_atoms_in_one_type // 2)
@@ -337,7 +325,7 @@ class Linker(Module):
axiom_links = torch.argmax(F.log_softmax(logits_predictions, dim=3), dim=3)
return axiom_links
def eval_batch(self, batch, cross_entropy_loss):
def eval_batch(self, batch):
batch_atoms = batch[0].to(self.device)
batch_polarity = batch[1].to(self.device)
batch_true_links = batch[2].to(self.device)
@@ -349,12 +337,20 @@ class Linker(Module):
batch_sentences_mask)
axiom_links_pred = torch.argmax(logits_axiom_links_pred, dim=3)
print('\n')
print("Tokens de la phrase : ", batch_sentences_tokens[1])
print("Atoms dans la phrase : ", (batch_atoms[1][:50]))
print("Polarités des atoms de la phrase : ", batch_polarity[1][:50])
print("Les vrais liens de la catégorie n : ", batch_true_links[1][2][:100])
print("Les prédictions : ", axiom_links_pred[1][2][:100])
print('\n')
accuracy = mesure_accuracy(batch_true_links, axiom_links_pred)
loss = cross_entropy_loss(logits_axiom_links_pred, batch_true_links)
loss = self.cross_entropy_loss(logits_axiom_links_pred, batch_true_links)
return loss, accuracy
def eval_epoch(self, dataloader, cross_entropy_loss):
def eval_epoch(self, dataloader):
r"""Average the evaluation of all the batch.
Args:
@@ -365,7 +361,7 @@ class Linker(Module):
loss_average = 0
with torch.no_grad():
for step, batch in enumerate(dataloader):
loss, accuracy = self.eval_batch(batch, cross_entropy_loss)
loss, accuracy = self.eval_batch(batch)
accuracy_average += accuracy
loss_average += float(loss)
@@ -405,3 +401,35 @@ class Linker(Module):
'optimizer': self.optimizer,
}, path)
self.to(self.device)
def get_pos_encoding_for_s_idx(self, atoms_encoding, atoms_batch_tokenized, atoms_polarity_batch, atom_type, s_idx):
pos_encoding = [x for i, x in enumerate(atoms_encoding[s_idx])
if (atom_map[atom_type] in atoms_batch_tokenized[s_idx] and
bool(re.match(r"" + atom_type + "_?\w*",
self.inverse_map[int(atoms_batch_tokenized[s_idx][i])])) and
atoms_polarity_batch[s_idx][i])]
if len(pos_encoding) == 0:
return torch.zeros(self.max_atoms_in_one_type//2, self.dim_embedding_atoms,
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
else:
len_pos_encoding = len(pos_encoding)
pos_encoding += [torch.zeros(self.dim_embedding_atoms,
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")) for i in
range(self.max_atoms_in_one_type//2 - len_pos_encoding)]
return torch.stack(pos_encoding)
def get_neg_encoding_for_s_idx(self, atoms_encoding, atoms_batch_tokenized, atoms_polarity_batch, atom_type, s_idx):
neg_encoding = [x for i, x in enumerate(atoms_encoding[s_idx])
if (atom_map[atom_type] in atoms_batch_tokenized[s_idx] and
bool(re.match(r"" + atom_type + "_?\w*",
self.inverse_map[int(atoms_batch_tokenized[s_idx][i])])) and
not atoms_polarity_batch[s_idx][i])]
if len(neg_encoding) == 0:
return torch.zeros(self.max_atoms_in_one_type//2, self.dim_embedding_atoms,
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
else:
len_neg_encoding = len(neg_encoding)
neg_encoding += [torch.zeros(self.dim_embedding_atoms,
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")) for i in
range(self.max_atoms_in_one_type//2 - len_neg_encoding)]
return torch.stack(neg_encoding)
Loading