diff --git a/Configuration/config.ini b/Configuration/config.ini
index 06c6f4b37289e03e384a0512ee1d90f327e71dcf..c3ccbc2d4759512eb96cb3ae2db0bef7a6d5a939 100644
--- a/Configuration/config.ini
+++ b/Configuration/config.ini
@@ -3,17 +3,22 @@ transformers = 4.16.2
 
 [DATASET_PARAMS]
 symbols_vocab_size=26
+atom_vocab_size=18
 max_len_sentence=290
-max_atoms_in_sentence=1250
-max_atoms_in_one_type=510
+max_atoms_in_sentence=874
+max_atoms_in_one_type=324
 
 [MODEL_ENCODER]
 dim_encoder = 768
 
 [MODEL_LINKER]
-dim_cat_out=768
-dim_intermediate_FFN=256
-dim_pre_sinkhorn_transfo=32
+nhead=4
+dim_emb_atom = 256
+num_layers=2
+dim_cat_inter=512
+dim_cat_out=256
+dim_intermediate_FFN=128
+dim_pre_sinkhorn_transfo=64
 dropout=0.1
 sinkhorn_iters=5
 
@@ -21,4 +26,4 @@ sinkhorn_iters=5
 batch_size=32
 epoch=25
 seed_val=42
-learning_rate=2e-3
+learning_rate=2e-3
\ No newline at end of file
diff --git a/Linker/Linker.py b/Linker/Linker.py
index c1a97ba352e1f32a6c26658324079ed50cf30a27..ee8842514e241cfe0ec47b9cf68730cebac92050 100644
--- a/Linker/Linker.py
+++ b/Linker/Linker.py
@@ -1,3 +1,4 @@
+import math
 import os
 import re
 import sys
@@ -7,7 +8,8 @@ import time
 
 import torch
 import torch.nn.functional as F
-from torch.nn import Sequential, LayerNorm, Module, Linear, Dropout
+from torch.nn import Sequential, LayerNorm, Module, Linear, Dropout, TransformerEncoderLayer, TransformerEncoder, \
+    Embedding
 from torch.optim import AdamW
 from torch.optim.lr_scheduler import StepLR
 from torch.utils.data import TensorDataset, random_split
@@ -15,15 +17,17 @@ from torch.utils.tensorboard import SummaryWriter
 from tqdm import tqdm
 
 from Configuration import Configuration
-from Linker.AtomTokenizer import AtomTokenizer
-from Linker.PositionEncoding import PositionalEncoding
+from Linker.PositionalEncoding import PositionalEncoding
 from Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn
+from Linker.AtomTokenizer import AtomTokenizer
 from Linker.atom_map import atom_map, atom_map_redux
 from Linker.eval import mesure_accuracy, SinkhornLoss
 from Linker.utils_linker import FFN, get_axiom_links, get_GOAL, get_pos_idx, get_num_atoms_batch
 from Supertagger import SuperTagger
 from utils import pad_sequence
 
