Skip to content
Snippets Groups Projects
Commit e115a712 authored by Caroline DE POURTALES's avatar Caroline DE POURTALES
Browse files

update linker encoding

parent 1f43915a
No related branches found
No related tags found
2 merge requests!6Linker with transformer,!5Linker with transformer
......@@ -115,7 +115,7 @@ class Linker(Module):
sents_embedding : (batch_size, len_sentence, dim_encoder) output of BERT for context
sents_mask : mask from BERT tokenizer
Returns:
link_weights : atom_vocab_size, batch-size, max_atoms_in_one_cat, max_atoms_in_one_cat)
link_weights : atom_vocab_size, batch-size, max_atoms_in_one_cat, max_atoms_in_one_cat) log probabilities
"""
# atoms embedding
......@@ -152,7 +152,10 @@ class Linker(Module):
weights = torch.bmm(pos_encoding, neg_encoding.transpose(2, 1))
link_weights.append(sinkhorn(weights, iters=3))
return torch.stack(link_weights)
total_link_weights = torch.stack(link_weights)
link_weights_per_batch = total_link_weights.permute(1, 0, 2, 3)
return F.log_softmax(link_weights_per_batch, dim=3)
def train_linker(self, df_axiom_links, sentences_tokens, sentences_mask, validation_rate=0.1, epochs=20,
batch_size=32, checkpoint=True, validate=True):
......@@ -208,7 +211,7 @@ class Linker(Module):
# Run the kinker on the categories predictions
logits_predictions = self(batch_atoms, batch_polarity, sentences_embedding, batch_sentences_mask)
linker_loss = self.cross_entropy_loss(logits_predictions.permute(1, 0, 2, 3), batch_true_links)
linker_loss = self.cross_entropy_loss(logits_predictions, batch_true_links)
# Perform a backward pass to calculate the gradients.
epoch_loss += float(linker_loss)
linker_loss.backward()
......@@ -290,7 +293,7 @@ class Linker(Module):
link_weights.append(sinkhorn(weights, iters=3))
logits_predictions = torch.stack(link_weights).permute(1, 0, 2, 3)
axiom_links = torch.argmax(F.softmax(logits_predictions, dim=3), dim=3)
axiom_links = torch.argmax(F.log_softmax(logits_predictions, dim=3), dim=3)
return axiom_links
def eval_batch(self, batch, cross_entropy_loss):
......@@ -303,8 +306,7 @@ class Linker(Module):
logits, sentences_embedding = self.Supertagger.foward(batch_sentences_tokens, batch_sentences_mask)
logits_axiom_links_pred = self(batch_atoms, batch_polarity, sentences_embedding,
batch_sentences_mask)
logits_axiom_links_pred = logits_axiom_links_pred.permute(1, 0, 2, 3)
axiom_links_pred = torch.argmax(F.softmax(logits_axiom_links_pred, dim=3), dim=3)
axiom_links_pred = torch.argmax(logits_axiom_links_pred, dim=3)
accuracy = mesure_accuracy(batch_true_links, axiom_links_pred)
loss = cross_entropy_loss(logits_axiom_links_pred, batch_true_links)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment