Skip to content
Snippets Groups Projects
Commit b012fcf5 authored by Caroline DE POURTALES's avatar Caroline DE POURTALES
Browse files

update padding

parent 3296f9db
No related branches found
No related tags found
2 merge requests!6Linker with transformer,!5Linker with transformer
...@@ -11,14 +11,14 @@ max_atoms_in_one_type=510 ...@@ -11,14 +11,14 @@ max_atoms_in_one_type=510
dim_encoder = 768 dim_encoder = 768
[MODEL_LINKER] [MODEL_LINKER]
dim_cat_out=512 dim_cat_out=768
dim_intermediate_FFN=256 dim_intermediate_FFN=256
dim_pre_sinkhorn_transfo=32 dim_pre_sinkhorn_transfo=32
dropout=0.1 dropout=0.1
sinkhorn_iters=3 sinkhorn_iters=5
[MODEL_TRAINING] [MODEL_TRAINING]
batch_size=32 batch_size=32
epoch=25 epoch=25
seed_val=42 seed_val=42
learning_rate=2e-4 learning_rate=2e-3
...@@ -9,6 +9,7 @@ import torch ...@@ -9,6 +9,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn import Sequential, LayerNorm, Module, Linear, Dropout from torch.nn import Sequential, LayerNorm, Module, Linear, Dropout
from torch.optim import AdamW from torch.optim import AdamW
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import TensorDataset, random_split from torch.utils.data import TensorDataset, random_split
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm from tqdm import tqdm
...@@ -57,11 +58,11 @@ class Linker(Module): ...@@ -57,11 +58,11 @@ class Linker(Module):
dim_pre_sinkhorn_transfo = int(Configuration.modelLinkerConfig['dim_pre_sinkhorn_transfo']) dim_pre_sinkhorn_transfo = int(Configuration.modelLinkerConfig['dim_pre_sinkhorn_transfo'])
dim_intermediate_FFN = int(Configuration.modelLinkerConfig['dim_intermediate_FFN']) dim_intermediate_FFN = int(Configuration.modelLinkerConfig['dim_intermediate_FFN'])
self.sinkhorn_iters = int(Configuration.modelLinkerConfig['sinkhorn_iters']) self.sinkhorn_iters = int(Configuration.modelLinkerConfig['sinkhorn_iters'])
dropout = float(Configuration.modelLinkerConfig['dropout'])
self.max_len_sentence = int(Configuration.datasetConfig['max_len_sentence']) self.max_len_sentence = int(Configuration.datasetConfig['max_len_sentence'])
self.max_atoms_in_sentence = int(Configuration.datasetConfig['max_atoms_in_sentence']) self.max_atoms_in_sentence = int(Configuration.datasetConfig['max_atoms_in_sentence'])
self.max_atoms_in_one_type = int(Configuration.datasetConfig['max_atoms_in_one_type']) self.max_atoms_in_one_type = int(Configuration.datasetConfig['max_atoms_in_one_type'])
learning_rate = float(Configuration.modelTrainingConfig['learning_rate']) learning_rate = float(Configuration.modelTrainingConfig['learning_rate'])
dropout = float(Configuration.modelTrainingConfig['dropout'])
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
supertagger = SuperTagger() supertagger = SuperTagger()
...@@ -70,6 +71,7 @@ class Linker(Module): ...@@ -70,6 +71,7 @@ class Linker(Module):
self.Supertagger.model.to(self.device) self.Supertagger.model.to(self.device)
self.atom_map = atom_map self.atom_map = atom_map
self.atom_map_redux = atom_map_redux
self.sub_atoms_type_list = list(atom_map_redux.keys()) self.sub_atoms_type_list = list(atom_map_redux.keys())
self.padding_id = self.atom_map['[PAD]'] self.padding_id = self.atom_map['[PAD]']
self.atoms_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence) self.atoms_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence)
...@@ -93,6 +95,7 @@ class Linker(Module): ...@@ -93,6 +95,7 @@ class Linker(Module):
self.cross_entropy_loss = SinkhornLoss() self.cross_entropy_loss = SinkhornLoss()
self.optimizer = AdamW(self.parameters(), self.optimizer = AdamW(self.parameters(),
lr=learning_rate) lr=learning_rate)
self.scheduler = StepLR(self.optimizer, step_size=2, gamma=0.5)
self.to(self.device) self.to(self.device)
...@@ -166,7 +169,9 @@ class Linker(Module): ...@@ -166,7 +169,9 @@ class Linker(Module):
atoms_encoding = self.dropout(atoms_encoding) atoms_encoding = self.dropout(atoms_encoding)
# linking per atom type # linking per atom type
link_weights = [] batch_size, atom_vocan_size, _ = batch_pos_idx.shape
link_weights = torch.zeros(atom_vocan_size, batch_size, self.max_atoms_in_one_type // 2,
self.max_atoms_in_one_type // 2, device=self.device)
for atom_type in self.sub_atoms_type_list: for atom_type in self.sub_atoms_type_list:
pos_encoding = self.make_sinkhorn_inputs(atoms_encoding, batch_pos_idx, atom_type) pos_encoding = self.make_sinkhorn_inputs(atoms_encoding, batch_pos_idx, atom_type)
neg_encoding = self.make_sinkhorn_inputs(atoms_encoding, batch_neg_idx, atom_type) neg_encoding = self.make_sinkhorn_inputs(atoms_encoding, batch_neg_idx, atom_type)
...@@ -175,11 +180,9 @@ class Linker(Module): ...@@ -175,11 +180,9 @@ class Linker(Module):
neg_encoding = self.neg_transformation(neg_encoding) neg_encoding = self.neg_transformation(neg_encoding)
weights = torch.bmm(pos_encoding, neg_encoding.transpose(2, 1)) weights = torch.bmm(pos_encoding, neg_encoding.transpose(2, 1))
link_weights.append(sinkhorn(weights, iters=self.sinkhorn_iters)) link_weights[self.atom_map_redux[atom_type]] = sinkhorn(weights, iters=self.sinkhorn_iters)
total_link_weights = torch.stack(link_weights) return F.log_softmax(link_weights, dim=3)
return F.log_softmax(total_link_weights, dim=3)
def train_linker(self, df_axiom_links, validation_rate=0.1, epochs=20, def train_linker(self, df_axiom_links, validation_rate=0.1, epochs=20,
batch_size=32, checkpoint=True, tensorboard=False): batch_size=32, checkpoint=True, tensorboard=False):
...@@ -278,7 +281,9 @@ class Linker(Module): ...@@ -278,7 +281,9 @@ class Linker(Module):
self.optimizer.step() self.optimizer.step()
pred_axiom_links = torch.argmax(logits_predictions, dim=3) pred_axiom_links = torch.argmax(logits_predictions, dim=3)
accuracy_train += mesure_accuracy(batch_true_links, pred_axiom_links) accuracy_train += mesure_accuracy(batch_true_links, pred_axiom_links, self.max_atoms_in_one_type)
self.scheduler.step()
# Measure how long this epoch took. # Measure how long this epoch took.
training_time = format_time(time.time() - t0) training_time = format_time(time.time() - t0)
...@@ -297,18 +302,18 @@ class Linker(Module): ...@@ -297,18 +302,18 @@ class Linker(Module):
output = self.Supertagger.forward(batch_sentences_tokens, batch_sentences_mask) output = self.Supertagger.forward(batch_sentences_tokens, batch_sentences_mask)
logits_predictions = self(batch_num_atoms, batch_pos_idx, batch_neg_idx, output['word_embeding'], logits_predictions = self(batch_num_atoms, batch_pos_idx, batch_neg_idx, output['word_embeding'],
output['last_hidden_state']) output['last_hidden_state']) # atom_vocab, batch_size, max atoms in one type, max atoms in one type
axiom_links_pred = torch.argmax(logits_predictions, dim=3) axiom_links_pred = torch.argmax(logits_predictions, dim=3) # atom_vocab, batch_size, max atoms in one type
print('\n') print('\n')
print("Tokens de la phrase : ", batch_sentences_tokens[1]) print("Tokens de la phrase : ", batch_sentences_tokens[1])
print("Polarités + des atoms de la phrase : ", batch_pos_idx[1][:50]) print("Polarités + des atoms de la phrase : ", batch_pos_idx[1][2][:50])
print("Polarités - des atoms de la phrase : ", batch_neg_idx[1][:50]) print("Polarités - des atoms de la phrase : ", batch_neg_idx[1][2][:50])
print("Les vrais liens de la catégorie n : ", batch_true_links[1][2][:100]) print("Les vrais liens de la catégorie n : ", batch_true_links[1][2][:100])
print("Les prédictions : ", axiom_links_pred[1][2][:100]) print("Les prédictions : ", axiom_links_pred[2][1][:100])
print('\n') print('\n')
accuracy = mesure_accuracy(batch_true_links, axiom_links_pred) accuracy = mesure_accuracy(batch_true_links, axiom_links_pred, self.max_atoms_in_one_type)
loss = self.cross_entropy_loss(logits_predictions, batch_true_links) loss = self.cross_entropy_loss(logits_predictions, batch_true_links)
return loss, accuracy return loss, accuracy
......
...@@ -8,21 +8,22 @@ class SinkhornLoss(Module): ...@@ -8,21 +8,22 @@ class SinkhornLoss(Module):
super(SinkhornLoss, self).__init__() super(SinkhornLoss, self).__init__()
def forward(self, predictions, truths): def forward(self, predictions, truths):
return sum(nll_loss(link.flatten(0, 1), perm.flatten(), reduction='mean', ignore_index=-1) return sum(nll_loss(link.flatten(0, 1), perm.flatten(), reduction='mean')
for link, perm in zip(predictions, truths.permute(1, 0, 2))) for link, perm in zip(predictions, truths.permute(1, 0, 2)))
def mesure_accuracy(batch_true_links, axiom_links_pred): def mesure_accuracy(batch_true_links, axiom_links_pred, max_atoms_in_one_type):
r""" r"""
batch_true_links : (atom_vocab_size, batch_size, max_atoms_in_one_cat) contains the index of the negative atoms batch_true_links : (atom_vocab_size, batch_size, max_atoms_in_one_cat) contains the index of the negative atoms
axiom_links_pred : (atom_vocab_size, batch_size, max_atoms_in_one_cat) contains the index of the negative atoms axiom_links_pred : (atom_vocab_size, batch_size, max_atoms_in_one_cat) contains the index of the negative atoms
""" """
padding = max_atoms_in_one_type // 2 -1
batch_true_links=batch_true_links.permute(1, 0, 2) batch_true_links=batch_true_links.permute(1, 0, 2)
correct_links = torch.ones(axiom_links_pred.size()) correct_links = torch.ones(axiom_links_pred.size())
correct_links[axiom_links_pred != batch_true_links] = 0 correct_links[axiom_links_pred != batch_true_links] = 0
correct_links[batch_true_links == -1] = 1 correct_links[batch_true_links == padding] = 1
num_correct_links = correct_links.sum().item() num_correct_links = correct_links.sum().item()
num_masked_atoms = len(batch_true_links[batch_true_links == -1]) num_masked_atoms = len(batch_true_links[batch_true_links == padding])
# diviser par nombre de links # diviser par nombre de links
return (num_correct_links - num_masked_atoms)/(axiom_links_pred.size()[0] * axiom_links_pred.size()[1] * axiom_links_pred.size()[2] - num_masked_atoms) return (num_correct_links - num_masked_atoms)/(axiom_links_pred.size()[0] * axiom_links_pred.size()[1] * axiom_links_pred.size()[2] - num_masked_atoms)
...@@ -51,9 +51,11 @@ def get_axiom_links(max_atoms_in_one_type, sub_atoms_type_list, atoms_polarity, ...@@ -51,9 +51,11 @@ def get_axiom_links(max_atoms_in_one_type, sub_atoms_type_list, atoms_polarity,
range(len(atoms_batch))] range(len(atoms_batch))]
linking_plus_to_minus = pad_sequence( linking_plus_to_minus = pad_sequence(
[torch.as_tensor([l_polarity_minus[s_idx].index(x) if x in l_polarity_minus[s_idx] else -1 for i, x in [torch.as_tensor(
enumerate(l_polarity_plus[s_idx])], dtype=torch.long) [l_polarity_minus[s_idx].index(x) if x in l_polarity_minus[s_idx] else max_atoms_in_one_type // 2 -1 for
for s_idx in range(len(atoms_batch))], max_len=max_atoms_in_one_type // 2, padding_value=-1) i, x in enumerate(l_polarity_plus[s_idx])], dtype=torch.long)
for s_idx in range(len(atoms_batch))], max_len=max_atoms_in_one_type // 2,
padding_value=max_atoms_in_one_type // 2 -1)
linking_plus_to_minus_all_types.append(linking_plus_to_minus) linking_plus_to_minus_all_types.append(linking_plus_to_minus)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment