Skip to content
Snippets Groups Projects
Commit e115a712 authored by Caroline DE POURTALES's avatar Caroline DE POURTALES
Browse files

update linker encoding

parent 1f43915a
Branches
No related tags found
2 merge requests!6Linker with transformer,!5Linker with transformer
...@@ -115,7 +115,7 @@ class Linker(Module): ...@@ -115,7 +115,7 @@ class Linker(Module):
sents_embedding : (batch_size, len_sentence, dim_encoder) output of BERT for context sents_embedding : (batch_size, len_sentence, dim_encoder) output of BERT for context
sents_mask : mask from BERT tokenizer sents_mask : mask from BERT tokenizer
Returns: Returns:
link_weights : atom_vocab_size, batch-size, max_atoms_in_one_cat, max_atoms_in_one_cat) link_weights : atom_vocab_size, batch-size, max_atoms_in_one_cat, max_atoms_in_one_cat) log probabilities
""" """
# atoms embedding # atoms embedding
...@@ -152,7 +152,10 @@ class Linker(Module): ...@@ -152,7 +152,10 @@ class Linker(Module):
weights = torch.bmm(pos_encoding, neg_encoding.transpose(2, 1)) weights = torch.bmm(pos_encoding, neg_encoding.transpose(2, 1))
link_weights.append(sinkhorn(weights, iters=3)) link_weights.append(sinkhorn(weights, iters=3))
return torch.stack(link_weights) total_link_weights = torch.stack(link_weights)
link_weights_per_batch = total_link_weights.permute(1, 0, 2, 3)
return F.log_softmax(link_weights_per_batch, dim=3)
def train_linker(self, df_axiom_links, sentences_tokens, sentences_mask, validation_rate=0.1, epochs=20, def train_linker(self, df_axiom_links, sentences_tokens, sentences_mask, validation_rate=0.1, epochs=20,
batch_size=32, checkpoint=True, validate=True): batch_size=32, checkpoint=True, validate=True):
...@@ -208,7 +211,7 @@ class Linker(Module): ...@@ -208,7 +211,7 @@ class Linker(Module):
# Run the kinker on the categories predictions # Run the kinker on the categories predictions
logits_predictions = self(batch_atoms, batch_polarity, sentences_embedding, batch_sentences_mask) logits_predictions = self(batch_atoms, batch_polarity, sentences_embedding, batch_sentences_mask)
linker_loss = self.cross_entropy_loss(logits_predictions.permute(1, 0, 2, 3), batch_true_links) linker_loss = self.cross_entropy_loss(logits_predictions, batch_true_links)
# Perform a backward pass to calculate the gradients. # Perform a backward pass to calculate the gradients.
epoch_loss += float(linker_loss) epoch_loss += float(linker_loss)
linker_loss.backward() linker_loss.backward()
...@@ -290,7 +293,7 @@ class Linker(Module): ...@@ -290,7 +293,7 @@ class Linker(Module):
link_weights.append(sinkhorn(weights, iters=3)) link_weights.append(sinkhorn(weights, iters=3))
logits_predictions = torch.stack(link_weights).permute(1, 0, 2, 3) logits_predictions = torch.stack(link_weights).permute(1, 0, 2, 3)
axiom_links = torch.argmax(F.softmax(logits_predictions, dim=3), dim=3) axiom_links = torch.argmax(F.log_softmax(logits_predictions, dim=3), dim=3)
return axiom_links return axiom_links
def eval_batch(self, batch, cross_entropy_loss): def eval_batch(self, batch, cross_entropy_loss):
...@@ -303,8 +306,7 @@ class Linker(Module): ...@@ -303,8 +306,7 @@ class Linker(Module):
logits, sentences_embedding = self.Supertagger.foward(batch_sentences_tokens, batch_sentences_mask) logits, sentences_embedding = self.Supertagger.foward(batch_sentences_tokens, batch_sentences_mask)
logits_axiom_links_pred = self(batch_atoms, batch_polarity, sentences_embedding, logits_axiom_links_pred = self(batch_atoms, batch_polarity, sentences_embedding,
batch_sentences_mask) batch_sentences_mask)
logits_axiom_links_pred = logits_axiom_links_pred.permute(1, 0, 2, 3) axiom_links_pred = torch.argmax(logits_axiom_links_pred, dim=3)
axiom_links_pred = torch.argmax(F.softmax(logits_axiom_links_pred, dim=3), dim=3)
accuracy = mesure_accuracy(batch_true_links, axiom_links_pred) accuracy = mesure_accuracy(batch_true_links, axiom_links_pred)
loss = cross_entropy_loss(logits_axiom_links_pred, batch_true_links) loss = cross_entropy_loss(logits_axiom_links_pred, batch_true_links)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment