diff --git a/Configuration/config.ini b/Configuration/config.ini
index c79def55882180c36facd03f2d0d9593501a13c6..d6c860553e48b27f2c16d5d1c2228c86e51b7fc4 100644
--- a/Configuration/config.ini
+++ b/Configuration/config.ini
@@ -12,15 +12,15 @@ max_atoms_in_one_type=250
 dim_encoder = 768
 
 [MODEL_DECODER]
-dim_decoder = 8
+dim_decoder = 16
 num_rnn_layers=1
 dropout=0.1
 teacher_forcing=0.05
 
 [MODEL_LINKER]
-nhead=1
+nhead=4
 dim_feedforward=246
-dim_embedding_atoms=8
+dim_embedding_atoms=16
 dim_polarity_transfo=128
 layer_norm_eps=1e-5
 dropout=0.1
diff --git a/Linker/Linker.py b/Linker/Linker.py
index 12f9534edfc8fe5a35fbad8b62dbc2fe7ee44d67..ce256406be8271ec9a951af3f08f7cd7995c7a5a 100644
--- a/Linker/Linker.py
+++ b/Linker/Linker.py
@@ -17,7 +17,8 @@ from Linker.AtomTokenizer import AtomTokenizer
 from Linker.MHA import AttentionDecoderLayer
 from Linker.atom_map import atom_map
 from Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn
-from Linker.utils_linker import find_pos_neg_idexes, get_atoms_batch, FFN, get_axiom_links
+from Linker.utils_linker import find_pos_neg_idexes, get_atoms_batch, FFN, get_axiom_links, get_pos_encoding_for_s_idx, \
+    get_neg_encoding_for_s_idx
 from Linker.eval import mesure_accuracy, SinkhornLoss
 from utils import pad_sequence
 
@@ -130,23 +131,17 @@ class Linker(Module):
 
         link_weights = []
         for atom_type in list(self.atom_map.keys())[:-1]:
