Skip to content
Snippets Groups Projects
Commit e115a712 authored by Caroline DE POURTALES's avatar Caroline DE POURTALES
Browse files

update linker encoding

parent 1f43915a
No related branches found
No related tags found
2 merge requests!6Linker with transformer,!5Linker with transformer
...@@ -115,7 +115,7 @@ class Linker(Module): ...@@ -115,7 +115,7 @@ class Linker(Module):
sents_embedding : (batch_size, len_sentence, dim_encoder) output of BERT for context sents_embedding : (batch_size, len_sentence, dim_encoder) output of BERT for context
sents_mask : mask from BERT tokenizer sents_mask : mask from BERT tokenizer
Returns: Returns:
link_weights : atom_vocab_size, batch-size, max_atoms_in_one_cat, max_atoms_in_one_cat) link_weights : atom_vocab_size, batch-size, max_atoms_in_one_cat, max_atoms_in_one_cat) log probabilities
""" """
# atoms embedding # atoms embedding
...@@ -152,7 +152,10 @@ class Linker(Module): ...@@ -152,7 +152,10 @@ class Linker(Module):
weights = torch.bmm(pos_encoding, neg_encoding.transpose(2, 1)) weights = torch.bmm(pos_encoding, neg_encoding.transpose(2, 1))
link_weights.append(sinkhorn(weights, iters=3)) link_weights.append(sinkhorn(weights, iters=3))
return torch.stack(link_weights) total_link_weights = torch.stack(link_weights)
link_weights_per_batch = total_link_weights.permute(1, 0, 2, 3)
return F.log_softmax(link_weights_per_batch, dim=3)
def train_linker(self, df_axiom_links, sentences_tokens, sentences_mask, validation_rate=0.1, epochs=20, def train_linker(self, df_axiom_links, sentences_tokens, sentences_mask, validation_rate=0.1, epochs=20,
batch_size=32, checkpoint=True, validate=True): batch_size=32, checkpoint=True, validate=True):
...@@ -208,7 +211,7 @@ class Linker(Module): ...@@ -208,7 +211,7 @@ class Linker(Module):
# Run the kinker on the categories predictions # Run the kinker on the categories predictions
logits_predictions = self(batch_atoms, batch_polarity, sentences_embedding, batch_sentences_mask) logits_predictions = self(batch_atoms, batch_polarity, sentences_embedding, batch_sentences_mask)
linker_loss = self.cross_entropy_loss(logits_predictions.permute(1, 0, 2, 3), batch_true_links) linker_loss = self.cross_entropy_loss(logits_predictions, batch_true_links)
# Perform a backward pass to calculate the gradients. # Perform a backward pass to calculate the gradients.
epoch_loss += float(linker_loss) epoch_loss += float(linker_loss)
linker_loss.backward() linker_loss.backward()
...@@ -290,7 +293,7 @@ class Linker(Module): ...@@ -290,7 +293,7 @@ class Linker(Module):
link_weights.append(sinkhorn(weights, iters=3)) link_weights.append(sinkhorn(weights, iters=3))
logits_predictions = torch.stack(link_weights).permute(1, 0, 2, 3) logits_predictions = torch.stack(link_weights).permute(1, 0, 2, 3)
axiom_links = torch.argmax(F.softmax(logits_predictions, dim=3), dim=3) axiom_links = torch.argmax(F.log_softmax(logits_predictions, dim=3), dim=3)
return axiom_links return axiom_links
def eval_batch(self, batch, cross_entropy_loss): def eval_batch(self, batch, cross_entropy_loss):
...@@ -303,8 +306,7 @@ class Linker(Module): ...@@ -303,8 +306,7 @@ class Linker(Module):
logits, sentences_embedding = self.Supertagger.foward(batch_sentences_tokens, batch_sentences_mask) logits, sentences_embedding = self.Supertagger.foward(batch_sentences_tokens, batch_sentences_mask)
logits_axiom_links_pred = self(batch_atoms, batch_polarity, sentences_embedding, logits_axiom_links_pred = self(batch_atoms, batch_polarity, sentences_embedding,
batch_sentences_mask) batch_sentences_mask)
logits_axiom_links_pred = logits_axiom_links_pred.permute(1, 0, 2, 3) axiom_links_pred = torch.argmax(logits_axiom_links_pred, dim=3)
axiom_links_pred = torch.argmax(F.softmax(logits_axiom_links_pred, dim=3), dim=3)
accuracy = mesure_accuracy(batch_true_links, axiom_links_pred) accuracy = mesure_accuracy(batch_true_links, axiom_links_pred)
loss = cross_entropy_loss(logits_axiom_links_pred, batch_true_links) loss = cross_entropy_loss(logits_axiom_links_pred, batch_true_links)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment