diff --git a/Linker/Linker.py b/Linker/Linker.py index d2a4d899238c155c3f5cb0bbf08a0ca4f58ff841..ad39aed0177bdde9e1bcec0f2b0c4ed0069fc0e5 100644 --- a/Linker/Linker.py +++ b/Linker/Linker.py @@ -115,7 +115,7 @@ class Linker(Module): sents_embedding : (batch_size, len_sentence, dim_encoder) output of BERT for context sents_mask : mask from BERT tokenizer Returns: - link_weights : atom_vocab_size, batch-size, max_atoms_in_one_cat, max_atoms_in_one_cat) + link_weights : atom_vocab_size, batch-size, max_atoms_in_one_cat, max_atoms_in_one_cat) log probabilities """ # atoms embedding @@ -152,7 +152,10 @@ class Linker(Module): weights = torch.bmm(pos_encoding, neg_encoding.transpose(2, 1)) link_weights.append(sinkhorn(weights, iters=3)) - return torch.stack(link_weights) + total_link_weights = torch.stack(link_weights) + link_weights_per_batch = total_link_weights.permute(1, 0, 2, 3) + + return F.log_softmax(link_weights_per_batch, dim=3) def train_linker(self, df_axiom_links, sentences_tokens, sentences_mask, validation_rate=0.1, epochs=20, batch_size=32, checkpoint=True, validate=True): @@ -208,7 +211,7 @@ class Linker(Module): # Run the kinker on the categories predictions logits_predictions = self(batch_atoms, batch_polarity, sentences_embedding, batch_sentences_mask) - linker_loss = self.cross_entropy_loss(logits_predictions.permute(1, 0, 2, 3), batch_true_links) + linker_loss = self.cross_entropy_loss(logits_predictions, batch_true_links) # Perform a backward pass to calculate the gradients. epoch_loss += float(linker_loss) linker_loss.backward() @@ -290,7 +293,7 @@ class Linker(Module): link_weights.append(sinkhorn(weights, iters=3)) logits_predictions = torch.stack(link_weights).permute(1, 0, 2, 3) - axiom_links = torch.argmax(F.softmax(logits_predictions, dim=3), dim=3) + axiom_links = torch.argmax(F.log_softmax(logits_predictions, dim=3), dim=3) return axiom_links def eval_batch(self, batch, cross_entropy_loss): @@ -303,8 +306,7 @@ class Linker(Module): logits, sentences_embedding = self.Supertagger.foward(batch_sentences_tokens, batch_sentences_mask) logits_axiom_links_pred = self(batch_atoms, batch_polarity, sentences_embedding, batch_sentences_mask) - logits_axiom_links_pred = logits_axiom_links_pred.permute(1, 0, 2, 3) - axiom_links_pred = torch.argmax(F.softmax(logits_axiom_links_pred, dim=3), dim=3) + axiom_links_pred = torch.argmax(logits_axiom_links_pred, dim=3) accuracy = mesure_accuracy(batch_true_links, axiom_links_pred) loss = cross_entropy_loss(logits_axiom_links_pred, batch_true_links)