-            pos_encoding = pad_sequence([torch.stack([x for i, x in enumerate(atoms_encoding[s_idx])
-                                                      if (self.atom_map[atom_type] in atoms_batch_tokenized[s_idx] and
-                                                          atoms_batch_tokenized[s_idx][i] == self.atom_map[
-                                                              atom_type] and
-                                                          atoms_polarity_batch[s_idx][i])] + [
-                                                         torch.zeros(self.dim_embedding_atoms, device=self.device)]).to(self.device)
-                                         for s_idx in range(len(atoms_polarity_batch))], padding_value=0,
-                                        max_len=self.max_atoms_in_one_type // 2).to(self.device)
-
-            neg_encoding = pad_sequence([torch.stack([x for i, x in enumerate(atoms_encoding[s_idx])
-                                                      if (self.atom_map[atom_type] in atoms_batch_tokenized[s_idx] and
-                                                          atoms_batch_tokenized[s_idx][i] == self.atom_map[
-                                                              atom_type] and
-                                                          not atoms_polarity_batch[s_idx][i])] + [
-                                                         torch.zeros(self.dim_embedding_atoms, device=self.device)]).to(self.device)
-                                         for s_idx in range(len(atoms_polarity_batch))], padding_value=0,
-                                        max_len=self.max_atoms_in_one_type // 2).to(self.device)
+            pos_encoding = pad_sequence(
+                [get_pos_encoding_for_s_idx(self.dim_embedding_atoms, atoms_encoding, atoms_batch_tokenized,
+                                            atoms_polarity_batch, atom_type, s_idx)
+                 for s_idx in range(len(atoms_polarity_batch))], padding_value=0,
+                max_len=self.max_atoms_in_one_type // 2).to(self.device)
+
+            neg_encoding = pad_sequence(
+                [get_neg_encoding_for_s_idx(self.dim_embedding_atoms, atoms_encoding, atoms_batch_tokenized,
+                                            atoms_polarity_batch, atom_type, s_idx)
+                 for s_idx in range(len(atoms_polarity_batch))], padding_value=0,
+                max_len=self.max_atoms_in_one_type // 2).to(self.device)
 
             pos_encoding = self.pos_transformation(pos_encoding)
             neg_encoding = self.neg_transformation(neg_encoding)
@@ -271,23 +266,17 @@ class Linker(Module):
 
         link_weights = []
         for atom_type in list(self.atom_map.keys())[:-1]:
-            pos_encoding = pad_sequence([torch.stack([x for i, x in enumerate(atoms_encoding[s_idx])
-                                                      if (self.atom_map[atom_type] in atoms_tokenized[s_idx] and
-                                                          atoms_tokenized[s_idx][i] == self.atom_map[
-                                                              atom_type] and
-                                                          polarities[s_idx][i])] + [
-                                                         torch.zeros(self.dim_embedding_atoms, device=self.device)])
-                                         for s_idx in range(len(polarities))], padding_value=0,
-                                        max_len=self.max_atoms_in_one_type // 2).to(self.device)
-
-            neg_encoding = pad_sequence([torch.stack([x for i, x in enumerate(atoms_encoding[s_idx])
-                                                      if (self.atom_map[atom_type] in atoms_tokenized[s_idx] and
-                                                          atoms_tokenized[s_idx][i] == self.atom_map[
-                                                              atom_type] and
-                                                          not polarities[s_idx][i])] + [
-                                                         torch.zeros(self.dim_embedding_atoms, device=self.device)])
-                                         for s_idx in range(len(polarities))], padding_value=0,
-                                        max_len=self.max_atoms_in_one_type // 2).to(self.device)
+            pos_encoding = pad_sequence(
+                [get_pos_encoding_for_s_idx(self.dim_embedding_atoms, atoms_encoding, atoms_tokenized,
+                                            polarities, atom_type, s_idx)
+                 for s_idx in range(len(polarities))], padding_value=0,
+                max_len=self.max_atoms_in_one_type // 2).to(self.device)
+
+            neg_encoding = pad_sequence(
+                [get_neg_encoding_for_s_idx(self.dim_embedding_atoms, atoms_encoding, atoms_tokenized,
+                                            polarities, atom_type, s_idx)
+                 for s_idx in range(len(polarities))], padding_value=0,
+                max_len=self.max_atoms_in_one_type // 2).to(self.device)
 
             pos_encoding = self.pos_transformation(pos_encoding)
             neg_encoding = self.neg_transformation(neg_encoding)
diff --git a/Linker/__init__.py b/Linker/__init__.py
index c0df5b8d2f6b10dc52709b2bd7b132eb1c1c2066..c2a9483d03b868e0c2b00cae7a54f7bb7b7bd4db 100644
--- a/Linker/__init__.py
+++ b/Linker/__init__.py
@@ -1 +1,3 @@
-from .Linker import Linker
\ No newline at end of file
+from .Linker import Linker
+from .atom_map import atom_map
+from .AtomEmbedding import AtomEmbedding
\ No newline at end of file
diff --git a/Linker/utils_linker.py b/Linker/utils_linker.py
index 13c63f47346f9aac35ac230734994cfb227d036b..0821f6196c55a3c0961ecc89c23aaacde8f53140 100644
--- a/Linker/utils_linker.py
+++ b/Linker/utils_linker.py
@@ -27,7 +27,7 @@ regex_categories = r'\w+\(\d+,(?:((?R))|(\w+))*,?(?:((?R))|(\w+))*\)'
 
 
 #########################################################################################
-################################ Liste des atoms avc _i########################################
+################################ Liste des atoms avec _i ########################################
 #########################################################################################
 
 
@@ -72,7 +72,7 @@ def category_to_atoms_axiom_links(category, categories_to_atoms):
     res = [bool(re.match(r'' + atom_type + "_\d+", category)) for atom_type in atom_map.keys()]
     if category.startswith("GOAL:"):
         word, cat = category.split(':')
-        return [cat]
+        return category_to_atoms_axiom_links(cat, categories_to_atoms)
     elif True in res:
         return [category]
     else:
@@ -103,7 +103,6 @@ def get_atoms_links_batch(category_batch):
 ################################ Liste des atoms ########################################
 #########################################################################################
 
-
 def category_to_atoms(category, categories_to_atoms):
     r"""
     Args:
@@ -115,8 +114,7 @@ def category_to_atoms(category, categories_to_atoms):
     res = [bool(re.match(r'' + atom_type + "_\d+", category)) for atom_type in atom_map.keys()]
     if category.startswith("GOAL:"):
         word, cat = category.split(':')
-        category = re.match(r'([a-zA-Z|_]+)_\d+', cat).group(1)
-        return [category]
+        return category_to_atoms(cat, categories_to_atoms)
     elif True in res:
         category = re.match(r'([a-zA-Z|_]+)_\d+', category).group(1)
         return [category]
@@ -158,78 +156,41 @@ def category_to_atoms_polarity(category, polarity):
     """
     category_to_polarity = []
     res = [bool(re.match(r'' + atom_type + "_\d+", category)) for atom_type in atom_map.keys()]
+
+    # mot final
     if category.startswith("GOAL:"):
-        category_to_polarity.append(True)
+        word, cat = category.split(':')
+        res = [bool(re.match(r'' + atom_type + "_\d+", cat)) for atom_type in atom_map.keys()]
+        if True in res:
+            category_to_polarity.append(True)
+        else:
+            category_to_polarity += category_to_atoms_polarity(cat, True)
+
+    # le mot a une category atomique
     elif True in res or category.startswith("dia") or category.startswith("box"):
-        category_to_polarity.append(False)
+        category_to_polarity.append(not polarity)
+
+    # sinon c'est une formule longue
     else:
         # dr = /
         if category.startswith("dr"):
             category_cut = regex.match(regex_categories, category).groups()
             category_cut = [cat for cat in category_cut if cat is not None]
             left_side, right_side = category_cut[0], category_cut[1]
-
-            if polarity == True:
-                # for the left side : normal
-                res = [bool(re.match(r'' + atom_type + "_\d+", left_side)) for atom_type in atom_map.keys()]
-                if True in res or left_side.startswith("dia") or left_side.startswith("box"):
-                    category_to_polarity.append(False)
-                else:
-                    category_to_polarity += category_to_atoms_polarity(left_side, True)
-                # for the right side : change polarity for next right formula
-                res = [bool(re.match(r'' + atom_type + "_\d+", right_side)) for atom_type in atom_map.keys()]
-                if True in res or right_side.startswith("dia") or right_side.startswith("box"):
-                    category_to_polarity.append(True)
-                else:
-                    category_to_polarity += category_to_atoms_polarity(right_side, False)
-
-            else:
-                # for the left side
-                res = [bool(re.match(r'' + atom_type + "_\d+", left_side)) for atom_type in atom_map.keys()]
-                if True in res or left_side.startswith("dia") or left_side.startswith("box"):
-                    category_to_polarity.append(True)
-                else:
-                    category_to_polarity += category_to_atoms_polarity(left_side, False)
-                # for the right side : change polarity for next right formula
-                res = [bool(re.match(r'' + atom_type + "_\d+", right_side)) for atom_type in atom_map.keys()]
-                if True in res or right_side.startswith("dia") or right_side.startswith("box"):
-                    category_to_polarity.append(False)
-                else:
-                    category_to_polarity += category_to_atoms_polarity(right_side, True)
+            # for the left side
+            category_to_polarity += category_to_atoms_polarity(left_side, polarity)
+            # for the right side : change polarity for next right formula
+            category_to_polarity += category_to_atoms_polarity(right_side, not polarity)
 
         # dl = \
         elif category.startswith("dl"):
             category_cut = regex.match(regex_categories, category).groups()
             category_cut = [cat for cat in category_cut if cat is not None]
             left_side, right_side = category_cut[0], category_cut[1]
-
-            if polarity == True:
-                # for the left side : change polarity
-                res = [bool(re.match(r'' + atom_type + "_\d+", left_side)) for atom_type in atom_map.keys()]
-                if True in res or left_side.startswith("dia") or left_side.startswith("box"):
-                    category_to_polarity.append(True)
-                else:
-                    category_to_polarity += category_to_atoms_polarity(left_side, False)
-                # for the right side : normal
-                res = [bool(re.match(r'' + atom_type + "_\d+", right_side)) for atom_type in atom_map.keys()]
-                if True in res or right_side.startswith("dia") or right_side.startswith("box"):
-                    category_to_polarity.append(False)
-                else:
-                    category_to_polarity += category_to_atoms_polarity(right_side, True)
-
-            else:
-                # for the left side
-                res = [bool(re.match(r'' + atom_type + "_\d+", left_side)) for atom_type in atom_map.keys()]
-                if True in res or left_side.startswith("dia") or left_side.startswith("box"):
-                    category_to_polarity.append(False)
-                else:
-                    category_to_polarity += category_to_atoms_polarity(left_side, True)
-                # for the right side
-                res = [bool(re.match(r'' + atom_type + "_\d+", right_side)) for atom_type in atom_map.keys()]
-                if True in res or right_side.startswith("dia") or right_side.startswith("box"):
-                    category_to_polarity.append(True)
-                else:
-                    category_to_polarity += category_to_atoms_polarity(right_side, False)
+            # for the left side
+            category_to_polarity += category_to_atoms_polarity(left_side, not polarity)
+            # for the right side
+            category_to_polarity += category_to_atoms_polarity(right_side, polarity)
 
     return category_to_polarity
 
@@ -251,3 +212,32 @@ def find_pos_neg_idexes(max_atoms_in_sentence, atoms_batch):
         list_batch.append(torch.as_tensor(list_atoms))
     return pad_sequence([list_batch[i] for i in range(len(list_batch))],
                         max_len=max_atoms_in_sentence, padding_value=0)
+
+
+#########################################################################################
+################################ Prepare encoding ###############################################
+#########################################################################################
+
+
+def get_pos_encoding_for_s_idx(dim_embedding_atoms, atoms_encoding, atoms_batch_tokenized, atoms_polarity_batch,
+                               atom_type, s_idx):
+    pos_encoding = [x for i, x in enumerate(atoms_encoding[s_idx])
+                    if (atom_map[atom_type] in atoms_batch_tokenized[s_idx] and
+                        atoms_batch_tokenized[s_idx][i] == atom_map[atom_type] and
+                        atoms_polarity_batch[s_idx][i])]
+    if len(pos_encoding) == 0:
+        return torch.zeros(1, dim_embedding_atoms, device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
+    else:
+        return torch.stack(pos_encoding)
+
+
+def get_neg_encoding_for_s_idx(dim_embedding_atoms, atoms_encoding, atoms_batch_tokenized, atoms_polarity_batch,
+                               atom_type, s_idx):
+    neg_encoding = [x for i, x in enumerate(atoms_encoding[s_idx])
+                    if (atom_map[atom_type] in atoms_batch_tokenized[s_idx] and
+                        atoms_batch_tokenized[s_idx][i] == atom_map[atom_type] and
+                        not atoms_polarity_batch[s_idx][i])]
+    if len(neg_encoding) == 0:
+        return torch.zeros(1, dim_embedding_atoms, device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
+    else:
+        return torch.stack(neg_encoding)
diff --git a/train.py b/train.py
index b2e73259a8ad242fb2103e22dbedf531d9e9c1d2..2c502001a16b1dbb13f19082f4c39ecc4dfbb4c5 100644
--- a/train.py
+++ b/train.py
@@ -6,7 +6,7 @@ from utils import read_csv_pgbar
 
 torch.cuda.empty_cache()
 batch_size = int(Configuration.modelTrainingConfig['batch_size'])
-nb_sentences = batch_size * 10
+nb_sentences = batch_size * 200
 epochs = int(Configuration.modelTrainingConfig['epoch'])
 
 file_path_axiom_links = 'Datasets/aa1_links_dataset_links.csv'
@@ -15,8 +15,6 @@ df_axiom_links = read_csv_pgbar(file_path_axiom_links, nb_sentences)
 sentences_batch = df_axiom_links["Sentences"].tolist()
 supertagger = SuperTagger()
 supertagger.load_weights("models/model_supertagger.pt")
-
-
 sents_tokenized, sents_mask = supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch)
 
 print("Linker")