diff --git a/Linker/Linker.py b/Linker/Linker.py
index f3a4538d67c8f6f3a1300e9f80a793d9c4036e68..8429260558d0a192c6240a4e85f25da841df3f94 100644
--- a/Linker/Linker.py
+++ b/Linker/Linker.py
@@ -17,9 +17,9 @@ from Linker.AtomEmbedding import AtomEmbedding
 from Linker.AtomTokenizer import AtomTokenizer
 from Linker.MHA import AttentionDecoderLayer
 from Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn
-from Linker.atom_map import atom_map
+from Linker.atom_map import atom_map, atom_map_redux
 from Linker.eval import mesure_accuracy, SinkhornLoss
-from Linker.utils_linker import FFN, get_axiom_links, get_GOAL
+from Linker.utils_linker import FFN, get_axiom_links, get_GOAL, get_pos_idx
 from Supertagger import *
 from utils import pad_sequence
 
@@ -69,7 +69,7 @@ class Linker(Module):
         self.Supertagger = supertagger
 
         self.atom_map = atom_map
-        self.sub_atoms_type_list = ['cl_r', 'pp', 'n', 'np', 'cl_y', 'txt', 's']
+        self.sub_atoms_type_list = list(atom_map_redux.keys())
         self.padding_id = self.atom_map['[PAD]']
         self.atoms_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence)
         self.inverse_map = self.atoms_tokenizer.inverse_atom_map
@@ -110,12 +110,15 @@ class Linker(Module):
         atoms_batch, atoms_polarity_batch = get_GOAL(self.max_atoms_in_sentence, df_axiom_links["Z"])
         atoms_batch_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms_batch)
 
+        pos_idx = get_pos_idx(atoms_batch_tokenized, atoms_polarity_batch, self.max_atoms_in_one_type)
+        neg_idx = get_pos_idx(atoms_batch_tokenized, atoms_polarity_batch, self.max_atoms_in_one_type)
+
         truth_links_batch = get_axiom_links(self.max_atoms_in_one_type, self.sub_atoms_type_list, atoms_polarity_batch,
                                             df_axiom_links["Y"])
         truth_links_batch = truth_links_batch.permute(1, 0, 2)
 
         # Construction tensor dataset
