diff --git a/Linker/Linker.py b/Linker/Linker.py
index d2a4d899238c155c3f5cb0bbf08a0ca4f58ff841..ad39aed0177bdde9e1bcec0f2b0c4ed0069fc0e5 100644
--- a/Linker/Linker.py
+++ b/Linker/Linker.py
@@ -115,7 +115,7 @@ class Linker(Module):
             sents_embedding : (batch_size, len_sentence, dim_encoder) output of BERT for context
             sents_mask : mask from BERT tokenizer
         Returns:
-            link_weights : atom_vocab_size, batch-size, max_atoms_in_one_cat, max_atoms_in_one_cat)
+            link_weights : atom_vocab_size, batch-size, max_atoms_in_one_cat, max_atoms_in_one_cat) log probabilities
         """
 
         # atoms embedding
@@ -152,7 +152,10 @@ class Linker(Module):
             weights = torch.bmm(pos_encoding, neg_encoding.transpose(2, 1))
             link_weights.append(sinkhorn(weights, iters=3))
 
-        return torch.stack(link_weights)
+        total_link_weights = torch.stack(link_weights)
+        link_weights_per_batch = total_link_weights.permute(1, 0, 2, 3)
+
+        return F.log_softmax(link_weights_per_batch, dim=3)
 
     def train_linker(self, df_axiom_links, sentences_tokens, sentences_mask, validation_rate=0.1, epochs=20,
                      batch_size=32, checkpoint=True, validate=True):
@@ -208,7 +211,7 @@ class Linker(Module):
             # Run the kinker on the categories predictions
             logits_predictions = self(batch_atoms, batch_polarity, sentences_embedding, batch_sentences_mask)
 
-            linker_loss = self.cross_entropy_loss(logits_predictions.permute(1, 0, 2, 3), batch_true_links)
+            linker_loss = self.cross_entropy_loss(logits_predictions, batch_true_links)
             # Perform a backward pass to calculate the gradients.
             epoch_loss += float(linker_loss)
             linker_loss.backward()
@@ -290,7 +293,7 @@ class Linker(Module):
             link_weights.append(sinkhorn(weights, iters=3))
 
         logits_predictions = torch.stack(link_weights).permute(1, 0, 2, 3)
-        axiom_links = torch.argmax(F.softmax(logits_predictions, dim=3), dim=3)
+        axiom_links = torch.argmax(F.log_softmax(logits_predictions, dim=3), dim=3)
         return axiom_links
 
     def eval_batch(self, batch, cross_entropy_loss):
@@ -303,8 +306,7 @@ class Linker(Module):
         logits, sentences_embedding = self.Supertagger.foward(batch_sentences_tokens, batch_sentences_mask)
         logits_axiom_links_pred = self(batch_atoms, batch_polarity, sentences_embedding,
                                        batch_sentences_mask)
-        logits_axiom_links_pred = logits_axiom_links_pred.permute(1, 0, 2, 3)
-        axiom_links_pred = torch.argmax(F.softmax(logits_axiom_links_pred, dim=3), dim=3)
+        axiom_links_pred = torch.argmax(logits_axiom_links_pred, dim=3)
 
         accuracy = mesure_accuracy(batch_true_links, axiom_links_pred)
         loss = cross_entropy_loss(logits_axiom_links_pred, batch_true_links)