diff --git a/Configuration/config.ini b/Configuration/config.ini
index 7bdebd8e1c3fda56bac736b5b5b84665d2e3b21f..06c6f4b37289e03e384a0512ee1d90f327e71dcf 100644
--- a/Configuration/config.ini
+++ b/Configuration/config.ini
@@ -11,14 +11,14 @@ max_atoms_in_one_type=510
 dim_encoder = 768
 
 [MODEL_LINKER]
-dim_cat_out=512
+dim_cat_out=768
 dim_intermediate_FFN=256
 dim_pre_sinkhorn_transfo=32
 dropout=0.1
-sinkhorn_iters=3
+sinkhorn_iters=5
 
 [MODEL_TRAINING]
 batch_size=32
 epoch=25
 seed_val=42
-learning_rate=2e-4
+learning_rate=2e-3
diff --git a/Linker/Linker.py b/Linker/Linker.py
index b1f7dbfe4a8fd437e0ea14952af9dd0da3dc6273..c1a97ba352e1f32a6c26658324079ed50cf30a27 100644
--- a/Linker/Linker.py
+++ b/Linker/Linker.py
@@ -9,6 +9,7 @@ import torch
 import torch.nn.functional as F
 from torch.nn import Sequential, LayerNorm, Module, Linear, Dropout
 from torch.optim import AdamW
+from torch.optim.lr_scheduler import StepLR
 from torch.utils.data import TensorDataset, random_split
 from torch.utils.tensorboard import SummaryWriter
 from tqdm import tqdm
@@ -57,11 +58,11 @@ class Linker(Module):
         dim_pre_sinkhorn_transfo = int(Configuration.modelLinkerConfig['dim_pre_sinkhorn_transfo'])
         dim_intermediate_FFN = int(Configuration.modelLinkerConfig['dim_intermediate_FFN'])
         self.sinkhorn_iters = int(Configuration.modelLinkerConfig['sinkhorn_iters'])
+        dropout = float(Configuration.modelLinkerConfig['dropout'])
         self.max_len_sentence = int(Configuration.datasetConfig['max_len_sentence'])
         self.max_atoms_in_sentence = int(Configuration.datasetConfig['max_atoms_in_sentence'])
         self.max_atoms_in_one_type = int(Configuration.datasetConfig['max_atoms_in_one_type'])
         learning_rate = float(Configuration.modelTrainingConfig['learning_rate'])
-        dropout = float(Configuration.modelTrainingConfig['dropout'])
         self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
         supertagger = SuperTagger()
@@ -70,6 +71,7 @@ class Linker(Module):
         self.Supertagger.model.to(self.device)
 
         self.atom_map = atom_map
+        self.atom_map_redux = atom_map_redux
         self.sub_atoms_type_list = list(atom_map_redux.keys())
         self.padding_id = self.atom_map['[PAD]']
         self.atoms_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence)
@@ -93,6 +95,7 @@ class Linker(Module):
         self.cross_entropy_loss = SinkhornLoss()
         self.optimizer = AdamW(self.parameters(),
                                lr=learning_rate)
+        self.scheduler = StepLR(self.optimizer, step_size=2, gamma=0.5)
 
         self.to(self.device)
 
@@ -166,7 +169,9 @@ class Linker(Module):
         atoms_encoding = self.dropout(atoms_encoding)
 
         # linking per atom type
-        link_weights = []
+        batch_size, atom_vocan_size, _ = batch_pos_idx.shape
+        link_weights = torch.zeros(atom_vocan_size, batch_size, self.max_atoms_in_one_type // 2,
+                                   self.max_atoms_in_one_type // 2, device=self.device)
         for atom_type in self.sub_atoms_type_list:
             pos_encoding = self.make_sinkhorn_inputs(atoms_encoding, batch_pos_idx, atom_type)
             neg_encoding = self.make_sinkhorn_inputs(atoms_encoding, batch_neg_idx, atom_type)
@@ -175,11 +180,9 @@ class Linker(Module):
             neg_encoding = self.neg_transformation(neg_encoding)
 
             weights = torch.bmm(pos_encoding, neg_encoding.transpose(2, 1))
-            link_weights.append(sinkhorn(weights, iters=self.sinkhorn_iters))
+            link_weights[self.atom_map_redux[atom_type]] = sinkhorn(weights, iters=self.sinkhorn_iters)
 
-        total_link_weights = torch.stack(link_weights)
-
-        return F.log_softmax(total_link_weights, dim=3)
+        return F.log_softmax(link_weights, dim=3)
 
     def train_linker(self, df_axiom_links, validation_rate=0.1, epochs=20,
                      batch_size=32, checkpoint=True, tensorboard=False):
@@ -278,7 +281,9 @@ class Linker(Module):
                 self.optimizer.step()
 
                 pred_axiom_links = torch.argmax(logits_predictions, dim=3)
-                accuracy_train += mesure_accuracy(batch_true_links, pred_axiom_links)
+                accuracy_train += mesure_accuracy(batch_true_links, pred_axiom_links, self.max_atoms_in_one_type)
+
+        self.scheduler.step()
 
         # Measure how long this epoch took.
         training_time = format_time(time.time() - t0)
@@ -297,18 +302,18 @@ class Linker(Module):
 
         output = self.Supertagger.forward(batch_sentences_tokens, batch_sentences_mask)
         logits_predictions = self(batch_num_atoms, batch_pos_idx, batch_neg_idx, output['word_embeding'],
-                                  output['last_hidden_state'])
-        axiom_links_pred = torch.argmax(logits_predictions, dim=3)
+                                  output['last_hidden_state'])  # atom_vocab, batch_size, max atoms in one type, max atoms in one type
+        axiom_links_pred = torch.argmax(logits_predictions, dim=3)  # atom_vocab, batch_size, max atoms in one type
 
         print('\n')
         print("Tokens de la phrase : ", batch_sentences_tokens[1])