+import torch
+
 
 def format_time(elapsed):
     '''
@@ -49,49 +53,73 @@ def output_create_dir():
     return training_dir, writer
 
 
+def generate_square_subsequent_mask(sz):
+    """Generates an upper-triangular matrix of -inf, with zeros on diag."""
+    return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)
+
+
 class Linker(Module):
     def __init__(self, supertagger_path_model):
         super(Linker, self).__init__()
 
+        # region parameters
         dim_encoder = int(Configuration.modelEncoderConfig['dim_encoder'])
+        # atom settings
+        atom_vocab_size = int(Configuration.datasetConfig['atom_vocab_size'])
+        # Transformer
+        self.nhead = int(Configuration.modelLinkerConfig['nhead'])
+        self.dim_emb_atom = int(Configuration.modelLinkerConfig['dim_emb_atom'])
+        self.num_layers = int(Configuration.modelLinkerConfig['num_layers'])
+        # torch cat
+        self.dim_cat_inter = int(Configuration.modelLinkerConfig['dim_cat_out'])
         self.dim_cat_out = int(Configuration.modelLinkerConfig['dim_cat_out'])
-        dim_pre_sinkhorn_transfo = int(Configuration.modelLinkerConfig['dim_pre_sinkhorn_transfo'])
         dim_intermediate_FFN = int(Configuration.modelLinkerConfig['dim_intermediate_FFN'])
+        dim_pre_sinkhorn_transfo = int(Configuration.modelLinkerConfig['dim_pre_sinkhorn_transfo'])
+        # sinkhorn
         self.sinkhorn_iters = int(Configuration.modelLinkerConfig['sinkhorn_iters'])
-        dropout = float(Configuration.modelLinkerConfig['dropout'])
+        # settings
+        self.batch_size = int(Configuration.modelTrainingConfig['batch_size'])
         self.max_len_sentence = int(Configuration.datasetConfig['max_len_sentence'])
         self.max_atoms_in_sentence = int(Configuration.datasetConfig['max_atoms_in_sentence'])
         self.max_atoms_in_one_type = int(Configuration.datasetConfig['max_atoms_in_one_type'])
         learning_rate = float(Configuration.modelTrainingConfig['learning_rate'])
         self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+        # endregion
 
+        # Supertagger for categories
         supertagger = SuperTagger()
         supertagger.load_weights(supertagger_path_model)
         self.Supertagger = supertagger
         self.Supertagger.model.to(self.device)
 
-        self.atom_map = atom_map
+        # Atoms embedding
+        self.atoms_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence)
         self.atom_map_redux = atom_map_redux
         self.sub_atoms_type_list = list(atom_map_redux.keys())
-        self.padding_id = self.atom_map['[PAD]']
-        self.atoms_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence)
-        self.inverse_map = self.atoms_tokenizer.inverse_atom_map
-
-        self.position_encoding = PositionalEncoding(dim_encoder, max_len=self.max_atoms_in_sentence)
-
-        dim_cat = dim_encoder * 2
-        self.linker_encoder = Linear(dim_cat, self.dim_cat_out, bias=False)
-        self.dropout = Dropout(dropout)
+        self.atom_encoder = Embedding(self.max_atoms_in_sentence, self.dim_emb_atom, padding_idx=atom_map["[PAD]"])
+        self.atom_encoder.weight.data.uniform_(-0.1, 0.1)
+        self.position_encoder = PositionalEncoding(self.dim_emb_atom, 0.1, max_len=self.max_atoms_in_sentence)
+        encoder_layer = TransformerEncoderLayer(d_model=self.dim_emb_atom, nhead=self.nhead)
+        self.transformer = TransformerEncoder(encoder_layer, num_layers=self.num_layers)
+
+        # Concatenation with word embedding
+        dim_cat = dim_encoder + self.dim_emb_atom
+        self.linker_encoder = Sequential(
+            FFN(dim_cat, self.dim_cat_inter, 0.1, d_out=self.dim_cat_out),
+            LayerNorm(self.dim_cat_out, eps=1e-8)
+        )
 
+        # Division into positive and negative
         self.pos_transformation = Sequential(
             FFN(self.dim_cat_out, dim_intermediate_FFN, 0.1, d_out=dim_pre_sinkhorn_transfo),
-            LayerNorm(dim_pre_sinkhorn_transfo, eps=1e-12)
+            LayerNorm(dim_pre_sinkhorn_transfo, eps=1e-8)
         )
         self.neg_transformation = Sequential(
             FFN(self.dim_cat_out, dim_intermediate_FFN, 0.1, d_out=dim_pre_sinkhorn_transfo),
-            LayerNorm(dim_pre_sinkhorn_transfo, eps=1e-12)
+            LayerNorm(dim_pre_sinkhorn_transfo, eps=1e-8)
         )
 
+        # Learning
         self.cross_entropy_loss = SinkhornLoss()
         self.optimizer = AdamW(self.parameters(),
                                lr=learning_rate)
@@ -113,20 +141,21 @@ class Linker(Module):
         sentences_tokens, sentences_mask = self.Supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch)
 
         atoms_batch, atoms_polarity_batch = get_GOAL(self.max_atoms_in_sentence, df_axiom_links["Z"])
-        atoms_batch_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms_batch)
+        atoms_batch_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(
+            list(map(lambda sentence: [item for item in sentence.split(" ")], atoms_batch)))
 
         num_atoms_per_word = get_num_atoms_batch(df_axiom_links["Z"], self.max_len_sentence)
 
-        pos_idx = get_pos_idx(atoms_batch_tokenized, atoms_polarity_batch, self.max_atoms_in_one_type)
-        neg_idx = get_pos_idx(atoms_batch_tokenized, atoms_polarity_batch, self.max_atoms_in_one_type)
+        pos_idx = get_pos_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type)
+        neg_idx = get_pos_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type)
 
-        truth_links_batch = get_axiom_links(self.max_atoms_in_one_type, self.sub_atoms_type_list, atoms_polarity_batch,
+        truth_links_batch = get_axiom_links(self.max_atoms_in_one_type, atoms_polarity_batch,
                                             df_axiom_links["Y"])
         truth_links_batch = truth_links_batch.permute(1, 0, 2)
 
         # Construction tensor dataset
-        dataset = TensorDataset(num_atoms_per_word, pos_idx, neg_idx, truth_links_batch, sentences_tokens,
-                                sentences_mask)
+        dataset = TensorDataset(num_atoms_per_word, atoms_batch_tokenized, pos_idx, neg_idx, truth_links_batch,
+                                sentences_tokens, sentences_mask)
 
         if validation_rate > 0.0:
             train_size = int(0.9 * len(dataset))
@@ -141,38 +170,38 @@ class Linker(Module):
         print("End preprocess Data")
         return training_dataloader, validation_dataloader
 
-    def forward(self, batch_num_atoms_per_word, batch_pos_idx, batch_neg_idx, sents_embedding, cat_embedding):
+    def forward(self, batch_num_atoms_per_word, batch_atoms, src_mask, batch_pos_idx, batch_neg_idx, sents_embedding):
         r"""
         Args:
             batch_num_atoms_per_word : (batch_size, len_sentence) flattened categories
+            batch_atoms : atoms tok
+            src_mask : atoms mask
             batch_pos_idx : (batch_size, atom_vocab_size, max atom in one cat) flattened categories polarities
             batch_neg_idx : (batch_size, atom_vocab_size, max atom in one cat) flattened categories polarities
             sents_embedding : (batch_size, len_sentence, dim_encoder) output of BERT for context
-            cat_embedding : (batch_size, len_sentence, dim_encoder) output of BERT for cat embedding
         Returns:
             link_weights : atom_vocab_size, batch-size, max_atoms_in_one_cat, max_atoms_in_one_cat) log probabilities
         """
-        # repeat embedding word for each atom in word
+        # repeat embedding word for each atom in word with a +1 for sep
         sents_embedding_repeat = pad_sequence(
             [torch.repeat_interleave(input=sents_embedding[i], repeats=batch_num_atoms_per_word[i], dim=0)
              for i in range(len(sents_embedding))], max_len=self.max_atoms_in_sentence, padding_value=0)
-        cat_embedding_repeat = pad_sequence(
-            [torch.repeat_interleave(input=cat_embedding[i], repeats=batch_num_atoms_per_word[i], dim=0)
-             for i in range(len(cat_embedding))], max_len=self.max_atoms_in_sentence, padding_value=0)
 
-        # positional encoding of atoms and cat embedding to form the atom embedding
-        position_encoding = self.position_encoding(cat_embedding_repeat)
+        atoms_embedding = self.atom_encoder(batch_atoms) * math.sqrt(self.dim_emb_atom)
+        atoms_embedding = self.position_encoder(atoms_embedding)
+        atoms_embedding = atoms_embedding.permute(1, 0, 2)
+        atoms_embedding = self.transformer(atoms_embedding, src_mask)
+        atoms_embedding = atoms_embedding.permute(1, 0, 2)
 
         # cat
-        atoms_sentences_encoding = torch.cat([sents_embedding_repeat, position_encoding], dim=2)
+        atoms_sentences_encoding = torch.cat([sents_embedding_repeat, atoms_embedding], dim=2)
         atoms_encoding = self.linker_encoder(atoms_sentences_encoding)
-        atoms_encoding = self.dropout(atoms_encoding)
 
         # linking per atom type
         batch_size, atom_vocan_size, _ = batch_pos_idx.shape
         link_weights = torch.zeros(atom_vocan_size, batch_size, self.max_atoms_in_one_type // 2,
                                    self.max_atoms_in_one_type // 2, device=self.device)
-        for atom_type in self.sub_atoms_type_list:
+        for atom_type in list(atom_map_redux.keys()):
             pos_encoding = self.make_sinkhorn_inputs(atoms_encoding, batch_pos_idx, atom_type)
             neg_encoding = self.make_sinkhorn_inputs(atoms_encoding, batch_neg_idx, atom_type)
 
@@ -251,23 +280,25 @@ class Linker(Module):
 
         # For each batch of training data...
         with tqdm(training_dataloader, unit="batch") as tepoch:
+            src_mask = generate_square_subsequent_mask(self.max_atoms_in_sentence).to(self.device)
             for batch in tepoch:
                 # Unpack this training batch from our dataloader
                 batch_num_atoms = batch[0].to(self.device)
-                batch_pos_idx = batch[1].to(self.device)
-                batch_neg_idx = batch[2].to(self.device)
-                batch_true_links = batch[3].to(self.device)
-                batch_sentences_tokens = batch[4].to(self.device)
-                batch_sentences_mask = batch[5].to(self.device)
+                batch_atoms_tok = batch[1].to(self.device)
+                batch_pos_idx = batch[2].to(self.device)
+                batch_neg_idx = batch[3].to(self.device)
+                batch_true_links = batch[4].to(self.device)
+                batch_sentences_tokens = batch[5].to(self.device)
+                batch_sentences_mask = batch[6].to(self.device)
 
                 self.optimizer.zero_grad()
 
                 # get sentence embedding from BERT which is already trained
                 output = self.Supertagger.forward(batch_sentences_tokens, batch_sentences_mask)
 
-                # Run the kinker on the categories predictions
-                logits_predictions = self(batch_num_atoms, batch_pos_idx, batch_neg_idx, output['word_embeding'],
-                                          output['last_hidden_state'])
+                # Run the Linker on the atoms
+                logits_predictions = self(batch_num_atoms, batch_atoms_tok, src_mask, batch_pos_idx, batch_neg_idx,
+                                          output['word_embeding'])
 
                 linker_loss = self.cross_entropy_loss(logits_predictions, batch_true_links)
                 # Perform a backward pass to calculate the gradients.
@@ -294,21 +325,22 @@ class Linker(Module):
 
     def eval_batch(self, batch):
         batch_num_atoms = batch[0].to(self.device)
-        batch_pos_idx = batch[1].to(self.device)
-        batch_neg_idx = batch[2].to(self.device)
-        batch_true_links = batch[3].to(self.device)
-        batch_sentences_tokens = batch[4].to(self.device)
-        batch_sentences_mask = batch[5].to(self.device)
+        batch_atoms_tok = batch[1].to(self.device)
+        batch_pos_idx = batch[2].to(self.device)
+        batch_neg_idx = batch[3].to(self.device)
+        batch_true_links = batch[4].to(self.device)
+        batch_sentences_tokens = batch[5].to(self.device)
+        batch_sentences_mask = batch[6].to(self.device)
 
         output = self.Supertagger.forward(batch_sentences_tokens, batch_sentences_mask)
-        logits_predictions = self(batch_num_atoms, batch_pos_idx, batch_neg_idx, output['word_embeding'],
-                                  output['last_hidden_state'])  # atom_vocab, batch_size, max atoms in one type, max atoms in one type
+
+        src_mask = generate_square_subsequent_mask(self.max_atoms_in_sentence).to(self.device)
+        logits_predictions = self(batch_num_atoms, batch_atoms_tok, src_mask, batch_pos_idx, batch_neg_idx, output[
+            'word_embeding'])  # atom_vocab, batch_size, max atoms in one type, max atoms in one type
         axiom_links_pred = torch.argmax(logits_predictions, dim=3)  # atom_vocab, batch_size, max atoms in one type
 
         print('\n')
         print("Tokens de la phrase : ", batch_sentences_tokens[1])
-        print("Polarités + des atoms de la phrase : ", batch_pos_idx[1][2][:50])
-        print("Polarités - des atoms de la phrase : ", batch_neg_idx[1][2][:50])
         print("Les vrais liens de la catégorie n : ", batch_true_links[1][2][:100])
         print("Les prédictions : ", axiom_links_pred[2][1][:100])
         print('\n')
@@ -340,10 +372,11 @@ class Linker(Module):
         try:
             params = torch.load(model_file, map_location=self.device)
             args = params['args']
-            self.atom_map = args['atom_map']
             self.max_atoms_in_sentence = args['max_atoms_in_sentence']
-            self.atoms_tokenizer = AtomTokenizer(self.atom_map, self.max_atoms_in_sentence)
-            self.atoms_embedding.load_state_dict(params['atoms_embedding'])
+            self.atoms_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence)
+            self.atom_encoder.load_state_dict(params['atom_encoder'])
+            self.position_encoder.load_state_dict(params['position_encoder'])
+            self.transformer.load_state_dict(params['transformer'])
             self.linker_encoder.load_state_dict(params['linker_encoder'])
             self.pos_transformation.load_state_dict(params['pos_transformation'])
             self.neg_transformation.load_state_dict(params['neg_transformation'])
@@ -361,8 +394,9 @@ class Linker(Module):
         self.cpu()
 
         torch.save({
-            'args': dict(atom_map=self.atom_map, max_atoms_in_sentence=self.max_atoms_in_sentence),
-            'atoms_embedding': self.atoms_embedding.state_dict(),
+            'atom_encoder': self.atom_encoder.state_dict(),
+            'position_encoder': self.position_encoder,
+            'transformer': self.transformer,
             'linker_encoder': self.linker_encoder.state_dict(),
             'pos_transformation': self.pos_transformation.state_dict(),
             'neg_transformation': self.neg_transformation.state_dict(),
@@ -384,4 +418,4 @@ class Linker(Module):
         return torch.stack([torch.stack([bsd_tensor.select(0, index=i).select(0, index=int(atom)).to(self.device)
                                          if atom != -1 else torch.zeros(self.dim_cat_out, device=self.device)
                                          for atom in sentence])
-                            for i, sentence in enumerate(positional_ids[:, atom_map_redux[atom_type], :])])
+                            for i, sentence in enumerate(positional_ids[:, self.atom_map_redux[atom_type], :])])
diff --git a/Linker/PositionEncoding.py b/Linker/PositionalEncoding.py
similarity index 81%
rename from Linker/PositionEncoding.py
rename to Linker/PositionalEncoding.py
index 5389a7a17bf6cb8c2866b3ba80f6e9bd5eff63ce..19e1b96c0bd17b9867d9d24bda52a619e7559e4e 100644
--- a/Linker/PositionEncoding.py
+++ b/Linker/PositionalEncoding.py
@@ -5,7 +5,7 @@ import math
 
 class PositionalEncoding(nn.Module):
 
-    def __init__(self, d_model, dropout=0.1, max_len=5000):
+    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
         super().__init__()
         self.dropout = nn.Dropout(p=dropout)
 
@@ -19,7 +19,7 @@ class PositionalEncoding(nn.Module):
     def forward(self, x):
         """
         Args:
-            x: Tensor, shape [batch_size,seq_len, embedding_dim]
+            x: Tensor, shape [batch_size, seq_len, mbedding_dim]
         """
         x = x + self.pe[:, :x.size(1)]
         return self.dropout(x)
diff --git a/Linker/__init__.py b/Linker/__init__.py
index 3dee6a7ab702ee2ae8df692ab8d544a6a12afe8f..eea58e3d271e21cf9e32bf4e085170ef30e9ef8b 100644
--- a/Linker/__init__.py
+++ b/Linker/__init__.py
@@ -1,4 +1,4 @@
 from .Linker import Linker
 from .atom_map import atom_map
 from .AtomTokenizer import AtomTokenizer
-from .PositionEncoding import PositionalEncoding
\ No newline at end of file
+from .Sinkhorn import *
\ No newline at end of file
diff --git a/Linker/atom_map.py b/Linker/atom_map.py
index 4e0c45e4faed7171fb563685c85f172327dd4295..0df2646a03e4a228eb9223a3eb5f167c4de2ca14 100644
--- a/Linker/atom_map.py
+++ b/Linker/atom_map.py
@@ -15,7 +15,8 @@ atom_map = \
      'txt': 13,
      's': 14,
      's_ppart': 15,
-     '[PAD]': 16
+     "[SEP]":16,
+     '[PAD]': 17
      }
 
 atom_map_redux = {
diff --git a/Linker/eval.py b/Linker/eval.py
index c60bb007b1b2388e2aa2df46d9f4d46fa775f1fa..2c8c578687bec168d04fd1ee81e0357ec2f1dac2 100644
--- a/Linker/eval.py
+++ b/Linker/eval.py
@@ -17,8 +17,8 @@ def mesure_accuracy(batch_true_links, axiom_links_pred, max_atoms_in_one_type):
     batch_true_links : (atom_vocab_size, batch_size, max_atoms_in_one_cat) contains the index of the negative atoms
     axiom_links_pred : (atom_vocab_size, batch_size, max_atoms_in_one_cat) contains the index of the negative atoms
     """
-    padding = max_atoms_in_one_type // 2 -1
-    batch_true_links=batch_true_links.permute(1, 0, 2)
+    padding = max_atoms_in_one_type // 2 - 1
+    batch_true_links = batch_true_links.permute(1, 0, 2)
     correct_links = torch.ones(axiom_links_pred.size())
     correct_links[axiom_links_pred != batch_true_links] = 0
     correct_links[batch_true_links == padding] = 1
@@ -26,4 +26,5 @@ def mesure_accuracy(batch_true_links, axiom_links_pred, max_atoms_in_one_type):
     num_masked_atoms = len(batch_true_links[batch_true_links == padding])
 
     # diviser par nombre de links
-    return (num_correct_links - num_masked_atoms)/(axiom_links_pred.size()[0] * axiom_links_pred.size()[1] * axiom_links_pred.size()[2] - num_masked_atoms)
+    return (num_correct_links - num_masked_atoms) / (
+                axiom_links_pred.size()[0] * axiom_links_pred.size()[1] * axiom_links_pred.size()[2] - num_masked_atoms)
diff --git a/Linker/utils_linker.py b/Linker/utils_linker.py
index 955d5571c10518c09f795c6cf19f61f95a91f15e..f2f418ff079dc8a17d5fb18094535566c75d58ec 100644
--- a/Linker/utils_linker.py
+++ b/Linker/utils_linker.py
@@ -3,6 +3,7 @@ import regex
 import torch
 from torch.nn import Sequential, Linear, Dropout, GELU
 from torch.nn import Module
+
 from Linker.atom_map import atom_map, atom_map_redux
 from utils import pad_sequence
 
@@ -28,34 +29,34 @@ regex_categories_axiom_links = r'\w+\(\d+,(?:((?R))|(\w+))*,?(?:((?R))|(\w+))*\)
 regex_categories = r'\w+\(\d+,(?:((?R))|(\w+))*,?(?:((?R))|(\w+))*\)'
 
 
-################################ Liste des atoms avec _i ########################################
-def get_axiom_links(max_atoms_in_one_type, sub_atoms_type_list, atoms_polarity, batch_axiom_links):
+# region get true axiom links
+def get_axiom_links(max_atoms_in_one_type, atoms_polarity, batch_axiom_links):
     r"""
     Args:
         max_atoms_in_one_type : configuration
-        sub_atoms_type_list : list of atom type to match
         atoms_polarity : (batch_size, max_atoms_in_sentence)
         batch_axiom_links : (batch_size, len_sentence) categories with the _i which allows linking atoms
     Returns:
         batch_true_links : (batch_size, atom_vocab_size, max_atoms_in_one_cat) contains the index of the negative atoms
     """
     atoms_batch = get_atoms_links_batch(batch_axiom_links)
+    atoms_batch = list(map(lambda sentence: sentence.split(" "), atoms_batch))
     linking_plus_to_minus_all_types = []
-    for atom_type in sub_atoms_type_list:
+    for atom_type in list(atom_map_redux.keys()):
         # filtrer sur atom_batch que ce type puis filtrer avec les indices sur atom polarity
         l_polarity_plus = [[x for i, x in enumerate(atoms_batch[s_idx]) if atoms_polarity[s_idx, i]
-                            and bool(re.search(atom_type + "_", atoms_batch[s_idx][i]))] for s_idx in
+                            and bool(re.match(r"" + atom_type + "(_{1}\w+)?_\d+\Z", atoms_batch[s_idx][i]))] for s_idx in
                            range(len(atoms_batch))]
         l_polarity_minus = [[x for i, x in enumerate(atoms_batch[s_idx]) if not atoms_polarity[s_idx, i]
-                             and bool(re.search(atom_type + "_", atoms_batch[s_idx][i]))] for s_idx in
+                             and bool(re.match(r"" + atom_type + "(_{1}\w+)?_\d+\Z", atoms_batch[s_idx][i]))] for s_idx in
                             range(len(atoms_batch))]
 
         linking_plus_to_minus = pad_sequence(
             [torch.as_tensor(
-                [l_polarity_minus[s_idx].index(x) if x in l_polarity_minus[s_idx] else max_atoms_in_one_type // 2 -1 for
-                 i, x in enumerate(l_polarity_plus[s_idx])], dtype=torch.long)
-             for s_idx in range(len(atoms_batch))], max_len=max_atoms_in_one_type // 2,
-            padding_value=max_atoms_in_one_type // 2 -1)
+                [l_polarity_minus[s_idx].index(x) if x in l_polarity_minus[s_idx] else max_atoms_in_one_type // 2 - 1
+                 for i, x in enumerate(l_polarity_plus[s_idx])], dtype=torch.long)
+                for s_idx in range(len(atoms_batch))], max_len=max_atoms_in_one_type // 2,
+            padding_value=max_atoms_in_one_type // 2 - 1)
 
         linking_plus_to_minus_all_types.append(linking_plus_to_minus)
 
@@ -74,15 +75,13 @@ def category_to_atoms_axiom_links(category, categories_to_atoms):
     if category.startswith("GOAL:"):
         word, cat = category.split(':')
         return category_to_atoms_axiom_links(cat, categories_to_atoms)
-    elif category == "let":
-        return []
     elif True in res:
-        return [category]
+        return " " + category
     else:
         category_cut = regex.match(regex_categories_axiom_links, category).groups()
         category_cut = [cat for cat in category_cut if cat is not None]
         for cat in category_cut:
-            categories_to_atoms += category_to_atoms_axiom_links(cat, [])
+            categories_to_atoms += category_to_atoms_axiom_links(cat, "")
         return categories_to_atoms
 
 
@@ -95,14 +94,26 @@ def get_atoms_links_batch(category_batch):
     """
     batch = []
     for sentence in category_batch:
-        categories_to_atoms = []
+        categories_to_atoms = ""
         for category in sentence:
-            categories_to_atoms += category_to_atoms_axiom_links(category, [])
+            if category != "let" and not category.startswith("GOAL:"):
+                categories_to_atoms += category_to_atoms_axiom_links(category, "")
+                categories_to_atoms += " [SEP]"
+                categories_to_atoms = categories_to_atoms.lstrip()
+            elif category.startswith("GOAL:"):
+                categories_to_atoms += category_to_atoms_axiom_links(category, "")
+                categories_to_atoms = categories_to_atoms.lstrip()
         batch.append(categories_to_atoms)
     return batch
 
 
-################################ Liste des atoms ########################################
+print("test to create links ",
+      get_axiom_links(20, torch.stack([torch.as_tensor([False, True, False, False, False, True, False, True, False, False, True, False, False, False, True, False, False, True, False, True, False, False, True, False, False, False, True])]),
+                      [['dr(0,np_1,n_2)', 'n_2', 'dr(0,dl(0,np_1,np_3),np_4)', 'dr(0,np_4,n_5)', 'n_6', 'dl(0,n_6,n_5)', 'dr(0,dl(0,np_3,np_7),np_8)', 'dr(0,np_8,np_9)', 'np_9', 'GOAL:np_7']]))
+
+# endregion
+
+# region get atoms in sentence
 
 def category_to_atoms(category, categories_to_atoms):
     r"""
@@ -116,15 +127,13 @@ def category_to_atoms(category, categories_to_atoms):
     if category.startswith("GOAL:"):
         word, cat = category.split(':')
         return category_to_atoms(cat, categories_to_atoms)
-    elif category == "let":
-        return []
     elif True in res:
-        return [category]
+        return " " + category
     else:
         category_cut = regex.match(regex_categories, category).groups()
         category_cut = [cat for cat in category_cut if cat is not None]
         for cat in category_cut:
-            categories_to_atoms += category_to_atoms(cat, [])
+            categories_to_atoms += category_to_atoms(cat, "")
         return categories_to_atoms
 
 
@@ -137,14 +146,22 @@ def get_atoms_batch(category_batch):
     """
     batch = []
     for sentence in category_batch:
-        categories_to_atoms = []
+        categories_to_atoms = ""
         for category in sentence:
-            categories_to_atoms += category_to_atoms(category, [])
+            if category != "let":
+                categories_to_atoms += category_to_atoms(category, "")
+                categories_to_atoms += " [SEP]"
+                categories_to_atoms = categories_to_atoms.lstrip()
         batch.append(categories_to_atoms)
     return batch
 
 
-################################ Liste des atoms ########################################
+print(" test for get atoms in categories on ['dr(0,s,np)', 'let']", get_atoms_batch([["dr(0,s,np)", "let"]]))
+
+
+# endregion
+
+# region calculate num atoms per category
 
 def category_to_num_atoms(category, categories_to_atoms):
     r"""
@@ -182,12 +199,22 @@ def get_num_atoms_batch(category_batch, max_len_sentence):
     for sentence in category_batch:
         num_atoms_sentence = []
         for category in sentence:
-            num_atoms_sentence.append(category_to_num_atoms(category, 0))
+            num_atoms_in_word = category_to_num_atoms(category, 0)
+            # add 1 because for word we have SEP at the end
+            if category != "let":
+                num_atoms_in_word += 1
+            num_atoms_sentence.append(num_atoms_in_word)
         batch.append(torch.as_tensor(num_atoms_sentence))
     return pad_sequence(batch, max_len=max_len_sentence, padding_value=0)
 
 
-################################ Polarity ###############################################
+print(" test for get number of atoms in categories on ['dr(0,s,np)', 'let']",
+      get_num_atoms_batch([["dr(0,s,np)", "let"]], 10))
+
+
+# endregion
+
+# region get polarity
 
 def category_to_atoms_polarity(category, polarity):
     r"""
@@ -207,8 +234,6 @@ def category_to_atoms_polarity(category, polarity):
             category_to_polarity.append(True)
         else:
             category_to_polarity += category_to_atoms_polarity(cat, True)
-    elif category == "let":
-        pass
     # le mot a une category atomique
     elif True in res:
         category_to_polarity.append(not polarity)
@@ -270,58 +295,91 @@ def find_pos_neg_idexes(atoms_batch):
     for sentence in atoms_batch:
         list_atoms = []
         for category in sentence:
-            for at in category_to_atoms_polarity(category, True):
-                list_atoms.append(at)
+            if category == "let":
+                pass
+            else:
+                for at in category_to_atoms_polarity(category, True):
+                    list_atoms.append(at)
+                list_atoms.append(False)
         list_batch.append(list_atoms)
     return list_batch
 
 
-################################ GOAL ###############################################
+print(" test for get polarities for atoms in categories on ['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'dr(0,np,n)', 'n', 'dl(0,n,n)', 'dr(0,dl(0,np,np),np)', 'dr(0,np,np)', 'np']",
+      find_pos_neg_idexes([['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'dr(0,np,n)', 'n', 'dl(0,n,n)', 'dr(0,dl(0,np,np),np)', 'dr(0,np,np)', 'np']]))
+
+
+# endregion
+
+# region get atoms and polarities with GOAL
 
 def get_GOAL(max_atoms_in_sentence, categories_batch):
     polarities = find_pos_neg_idexes(categories_batch)
     atoms_batch = get_atoms_batch(categories_batch)
+    atoms_batch_for_polarities = list(
+        map(lambda sentence: [item for item in sentence.split(" ")], atoms_batch))
     for s_idx in range(len(atoms_batch)):
-        for atom_type in list(atom_map.keys()):
-            list_plus = [x for i, x in enumerate(atoms_batch[s_idx]) if polarities[s_idx][i]
-                         and atoms_batch[s_idx][i] == atom_type]
-            list_minus = [x for i, x in enumerate(atoms_batch[s_idx]) if not polarities[s_idx][i]
-                          and atoms_batch[s_idx][i] == atom_type]
+        for atom_type in list(atom_map_redux.keys()):
+            list_plus = [x for i, x in enumerate(atoms_batch_for_polarities[s_idx]) if polarities[s_idx][i]
+                         and bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch_for_polarities[s_idx][i]))]
+            list_minus = [x for i, x in enumerate(atoms_batch_for_polarities[s_idx]) if not polarities[s_idx][i]
+                          and bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch_for_polarities[s_idx][i]))]
             while len(list_minus) != len(list_plus):
                 if len(list_minus) > len(list_plus):
-                    atoms_batch[s_idx].append(atom_type)
+                    atoms_batch[s_idx] += " " + atom_type
+                    atoms_batch_for_polarities[s_idx].append(atom_type)
                     polarities[s_idx].append(True)
                 else:
-                    atoms_batch[s_idx].append(atom_type)
+                    atoms_batch[s_idx] += " " + atom_type
+                    atoms_batch_for_polarities[s_idx].append(atom_type)
                     polarities[s_idx].append(False)
-                list_plus = [x for i, x in enumerate(atoms_batch[s_idx]) if polarities[s_idx][i]
-                             and atoms_batch[s_idx][i] == atom_type]
-                list_minus = [x for i, x in enumerate(atoms_batch[s_idx]) if not polarities[s_idx][i]
-                              and atoms_batch[s_idx][i] == atom_type]
+                list_plus = [x for i, x in enumerate(atoms_batch_for_polarities[s_idx]) if polarities[s_idx][i]
+                             and atoms_batch_for_polarities[s_idx][i] == atom_type]
+                list_minus = [x for i, x in enumerate(atoms_batch_for_polarities[s_idx]) if not polarities[s_idx][i]
+                              and atoms_batch_for_polarities[s_idx][i] == atom_type]
 
     return atoms_batch, pad_sequence([torch.as_tensor(polarities[i], dtype=torch.bool) for i in range(len(polarities))],
                                      max_len=max_atoms_in_sentence, padding_value=0)
 
 
-################################ Prepare encoding ###############################################
+print(" test for get GOAL on ['dr(0,s,np)', 's']", get_GOAL(12, [["dr(0,s,np)", "s"]]))
+
+
+# endregion
+
+# region get idx for pos and neg
 
-def get_pos_idx(atoms_batch_tokenized, atoms_polarity_batch, max_atoms_in_one_type):
-    inverse_atom_map = {v: k for k, v in atom_map.items()}
+def get_pos_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type):
+    atoms_batch_for_polarities = list(
+        map(lambda sentence: sentence.split(" "), atoms_batch))
     pos_idx = [pad_sequence([torch.as_tensor([i for i, x in enumerate(sentence) if bool(
-        re.match(r"" + atom_type + "_?\w*", inverse_atom_map[int(atoms_batch_tokenized[s_idx][i])])) and
-                                              atoms_polarity_batch[s_idx][i]]) for s_idx, sentence in
-                             enumerate(atoms_batch_tokenized)], max_len=max_atoms_in_one_type // 2, padding_value=-1)
+        re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch_for_polarities[s_idx][i])) and
+                                              atoms_polarity_batch[s_idx][i]])
+                             for s_idx, sentence in enumerate(atoms_batch_for_polarities)],
+                            max_len=max_atoms_in_one_type // 2, padding_value=-1)
                for atom_type in list(atom_map_redux.keys())]
 
     return torch.stack(pos_idx).permute(1, 0, 2)
 
 
-def get_neg_idx(atoms_batch_tokenized, atoms_polarity_batch, max_atoms_in_one_type):
-    inverse_atom_map = {v: k for k, v in atom_map.items()}
-    neg_idx = [pad_sequence([torch.as_tensor([i for i, x in enumerate(sentence) if bool(
-        re.match(r"" + atom_type + "_?\w*", inverse_atom_map[int(atoms_batch_tokenized[s_idx][i])])) and
-                                              not atoms_polarity_batch[s_idx][i]]) for s_idx, sentence in
-                             enumerate(atoms_batch_tokenized)], max_len=max_atoms_in_one_type // 2, padding_value=-1)
+def get_neg_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type):
+    atoms_batch_for_polarities = list(
+        map(lambda sentence: sentence.split(" "), atoms_batch))
+    pos_idx = [pad_sequence([torch.as_tensor([i for i, x in enumerate(sentence) if bool(
+        re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch_for_polarities[s_idx][i])) and not
+                                              atoms_polarity_batch[s_idx][i]])
+                             for s_idx, sentence in enumerate(atoms_batch_for_polarities)],
+                            max_len=max_atoms_in_one_type // 2, padding_value=-1)
                for atom_type in list(atom_map_redux.keys())]
 
-    return torch.stack(neg_idx).permute(1, 0, 2)
+    return torch.stack(pos_idx).permute(1, 0, 2)
+
+
+print(" test for cut into pos neg on ['s np [SEP] s [SEP] np s s n n']", get_neg_idx(['s np [SEP] s [SEP] np s s n n'],
+                                                                                     torch.as_tensor(
+                                                                                         [[False, True, False, False,
+                                                                                           False, False, True, True,
+                                                                                           False, True,
+                                                                                           False, False]]), 10))
+
+# endregion
diff --git a/bash_GPU.sh b/bash_GPU.sh
index d98ae661d3d2d2e7a7f37f4f5592ff4a2a0ff0ee..500c7326f767af683a3c0a31e2e0026f3f2e74d3 100644
--- a/bash_GPU.sh
+++ b/bash_GPU.sh
@@ -1,6 +1,6 @@
 #!/bin/sh
-#SBATCH --job-name=Deepgrail_Linker_9000
-#SBATCH --partition=RTX6000Node
+#SBATCH --job-name=Deepgrail_Linker
+#SBATCH --partition=GPUNodes
 #SBATCH --gres=gpu:1
 #SBATCH --mem=32000
 #SBATCH --gres-flags=enforce-binding
diff --git a/command_line.txt b/command_line.txt
index 7abcb351f1aa129d614ca8f9e6bb0bfea94def85..31b0fc48391da24d8765df9b88505d761357daa3 100644
--- a/command_line.txt
+++ b/command_line.txt
@@ -1,4 +1,4 @@
-scp -r cdepourt@osirim-slurm.irit.fr:projets/deepgrail2/deepgrail_RNN_with_linker/TensorBoard/Tranning_19-05_09-49/logs /home/cdepourt/Bureau/deepgrail_RNN_with_linker/logs
+scp -r cdepourt@osirim-slurm.irit.fr:projets/deepgrailGPU1/deepgrail_RNN_with_linker/TensorBoard/ /home/cdepourt/Bureau/deepgrail_RNN_with_linker/TensorBoard
 
 rsync -av -e ssh --exclude="__pycache__" --exclude="venv" --exclude=".git" --exclude=".idea"  -r /home/cdepourt/Bureau/deepgrail_RNN_with_linker cdepourt@osirim-slurm.irit.fr:projets/deepgrail2
 
diff --git a/logs/logs/Accuracy_Train/events.out.tfevents.1655740922.co2-slurm-ng04.19806.1 b/logs/logs/Accuracy_Train/events.out.tfevents.1655740922.co2-slurm-ng04.19806.1
new file mode 100644
index 0000000000000000000000000000000000000000..09582fa3c93ec31068f014c18ba3bd65ffe59935
Binary files /dev/null and b/logs/logs/Accuracy_Train/events.out.tfevents.1655740922.co2-slurm-ng04.19806.1 differ
diff --git a/logs/logs/Accuracy_Validation/events.out.tfevents.1655740922.co2-slurm-ng04.19806.3 b/logs/logs/Accuracy_Validation/events.out.tfevents.1655740922.co2-slurm-ng04.19806.3
new file mode 100644
index 0000000000000000000000000000000000000000..5442af218aa6bb356f2080bc5dd84cdf9ecfb6cc
Binary files /dev/null and b/logs/logs/Accuracy_Validation/events.out.tfevents.1655740922.co2-slurm-ng04.19806.3 differ
diff --git a/logs/logs/Loss_Train/events.out.tfevents.1655740922.co2-slurm-ng04.19806.2 b/logs/logs/Loss_Train/events.out.tfevents.1655740922.co2-slurm-ng04.19806.2
new file mode 100644
index 0000000000000000000000000000000000000000..c9b5e3bf3131b8fa559236aa17551dc3f0055839
Binary files /dev/null and b/logs/logs/Loss_Train/events.out.tfevents.1655740922.co2-slurm-ng04.19806.2 differ
diff --git a/logs/logs/Loss_Validation/events.out.tfevents.1655740922.co2-slurm-ng04.19806.4 b/logs/logs/Loss_Validation/events.out.tfevents.1655740922.co2-slurm-ng04.19806.4
new file mode 100644
index 0000000000000000000000000000000000000000..c3af9b258413271b78351acffec0f41824f7cec3
Binary files /dev/null and b/logs/logs/Loss_Validation/events.out.tfevents.1655740922.co2-slurm-ng04.19806.4 differ
diff --git a/logs/logs/events.out.tfevents.1655739317.co2-slurm-ng04.19806.0 b/logs/logs/events.out.tfevents.1655739317.co2-slurm-ng04.19806.0
new file mode 100644
index 0000000000000000000000000000000000000000..d8f3c0aa95cc484017700deac105cacffc3cb025
Binary files /dev/null and b/logs/logs/events.out.tfevents.1655739317.co2-slurm-ng04.19806.0 differ