-        dataset = TensorDataset(atoms_batch_tokenized, atoms_polarity_batch, truth_links_batch, sentences_tokens,
+        dataset = TensorDataset(atoms_batch_tokenized, pos_idx, neg_idx, truth_links_batch, sentences_tokens,
                                 sentences_mask)
 
         if validation_rate > 0.0:
@@ -136,11 +139,12 @@ class Linker(Module):
         decoder_attn_mask[atoms_token.eq(self.padding_id)] = 0.0
         return decoder_attn_mask.unsqueeze(1).repeat(1, atoms_token.shape[1], 1).repeat(self.nhead, 1, 1)
 
-    def forward(self, atoms_batch_tokenized, atoms_polarity_batch, sents_embedding, sents_mask=None):
+    def forward(self, atoms_batch_tokenized, batch_pos_idx, batch_neg_idx, sents_embedding, sents_mask=None):
         r"""
         Args:
             atoms_batch_tokenized : (batch_size, max_atoms_in_one_sentence) flattened categories
-            atoms_polarity_batch : (batch_size, max_atoms_in_one_sentence) flattened categories polarities
+            batch_pos_idx : (batch_size, atom_vocab_size, max atom in one cat) flattened categories polarities
+            batch_neg_idx :
             sents_embedding : (batch_size, len_sentence, dim_encoder) output of BERT for context
             sents_mask : mask from BERT tokenizer
         Returns:
@@ -157,13 +161,8 @@ class Linker(Module):
 
         link_weights = []
         for atom_type in self.sub_atoms_type_list:
-            pos_encoding = torch.stack([self.get_pos_encoding_for_s_idx(atoms_encoding, atoms_batch_tokenized,
-                                                                        atoms_polarity_batch, atom_type, s_idx)
-                                        for s_idx in range(len(atoms_polarity_batch))])
-
-            neg_encoding = torch.stack([self.get_neg_encoding_for_s_idx(atoms_encoding, atoms_batch_tokenized,
-                                                                        atoms_polarity_batch, atom_type, s_idx)
-                                        for s_idx in range(len(atoms_polarity_batch))])
+            pos_encoding = self.make_sinkhorn_inputs(atoms_encoding, batch_pos_idx, atom_type)
+            neg_encoding = self.make_sinkhorn_inputs(atoms_encoding, batch_neg_idx, atom_type)
 
             pos_encoding = self.pos_transformation(pos_encoding)
             neg_encoding = self.neg_transformation(neg_encoding)
@@ -246,10 +245,11 @@ class Linker(Module):
             for batch in tepoch:
                 # Unpack this training batch from our dataloader
                 batch_atoms = batch[0].to(self.device)
-                batch_polarity = batch[1].to(self.device)
-                batch_true_links = batch[2].to(self.device)
-                batch_sentences_tokens = batch[3].to(self.device)
-                batch_sentences_mask = batch[4].to(self.device)
+                batch_pos_idx = batch[1].to(self.device)
+                batch_neg_idx = batch[2].to(self.device)
+                batch_true_links = batch[3].to(self.device)
+                batch_sentences_tokens = batch[4].to(self.device)
+                batch_sentences_mask = batch[5].to(self.device)
 
                 self.optimizer.zero_grad()
 
@@ -257,7 +257,7 @@ class Linker(Module):
                 logits, sentences_embedding = self.Supertagger.forward(batch_sentences_tokens, batch_sentences_mask)
 
                 # Run the kinker on the categories predictions
-                logits_predictions = self(batch_atoms, batch_polarity, sentences_embedding, batch_sentences_mask)
+                logits_predictions = self(batch_atoms, batch_pos_idx, batch_neg_idx, sentences_embedding, batch_sentences_mask)
 
                 linker_loss = self.cross_entropy_loss(logits_predictions, batch_true_links)
                 # Perform a backward pass to calculate the gradients.
@@ -280,67 +280,24 @@ class Linker(Module):
 
         return avg_train_loss, avg_accuracy_train, training_time
 
-    def predict(self, categories, sents_embedding, sents_mask=None):
-        r"""Prediction from categories output by BERT and hidden_state from BERT
-
-        Args:
-            categories : (batch_size, len_sentence)
-            sents_embedding : (batch_size, len_sentence, dim_encoder) output of BERT for context
-            sents_mask
-        Returns:
-            axiom_links : atom_vocab_size, batch-size, max_atoms_in_one_cat)
-        """
-        self.eval()
-        with torch.no_grad():
-            # get atoms
-            atoms_batch, polarities = get_GOAL(self.max_atoms_in_sentence, categories)
-            atoms_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms_batch)
-
-            # atoms embedding
-            atoms_embedding = self.atoms_embedding(atoms_tokenized)
-
-            # MHA ou LSTM avec sortie de BERT
-            atoms_encoding = self.linker_encoder(atoms_embedding, sents_embedding, sents_mask,
-                                                 self.make_decoder_mask(atoms_tokenized))
-
-            link_weights = []
-            for atom_type in self.sub_atoms_type_list:
-                pos_encoding = pad_sequence(
-                    [self.get_pos_encoding_for_s_idx(atoms_encoding, atoms_tokenized, polarities, atom_type, s_idx)
-                     for s_idx in range(len(polarities))], padding_value=0,
-                    max_len=self.max_atoms_in_one_type // 2)
-
-                neg_encoding = pad_sequence(
-                    [self.get_neg_encoding_for_s_idx(atoms_encoding, atoms_tokenized, polarities, atom_type, s_idx)
-                     for s_idx in range(len(polarities))], padding_value=0,
-                    max_len=self.max_atoms_in_one_type // 2)
-
-                pos_encoding = self.pos_transformation(pos_encoding)
-                neg_encoding = self.neg_transformation(neg_encoding)
-
-                weights = torch.bmm(pos_encoding, neg_encoding.transpose(2, 1))
-                link_weights.append(sinkhorn(weights, iters=3))
-
-            logits_predictions = torch.stack(link_weights).permute(1, 0, 2, 3)
-            axiom_links = torch.argmax(F.log_softmax(logits_predictions, dim=3), dim=3)
-            return axiom_links
-
     def eval_batch(self, batch):
         batch_atoms = batch[0].to(self.device)
-        batch_polarity = batch[1].to(self.device)
-        batch_true_links = batch[2].to(self.device)
-        batch_sentences_tokens = batch[3].to(self.device)
-        batch_sentences_mask = batch[4].to(self.device)
+        batch_pos_idx = batch[1].to(self.device)
+        batch_neg_idx = batch[2].to(self.device)
+        batch_true_links = batch[3].to(self.device)
+        batch_sentences_tokens = batch[4].to(self.device)
+        batch_sentences_mask = batch[5].to(self.device)
 
         logits, sentences_embedding = self.Supertagger.forward(batch_sentences_tokens, batch_sentences_mask)
-        logits_axiom_links_pred = self(batch_atoms, batch_polarity, sentences_embedding,
+        logits_axiom_links_pred = self(batch_atoms, batch_pos_idx, batch_neg_idx, sentences_embedding,
                                        batch_sentences_mask)
         axiom_links_pred = torch.argmax(logits_axiom_links_pred, dim=3)
 
         print('\n')
         print("Tokens de la phrase : ", batch_sentences_tokens[1])
         print("Atoms dans la phrase : ", (batch_atoms[1][:50]))
-        print("Polarités des atoms de la phrase : ", batch_polarity[1][:50])
+        print("Polarités + des atoms de la phrase : ", batch_pos_idx[1][:50])
+        print("Polarités - des atoms de la phrase : ", batch_neg_idx[1][:50])
         print("Les vrais liens de la catégorie n : ", batch_true_links[1][2][:100])
         print("Les prédictions : ", axiom_links_pred[1][2][:100])
         print('\n')
@@ -402,34 +359,16 @@ class Linker(Module):
         }, path)
         self.to(self.device)
 
-    def get_pos_encoding_for_s_idx(self, atoms_encoding, atoms_batch_tokenized, atoms_polarity_batch, atom_type, s_idx):
-        pos_encoding = [x for i, x in enumerate(atoms_encoding[s_idx])
-                        if (atom_map[atom_type] in atoms_batch_tokenized[s_idx] and
-                            bool(re.match(r"" + atom_type + "_?\w*",
-                                          self.inverse_map[int(atoms_batch_tokenized[s_idx][i])])) and
-                            atoms_polarity_batch[s_idx][i])]
-        if len(pos_encoding) == 0:
-            return torch.zeros(self.max_atoms_in_one_type//2, self.dim_embedding_atoms,
-                               device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
-        else:
-            len_pos_encoding = len(pos_encoding)
-            pos_encoding += [torch.zeros(self.dim_embedding_atoms,
-                               device=torch.device("cuda" if torch.cuda.is_available() else "cpu")) for i in
-                             range(self.max_atoms_in_one_type//2 - len_pos_encoding)]
-            return torch.stack(pos_encoding)
-
-    def get_neg_encoding_for_s_idx(self, atoms_encoding, atoms_batch_tokenized, atoms_polarity_batch, atom_type, s_idx):
-        neg_encoding = [x for i, x in enumerate(atoms_encoding[s_idx])
-                        if (atom_map[atom_type] in atoms_batch_tokenized[s_idx] and
-                            bool(re.match(r"" + atom_type + "_?\w*",
-                                          self.inverse_map[int(atoms_batch_tokenized[s_idx][i])])) and
-                            not atoms_polarity_batch[s_idx][i])]
-        if len(neg_encoding) == 0:
-            return torch.zeros(self.max_atoms_in_one_type//2, self.dim_embedding_atoms,
-                               device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
-        else:
-            len_neg_encoding = len(neg_encoding)
-            neg_encoding += [torch.zeros(self.dim_embedding_atoms,
-                               device=torch.device("cuda" if torch.cuda.is_available() else "cpu")) for i in
-                             range(self.max_atoms_in_one_type//2 - len_neg_encoding)]
-            return torch.stack(neg_encoding)
+    def make_sinkhorn_inputs(self, bsd_tensor, positional_ids, atom_type):
+        """
+        :param bsd_tensor:
+            Tensor of shape batch size \times sequence length \times feature dimensionality.
+        :param positional_ids:
+            A List of batch_size elements, each being a List of num_atoms LongTensors.
+            Each LongTensor in positional_ids[b][a] indexes the location of atoms of type a in sentence b.
+        :param device:
+        :return:
+        """
+
+        return [[bsd_tensor.select(0, index=i).index_select(0, index=atom.to(self.device)) for atom in sentence]
+                for i, sentence in enumerate(positional_ids[atom_map_redux[atom_type]])]
\ No newline at end of file
diff --git a/Linker/atom_map.py b/Linker/atom_map.py
index d45c4b9709ee161960302e485f01a39e08c3fc76..4e0c45e4faed7171fb563685c85f172327dd4295 100644
--- a/Linker/atom_map.py
+++ b/Linker/atom_map.py
@@ -17,3 +17,13 @@ atom_map = \
      's_ppart': 15,
      '[PAD]': 16
      }
+
+atom_map_redux = {
+    'cl_r': 0,
+    'pp': 1,
+    'n': 2,
+    'np': 3,
+    'cl_y': 4,
+    'txt': 5,
+    's': 6
+}
diff --git a/Linker/utils_linker.py b/Linker/utils_linker.py
index 0aa6dc25bfac924b64ce38d481787e81b93c980d..a5f0ff261ed94a2cb79908592051afc1e5c9ec27 100644
--- a/Linker/utils_linker.py
+++ b/Linker/utils_linker.py
@@ -3,7 +3,7 @@ import regex
 import torch
 from torch.nn import Sequential, Linear, Dropout, GELU
 from torch.nn import Module
-from Linker.atom_map import atom_map
+from Linker.atom_map import atom_map, atom_map_redux
 from utils import pad_sequence
 
 
@@ -276,4 +276,19 @@ def get_GOAL(max_atoms_in_sentence, categories_batch):
 ################################ Prepare encoding ###############################################
 #########################################################################################
 
+def get_pos_idx(atoms_batch_tokenized, atoms_polarity_batch, max_atoms_in_one_type):
+    inverse_atom_map = {v: k for k, v in atom_map.items()}
+    pos_idx = [pad_sequence([torch.as_tensor([i for i, x in enumerate(sentence) if bool(re.match(r"" + atom_type + "_?\w*", inverse_atom_map[int(atoms_batch_tokenized[s_idx][i])])) and
+                                                     atoms_polarity_batch[s_idx][i]]) for s_idx, sentence in enumerate(atoms_batch_tokenized)], max_len=max_atoms_in_one_type//2, padding_value=-1)
+                for atom_type in list(atom_map_redux.keys())]
 
+    return torch.stack(pos_idx).permute(1, 0, 2)
+
+
+def get_neg_idx(atoms_batch_tokenized, atoms_polarity_batch, max_atoms_in_one_type):
+    inverse_atom_map = {v: k for k, v in atom_map.items()}
+    neg_idx = [pad_sequence([torch.as_tensor([i for i, x in enumerate(sentence) if bool(re.match(r"" + atom_type + "_?\w*", inverse_atom_map[int(atoms_batch_tokenized[s_idx][i])])) and
+                                                    not atoms_polarity_batch[s_idx][i]]) for s_idx, sentence in enumerate(atoms_batch_tokenized)], max_len=max_atoms_in_one_type//2, padding_value=-1)
+                for atom_type in list(atom_map_redux.keys())]
+
+    return torch.stack(neg_idx).permute(1, 0, 2)