-        print("Polarités + des atoms de la phrase : ", batch_pos_idx[1][:50])
-        print("Polarités - des atoms de la phrase : ", batch_neg_idx[1][:50])
+        print("Polarités + des atoms de la phrase : ", batch_pos_idx[1][2][:50])
+        print("Polarités - des atoms de la phrase : ", batch_neg_idx[1][2][:50])
         print("Les vrais liens de la catégorie n : ", batch_true_links[1][2][:100])
-        print("Les prédictions : ", axiom_links_pred[1][2][:100])
+        print("Les prédictions : ", axiom_links_pred[2][1][:100])
         print('\n')
 
-        accuracy = mesure_accuracy(batch_true_links, axiom_links_pred)
+        accuracy = mesure_accuracy(batch_true_links, axiom_links_pred, self.max_atoms_in_one_type)
         loss = self.cross_entropy_loss(logits_predictions, batch_true_links)
 
         return loss, accuracy
diff --git a/Linker/eval.py b/Linker/eval.py
index e713120ce61d3a43619559bd2eaadf867a958931..c60bb007b1b2388e2aa2df46d9f4d46fa775f1fa 100644
--- a/Linker/eval.py
+++ b/Linker/eval.py
@@ -8,21 +8,22 @@ class SinkhornLoss(Module):
         super(SinkhornLoss, self).__init__()
 
     def forward(self, predictions, truths):
-        return sum(nll_loss(link.flatten(0, 1), perm.flatten(), reduction='mean', ignore_index=-1)
+        return sum(nll_loss(link.flatten(0, 1), perm.flatten(), reduction='mean')
                    for link, perm in zip(predictions, truths.permute(1, 0, 2)))
 
 
-def mesure_accuracy(batch_true_links, axiom_links_pred):
+def mesure_accuracy(batch_true_links, axiom_links_pred, max_atoms_in_one_type):
     r"""
     batch_true_links : (atom_vocab_size, batch_size, max_atoms_in_one_cat) contains the index of the negative atoms
     axiom_links_pred : (atom_vocab_size, batch_size, max_atoms_in_one_cat) contains the index of the negative atoms
     """
+    padding = max_atoms_in_one_type // 2 -1
     batch_true_links=batch_true_links.permute(1, 0, 2)
     correct_links = torch.ones(axiom_links_pred.size())
     correct_links[axiom_links_pred != batch_true_links] = 0
-    correct_links[batch_true_links == -1] = 1
+    correct_links[batch_true_links == padding] = 1
     num_correct_links = correct_links.sum().item()
-    num_masked_atoms = len(batch_true_links[batch_true_links == -1])
+    num_masked_atoms = len(batch_true_links[batch_true_links == padding])
 
     # diviser par nombre de links
     return (num_correct_links - num_masked_atoms)/(axiom_links_pred.size()[0] * axiom_links_pred.size()[1] * axiom_links_pred.size()[2] - num_masked_atoms)
diff --git a/Linker/utils_linker.py b/Linker/utils_linker.py
index c34c6bfc26b7090721f980ac55d25e8d04fda046..955d5571c10518c09f795c6cf19f61f95a91f15e 100644
--- a/Linker/utils_linker.py
+++ b/Linker/utils_linker.py
@@ -51,9 +51,11 @@ def get_axiom_links(max_atoms_in_one_type, sub_atoms_type_list, atoms_polarity,
                             range(len(atoms_batch))]
 
         linking_plus_to_minus = pad_sequence(
-            [torch.as_tensor([l_polarity_minus[s_idx].index(x) if x in l_polarity_minus[s_idx] else -1 for i, x in
-                              enumerate(l_polarity_plus[s_idx])], dtype=torch.long)
-             for s_idx in range(len(atoms_batch))], max_len=max_atoms_in_one_type // 2, padding_value=-1)
+            [torch.as_tensor(
+                [l_polarity_minus[s_idx].index(x) if x in l_polarity_minus[s_idx] else max_atoms_in_one_type // 2 -1 for
+                 i, x in enumerate(l_polarity_plus[s_idx])], dtype=torch.long)
+             for s_idx in range(len(atoms_batch))], max_len=max_atoms_in_one_type // 2,
+            padding_value=max_atoms_in_one_type // 2 -1)
 
         linking_plus_to_minus_all_types.append(linking_plus_to_minus)