diff --git a/Configuration/Configuration.py b/Configuration/Configuration.py
index 12a4b5f29755681ec929e73496d9bd3921fc9fc4..f622a57b821df167c83a0d8341570d3cf4d6d48a 100644
--- a/Configuration/Configuration.py
+++ b/Configuration/Configuration.py
@@ -1,19 +1,14 @@
 import os
 from configparser import ConfigParser
 
-# Read configuration file
-path_current_directory = os.path.dirname(__file__)
-path_config_file = os.path.join(path_current_directory, 'config.ini')
-config = ConfigParser()
-config.read(path_config_file)
 
-# region Get section
+def read_config():
+    # Read configuration file
+    path_current_directory = os.path.dirname(__file__)
+    path_config_file = os.path.join(path_current_directory, 'config.ini')
+    config = ConfigParser()
+    config.read(path_config_file)
 
-version = config["VERSION"]
+    return config
 
-datasetConfig = config["DATASET_PARAMS"]
-modelEncoderConfig = config["MODEL_ENCODER"]
-modelLinkerConfig = config["MODEL_LINKER"]
-modelTrainingConfig = config["MODEL_TRAINING"]
 
-# endregion Get section
diff --git a/Configuration/config.ini b/Configuration/config.ini
index d4527a0c0881a9543d959b5b744c39cc27519dd6..e26d80f63bf4ff172cee6b5946384c06ca8c9786 100644
--- a/Configuration/config.ini
+++ b/Configuration/config.ini
@@ -2,31 +2,30 @@
 transformers = 4.16.2
 
 [DATASET_PARAMS]
-symbols_vocab_size=26
-atom_vocab_size=18
-max_len_sentence=83
-max_atoms_in_sentence=875
-max_atoms_in_one_type=324
+symbols_vocab_size = 26
+atom_vocab_size = 18
+max_len_sentence = 83
+max_atoms_in_sentence = 238
+max_atoms_in_one_type = 102
 
 [MODEL_ENCODER]
 dim_encoder = 768
 
 [MODEL_LINKER]
-nhead=8
+nhead = 8
 dim_emb_atom = 256
-dim_feedforward_transformer = 768
-num_layers=3
-dim_cat_bert_out=1024
-dim_cat_inter=768
-dim_cat_out=512
-dim_intermediate_FFN=256
-dim_pre_sinkhorn_transfo=64
-dropout=0.15
-sinkhorn_iters=5
+dim_feedforward_transformer = 512
+num_layers = 3
+dim_cat_out = 256
+dim_intermediate_ffn = 128
+dim_pre_sinkhorn_transfo = 32
+dropout = 0.1
+sinkhorn_iters = 5
 
 [MODEL_TRAINING]
-batch_size=32
-pretrain_linker_epochs=1
-training_epoch=15
-seed_val=42
-learning_rate=5e-3
\ No newline at end of file
+batch_size = 32
+pretrain_linker_epochs = 1
+epoch = 1
+seed_val = 42
+learning_rate = 2e-3
+
diff --git a/Linker/AtomTokenizer.py b/Linker/AtomTokenizer.py
index c72b73e3ea4720daca99b844098b65b7da1e0e74..1f5c1a1c95998b40390a6839485e680f5d79bacf 100644
--- a/Linker/AtomTokenizer.py
+++ b/Linker/AtomTokenizer.py
@@ -3,6 +3,9 @@ from utils import pad_sequence
 
 
 class AtomTokenizer(object):
+    r"""
+    Tokenizer for the atoms with padding
+    """
     def __init__(self, atom_map, max_atoms_in_sentence):
         self.atom_map = atom_map
         self.max_atoms_in_sentence = max_atoms_in_sentence
@@ -14,14 +17,34 @@ class AtomTokenizer(object):
         return len(self.atom_map)
 
     def convert_atoms_to_ids(self, atom):
+        r"""
+        Convert a atom to its id
+        :param atom: atom string
+        :return: atom id
+        """
         return self.atom_map[str(atom)]
 
     def convert_sents_to_ids(self, sentences):
+        r"""
+        Convert sentences to ids
+        :param sentences: List of atoms in a sentence
+        :return: List of atoms'ids
+        """
         return torch.as_tensor([self.convert_atoms_to_ids(atom) for atom in sentences])
 
     def convert_batchs_to_ids(self, batchs_sentences):
+        r"""
+        Convert a batch of sentences of atoms to the ids
+        :param batchs_sentences: batch of sentences atoms
+        :return: list of list of atoms'ids
+        """
         return torch.as_tensor(pad_sequence([self.convert_sents_to_ids(sents) for sents in batchs_sentences],
                                             max_len=self.max_atoms_in_sentence, padding_value=self.pad_token_id))
 
     def convert_ids_to_atoms(self, ids):
+        r"""
+        Translate id to atom
+        :param ids: atom id
+        :return: atom string
+        """
         return [self.inverse_atom_map[int(i)] for i in ids]
diff --git a/Linker/Linker.py b/Linker/Linker.py
index ff914d1b74b97fb1ad72c9d66b185d89c06b8cdb..1006b8f8f4fe3070e11f8139bc09526e99c29591 100644
--- a/Linker/Linker.py
+++ b/Linker/Linker.py
@@ -1,15 +1,13 @@
+import datetime
 import math
 import os
-import re
 import sys
-import datetime
-
 import time
 
 import torch
 import torch.nn.functional as F
 from torch.nn import Sequential, LayerNorm, Module, Linear, Dropout, TransformerEncoderLayer, TransformerEncoder, \
-    Embedding
+    Embedding, GELU
 from torch.optim import AdamW
 from torch.optim.lr_scheduler import StepLR
 from torch.utils.data import TensorDataset, random_split
@@ -17,13 +15,14 @@ from torch.utils.tensorboard import SummaryWriter
 from tqdm import tqdm
 
 from Configuration import Configuration
-from Linker.PositionalEncoding import PositionalEncoding
+from .AtomTokenizer import AtomTokenizer
+from .PositionalEncoding import PositionalEncoding
 from Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn
-from Linker.AtomTokenizer import AtomTokenizer
 from Linker.atom_map import atom_map, atom_map_redux
-from Linker.eval import mesure_accuracy, SinkhornLoss
-from Linker.utils_linker import FFN, get_axiom_links, get_GOAL, get_pos_idx, get_num_atoms_batch, get_neg_idx
-from Supertagger import SuperTagger
+from Linker.eval import measure_accuracy, SinkhornLoss
+from Linker.utils_linker import FFN, get_axiom_links, get_GOAL, get_pos_idx, get_neg_idx, get_atoms_batch, \
+    find_pos_neg_idexes, get_num_atoms_batch
+from SuperTagger import SuperTagger
 from utils import pad_sequence
 
 
@@ -60,42 +59,41 @@ class Linker(Module):
     def __init__(self, supertagger_path_model):
         super(Linker, self).__init__()
 
+        config = Configuration.read_config()
+        datasetConfig = config["DATASET_PARAMS"]
+        modelEncoderConfig = config["MODEL_ENCODER"]
+        modelLinkerConfig = config["MODEL_LINKER"]
+        modelTrainingConfig = config["MODEL_TRAINING"]
+
         # region parameters
-        dim_encoder = int(Configuration.modelEncoderConfig['dim_encoder'])
+        dim_encoder = int(modelEncoderConfig['dim_encoder'])
         # atom settings
-        atom_vocab_size = int(Configuration.datasetConfig['atom_vocab_size'])
-        # out bert
-        dim_cat_bert_out = int(Configuration.modelLinkerConfig['dim_cat_bert_out'])
+        atom_vocab_size = int(datasetConfig['atom_vocab_size'])
         # Transformer
-        self.nhead = int(Configuration.modelLinkerConfig['nhead'])
-        self.dim_emb_atom = int(Configuration.modelLinkerConfig['dim_emb_atom'])
-        self.dim_feedforward_transformer = int(Configuration.modelLinkerConfig['dim_feedforward_transformer'])
-        self.num_layers = int(Configuration.modelLinkerConfig['num_layers'])
+        self.nhead = int(modelLinkerConfig['nhead'])
+        self.dim_emb_atom = int(modelLinkerConfig['dim_emb_atom'])
+        self.dim_feedforward_transformer = int(modelLinkerConfig['dim_feedforward_transformer'])
+        self.num_layers = int(modelLinkerConfig['num_layers'])
         # torch cat
-        self.dim_cat_inter = int(Configuration.modelLinkerConfig['dim_cat_out'])
-        self.dim_cat_out = int(Configuration.modelLinkerConfig['dim_cat_out'])
-        dim_intermediate_FFN = int(Configuration.modelLinkerConfig['dim_intermediate_FFN'])
-        dim_pre_sinkhorn_transfo = int(Configuration.modelLinkerConfig['dim_pre_sinkhorn_transfo'])
-        dropout = float(Configuration.modelLinkerConfig['dropout'])
+        dropout = float(modelLinkerConfig['dropout'])
+        self.dim_cat_out = int(modelLinkerConfig['dim_cat_out'])
+        dim_intermediate_FFN = int(modelLinkerConfig['dim_intermediate_FFN'])
+        dim_pre_sinkhorn_transfo = int(modelLinkerConfig['dim_pre_sinkhorn_transfo'])
         # sinkhorn
-        self.sinkhorn_iters = int(Configuration.modelLinkerConfig['sinkhorn_iters'])
+        self.sinkhorn_iters = int(modelLinkerConfig['sinkhorn_iters'])
         # settings
-        self.max_len_sentence = int(Configuration.datasetConfig['max_len_sentence'])
-        self.max_atoms_in_sentence = int(Configuration.datasetConfig['max_atoms_in_sentence'])
-        self.max_atoms_in_one_type = int(Configuration.datasetConfig['max_atoms_in_one_type'])
-        learning_rate = float(Configuration.modelTrainingConfig['learning_rate'])
+        self.max_len_sentence = int(datasetConfig['max_len_sentence'])
+        self.max_atoms_in_sentence = int(datasetConfig['max_atoms_in_sentence'])
+        self.max_atoms_in_one_type = int(datasetConfig['max_atoms_in_one_type'])
+        learning_rate = float(modelTrainingConfig['learning_rate'])
         self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
         # endregion
 
-        # Supertagger for categories
+        # SuperTagger for categories
         supertagger = SuperTagger()
         supertagger.load_weights(supertagger_path_model)
         self.Supertagger = supertagger
         self.Supertagger.model.to(self.device)
-        self.word_cat_encoder = Sequential(
-            Linear(dim_encoder * 2, dim_cat_bert_out),
-            Dropout(dropout),
-            LayerNorm(dim_cat_bert_out, eps=1e-8))
 
         # Atoms embedding
         self.atoms_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence)
@@ -110,19 +108,23 @@ class Linker(Module):
         self.transformer = TransformerEncoder(encoder_layer, num_layers=self.num_layers)
 
         # Concatenation with word embedding
-        dim_cat = dim_cat_bert_out + self.dim_emb_atom
+        dim_cat = dim_encoder + self.dim_emb_atom
         self.linker_encoder = Sequential(
             Linear(dim_cat, self.dim_cat_out),
+            GELU(),
             Dropout(dropout),
-            LayerNorm(self.dim_cat_out, eps=1e-8))
+            LayerNorm(self.dim_cat_out, eps=1e-8)
+        )
 
         # Division into positive and negative
         self.pos_transformation = Sequential(
             FFN(self.dim_cat_out, dim_intermediate_FFN, dropout, d_out=dim_pre_sinkhorn_transfo),
-            LayerNorm(dim_pre_sinkhorn_transfo, eps=1e-8))
+            LayerNorm(dim_pre_sinkhorn_transfo, eps=1e-8)
+        )
         self.neg_transformation = Sequential(
             FFN(self.dim_cat_out, dim_intermediate_FFN, dropout, d_out=dim_pre_sinkhorn_transfo),
-            LayerNorm(dim_pre_sinkhorn_transfo, eps=1e-8))
+            LayerNorm(dim_pre_sinkhorn_transfo, eps=1e-8)
+        )
 
         # Learning
         self.cross_entropy_loss = SinkhornLoss()
@@ -145,14 +147,14 @@ class Linker(Module):
         sentences_batch = df_axiom_links["X"].str.strip().tolist()
         sentences_tokens, sentences_mask = self.Supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch)
 
-        atoms_batch, atoms_polarity_batch = get_GOAL(self.max_atoms_in_sentence, df_axiom_links["Z"])
-        atoms_batch_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(
-            list(map(lambda sentence: [item for item in sentence.split(" ")], atoms_batch)))
-
-        num_atoms_per_word = get_num_atoms_batch(df_axiom_links["Z"], self.max_len_sentence)
+        atoms_batch, polarities, num_atoms_per_word = get_GOAL(self.max_len_sentence, df_axiom_links)
+        atoms_polarity_batch = pad_sequence(
+            [torch.as_tensor(polarities[i], dtype=torch.bool) for i in range(len(polarities))],
+            max_len=self.max_atoms_in_sentence, padding_value=0)
+        atoms_batch_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms_batch)
 
-        pos_idx = get_pos_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type, self.max_atoms_in_sentence)
-        neg_idx = get_neg_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type, self.max_atoms_in_sentence)
+        pos_idx = get_pos_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type)
+        neg_idx = get_neg_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type)
 
         truth_links_batch = get_axiom_links(self.max_atoms_in_one_type, atoms_polarity_batch,
                                             df_axiom_links["Y"])
@@ -175,12 +177,11 @@ class Linker(Module):
         print("End preprocess Data")
         return training_dataloader, validation_dataloader
 
-    def forward(self, batch_num_atoms_per_word, batch_atoms, batch_pos_idx, batch_neg_idx, sents_embedding,
-                cats_embedding):
+    def forward(self, batch_num_atoms_per_word, batch_atoms, batch_pos_idx, batch_neg_idx, sents_embedding):
         r"""
         Args:
             batch_num_atoms_per_word : (batch_size, len_sentence) flattened categories
-            batch_atoms : atoms tokenized
+            batch_atoms : atoms tok
             batch_pos_idx : (batch_size, atom_vocab_size, max atom in one cat) flattened categories polarities
             batch_neg_idx : (batch_size, atom_vocab_size, max atom in one cat) flattened categories polarities
             sents_embedding : (batch_size, len_sentence, dim_encoder) output of BERT for context
@@ -191,11 +192,6 @@ class Linker(Module):
         sents_embedding_repeat = pad_sequence(
             [torch.repeat_interleave(input=sents_embedding[i], repeats=batch_num_atoms_per_word[i], dim=0)
              for i in range(len(sents_embedding))], max_len=self.max_atoms_in_sentence, padding_value=0)
-        cats_embedding_repeat = pad_sequence(
-            [torch.repeat_interleave(input=cats_embedding[i], repeats=batch_num_atoms_per_word[i], dim=0)
-             for i in range(len(cats_embedding))], max_len=self.max_atoms_in_sentence, padding_value=0)
-        word_cat_encoding = torch.cat([sents_embedding_repeat, cats_embedding_repeat], dim=2)
-        word_cat_encoding = self.word_cat_encoder(word_cat_encoding)
 
         # atoms emebedding
         src_key_padding_mask = torch.eq(batch_atoms, self.padding_id)
@@ -208,12 +204,12 @@ class Linker(Module):
         atoms_embedding = atoms_embedding.permute(1, 0, 2)
 
         # cat
-        atoms_sentences_encoding = torch.cat([word_cat_encoding, atoms_embedding], dim=2)
+        atoms_sentences_encoding = torch.cat([sents_embedding_repeat, atoms_embedding], dim=2)
         atoms_encoding = self.linker_encoder(atoms_sentences_encoding)
 
         # linking per atom type
-        batch_size, atom_vocan_size, _ = batch_pos_idx.shape
-        link_weights = torch.zeros(atom_vocan_size, batch_size, self.max_atoms_in_one_type // 2,
+        batch_size, atom_vocab_size, _ = batch_pos_idx.shape
+        link_weights = torch.zeros(atom_vocab_size, batch_size, self.max_atoms_in_one_type // 2,
                                    self.max_atoms_in_one_type // 2, device=self.device)
         for atom_type in list(atom_map_redux.keys()):
             pos_encoding = self.make_sinkhorn_inputs(atoms_encoding, batch_pos_idx, atom_type)
@@ -261,7 +257,7 @@ class Linker(Module):
 
             if checkpoint:
                 self.__checkpoint_save(
-                    path=os.path.join("Output", 'linker' + datetime.today().strftime('%d-%m_%H-%M') + '.pt'))
+                    path=os.path.join("Output", 'linker' + datetime.datetime.today().strftime('%d-%m_%H-%M') + '.pt'))
 
             if tensorboard:
                 writer.add_scalars(f'Accuracy', {
@@ -311,7 +307,7 @@ class Linker(Module):
 
                 # Run the Linker on the atoms
                 logits_predictions = self(batch_num_atoms, batch_atoms_tok, batch_pos_idx, batch_neg_idx,
-                                          output['word_embeding'], output['last_hidden_state'])
+                                          output['word_embeding'])
 
                 linker_loss = self.cross_entropy_loss(logits_predictions, batch_true_links, self.max_atoms_in_one_type)
                 # Perform a backward pass to calculate the gradients.
@@ -325,7 +321,7 @@ class Linker(Module):
                 self.optimizer.step()
 
                 pred_axiom_links = torch.argmax(logits_predictions, dim=3)
-                accuracy_train += mesure_accuracy(batch_true_links, pred_axiom_links, self.max_atoms_in_one_type)
+                accuracy_train += measure_accuracy(batch_true_links, pred_axiom_links, self.max_atoms_in_one_type)
 
         self.scheduler.step()
 
@@ -348,8 +344,7 @@ class Linker(Module):
         output = self.Supertagger.forward(batch_sentences_tokens, batch_sentences_mask)
 
         logits_predictions = self(batch_num_atoms, batch_atoms_tok, batch_pos_idx, batch_neg_idx, output[
-            'word_embeding'], output[
-                                      'last_hidden_state'])  # atom_vocab, batch_size, max atoms in one type, max atoms in one type
+            'word_embeding'])  # atom_vocab, batch_size, max atoms in one type, max atoms in one type
         axiom_links_pred = torch.argmax(logits_predictions, dim=3)  # atom_vocab, batch_size, max atoms in one type
 
         print('\n')
@@ -357,7 +352,7 @@ class Linker(Module):
         print("Les prédictions : ", axiom_links_pred[2][1][:100])
         print('\n')
 
-        accuracy = mesure_accuracy(batch_true_links, axiom_links_pred, self.max_atoms_in_one_type)
+        accuracy = measure_accuracy(batch_true_links, axiom_links_pred, self.max_atoms_in_one_type)
         loss = self.cross_entropy_loss(logits_predictions, batch_true_links, self.max_atoms_in_one_type)
 
         return loss, accuracy
@@ -379,19 +374,85 @@ class Linker(Module):
 
         return loss_average / len(dataloader), accuracy_average / len(dataloader)
 
+    def predict_with_categories(self, sentence, categories):
+        r""" Predict the links from a sentence and its categories
+
+        Args :
+            sentence : list of words composing the sentence
+            categories : list of categories (tags) of each word
+        """
+        self.eval()
+        with torch.no_grad():
+            self.cpu()
+            self.device = torch.device("cpu")
+            sentences_tokens, sentences_mask = self.Supertagger.sent_tokenizer.fit_transform_tensors([sentence])
+            nb_sentence, len_sentence = sentences_tokens.shape
+
+            atoms = get_atoms_batch([categories])
+            atoms_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms)
+
+            polarities = find_pos_neg_idexes([categories])
+            polarities = pad_sequence(
+                [torch.as_tensor(polarities[i], dtype=torch.bool) for i in range(len(polarities))],
+                max_len=self.max_atoms_in_sentence, padding_value=0)
+
+            num_atoms_per_word = get_num_atoms_batch([categories], len_sentence)
+
+            pos_idx = get_pos_idx(atoms, polarities, self.max_atoms_in_one_type)
+            neg_idx = get_neg_idx(atoms, polarities, self.max_atoms_in_one_type)
+
+            output = self.Supertagger.forward(sentences_tokens, sentences_mask)
+
+            logits_predictions = self(num_atoms_per_word, atoms_tokenized, pos_idx, neg_idx, output['word_embeding'])
+            axiom_links_pred = torch.argmax(logits_predictions, dim=3)
+
+        return axiom_links_pred
+
+    def predict_without_categories(self, sentence):
+        r""" Predict the links from a sentence
+
+        Args :
+            sentence : list of words composing the sentence
+        """
+        self.eval()
+        with torch.no_grad():
+            self.cpu()
+            self.device = torch.device("cpu")
+            sentences_tokens, sentences_mask = self.Supertagger.sent_tokenizer.fit_transform_tensors([sentence])
+            nb_sentence, len_sentence = sentences_tokens.shape
+
+            hidden_state, categories = self.Supertagger.predict(sentence)
+
+            output = self.Supertagger.forward(sentences_tokens, sentences_mask)
+            atoms = get_atoms_batch(categories)
+            atoms_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms)
+
+            polarities = find_pos_neg_idexes(categories)
+            polarities = pad_sequence(
+                [torch.as_tensor(polarities[i], dtype=torch.bool) for i in range(len(polarities))],
+                max_len=self.max_atoms_in_sentence, padding_value=0)
+
+            num_atoms_per_word = get_num_atoms_batch(categories, len_sentence)
+
+            pos_idx = get_pos_idx(atoms, polarities, self.max_atoms_in_one_type)
+            neg_idx = get_neg_idx(atoms, polarities, self.max_atoms_in_one_type)
+
+            logits_predictions = self(num_atoms_per_word, atoms_tokenized, pos_idx, neg_idx, output['word_embeding'])
+            axiom_links_pred = torch.argmax(logits_predictions, dim=3)
+
+        return axiom_links_pred
+
     def load_weights(self, model_file):
         print("#" * 15)
         try:
             params = torch.load(model_file, map_location=self.device)
-            args = params['args']
-            self.max_atoms_in_sentence = args['max_atoms_in_sentence']
-            self.atoms_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence)
             self.atom_encoder.load_state_dict(params['atom_encoder'])
             self.position_encoder.load_state_dict(params['position_encoder'])
             self.transformer.load_state_dict(params['transformer'])
             self.linker_encoder.load_state_dict(params['linker_encoder'])
             self.pos_transformation.load_state_dict(params['pos_transformation'])
             self.neg_transformation.load_state_dict(params['neg_transformation'])
+            self.cross_entropy_loss.load_state_dict(params['cross_entropy_loss'])
             self.optimizer.load_state_dict(params['optimizer'])
             print("\n The loading checkpoint was successful ! \n")
         except Exception as e:
@@ -408,10 +469,11 @@ class Linker(Module):
         torch.save({
             'atom_encoder': self.atom_encoder.state_dict(),
             'position_encoder': self.position_encoder,
-            'transformer': self.transformer,
+            'transformer': self.transformer.state_dict(),
             'linker_encoder': self.linker_encoder.state_dict(),
             'pos_transformation': self.pos_transformation.state_dict(),
             'neg_transformation': self.neg_transformation.state_dict(),
+            'cross_entropy_loss': self.cross_entropy_loss,
             'optimizer': self.optimizer,
         }, path)
         self.to(self.device)
diff --git a/Linker/PositionalEncoding.py b/Linker/PositionalEncoding.py
index ba50a703d4bc0007dede0981a68930d2f1b6d6fa..19e1b96c0bd17b9867d9d24bda52a619e7559e4e 100644
--- a/Linker/PositionalEncoding.py
+++ b/Linker/PositionalEncoding.py
@@ -5,7 +5,7 @@ import math
 
 class PositionalEncoding(nn.Module):
 
-    def __init__(self, d_model: int, dropout: float = 0.15, max_len: int = 5000):
+    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
         super().__init__()
         self.dropout = nn.Dropout(p=dropout)
 
diff --git a/Linker/__init__.py b/Linker/__init__.py
index eea58e3d271e21cf9e32bf4e085170ef30e9ef8b..0983f0bce2a67ac9fea8478389d3fb706ce820a1 100644
--- a/Linker/__init__.py
+++ b/Linker/__init__.py
@@ -1,4 +1,5 @@
 from .Linker import Linker
 from .atom_map import atom_map
 from .AtomTokenizer import AtomTokenizer
+from .PositionalEncoding import PositionalEncoding
 from .Sinkhorn import *
\ No newline at end of file
diff --git a/Linker/eval.py b/Linker/eval.py
index 05c096639ee2d12f9b6fa38f44833067b4169440..086f2a94afabb63d22966c36d8202404da800c0d 100644
--- a/Linker/eval.py
+++ b/Linker/eval.py
@@ -5,6 +5,9 @@ from Linker.atom_map import atom_map, atom_map_redux
 
 
 class SinkhornLoss(Module):
+    r"""
+    Loss for the linker
+    """
     def __init__(self):
         super(SinkhornLoss, self).__init__()
 
@@ -13,7 +16,7 @@ class SinkhornLoss(Module):
                    for link, perm in zip(predictions, truths.permute(1, 0, 2)))
 
 
-def mesure_accuracy(batch_true_links, axiom_links_pred, max_atoms_in_one_type):
+def measure_accuracy(batch_true_links, axiom_links_pred, max_atoms_in_one_type):
     r"""
     batch_true_links : (atom_vocab_size, batch_size, max_atoms_in_one_cat) contains the index of the negative atoms
     axiom_links_pred : (atom_vocab_size, batch_size, max_atoms_in_one_cat) contains the index of the negative atoms
diff --git a/Linker/utils_linker.py b/Linker/utils_linker.py
index 8bb55d1673f82ce8863b8177ce537e26249d52cb..199351e586ff571c5a73ee1947d7067023a8c581 100644
--- a/Linker/utils_linker.py
+++ b/Linker/utils_linker.py
@@ -1,4 +1,6 @@
 import re
+
+import pandas as pd
 import regex
 import torch
 from torch.nn import Sequential, Linear, Dropout, GELU
@@ -40,7 +42,6 @@ def get_axiom_links(max_atoms_in_one_type, atoms_polarity, batch_axiom_links):
         batch_true_links : (batch_size, atom_vocab_size, max_atoms_in_one_cat) contains the index of the negative atoms
     """
     atoms_batch = get_atoms_links_batch(batch_axiom_links)
-    atoms_batch = list(map(lambda sentence: sentence.split(" "), atoms_batch))
     linking_plus_to_minus_all_types = []
     for atom_type in list(atom_map_redux.keys()):
         # filtrer sur atom_batch que ce type puis filtrer avec les indices sur atom polarity
@@ -76,12 +77,12 @@ def category_to_atoms_axiom_links(category, categories_to_atoms):
         word, cat = category.split(':')
         return category_to_atoms_axiom_links(cat, categories_to_atoms)
     elif True in res:
-        return " " + category
+        return [category]
     else:
         category_cut = regex.match(regex_categories_axiom_links, category).groups()
         category_cut = [cat for cat in category_cut if cat is not None]
         for cat in category_cut:
-            categories_to_atoms += category_to_atoms_axiom_links(cat, "")
+            categories_to_atoms += category_to_atoms_axiom_links(cat, [])
         return categories_to_atoms
 
 
@@ -94,23 +95,22 @@ def get_atoms_links_batch(category_batch):
     """
     batch = []
     for sentence in category_batch:
-        categories_to_atoms = ""
+        categories_to_atoms = []
         for category in sentence:
             if category != "let" and not category.startswith("GOAL:"):
-                categories_to_atoms += category_to_atoms_axiom_links(category, "")
-                categories_to_atoms += " [SEP]"
-                categories_to_atoms = categories_to_atoms.lstrip()
+                categories_to_atoms += category_to_atoms_axiom_links(category, [])
+                categories_to_atoms.append("[SEP]")
             elif category.startswith("GOAL:"):
-                categories_to_atoms += category_to_atoms_axiom_links(category, "")
-                categories_to_atoms = categories_to_atoms.lstrip()
+                categories_to_atoms = category_to_atoms_axiom_links(category, []) + categories_to_atoms
         batch.append(categories_to_atoms)
     return batch
 
 
 print("test to create links ",
       get_axiom_links(20, torch.stack([torch.as_tensor(
-          [False, True, False, False, False, True, False, True, False, False, True, False, False, False, True, False,
-           False, True, False, True, False, False, True, False, False, False, True])]),
+          [True, False, True, False, False, False, True, False, True, False,
+           False, True, False, False, False, True, False, False, True, False,
+           True, False, False, True, False, False, False, False, False, False])]),
                       [['dr(0,np_1,n_2)', 'n_2', 'dr(0,dl(0,np_1,np_3),np_4)', 'dr(0,np_4,n_5)', 'n_6', 'dl(0,n_6,n_5)',
                         'dr(0,dl(0,np_3,np_7),np_8)', 'dr(0,np_8,np_9)', 'np_9', 'GOAL:np_7']]))
 
@@ -132,12 +132,12 @@ def category_to_atoms(category, categories_to_atoms):
         word, cat = category.split(':')
         return category_to_atoms(cat, categories_to_atoms)
     elif True in res:
-        return " " + category
+        return [category]
     else:
         category_cut = regex.match(regex_categories, category).groups()
         category_cut = [cat for cat in category_cut if cat is not None]
         for cat in category_cut:
-            categories_to_atoms += category_to_atoms(cat, "")
+            categories_to_atoms += category_to_atoms(cat, [])
         return categories_to_atoms
 
 
@@ -150,17 +150,17 @@ def get_atoms_batch(category_batch):
     """
     batch = []
     for sentence in category_batch:
-        categories_to_atoms = ""
+        categories_to_atoms = []
         for category in sentence:
             if category != "let":
-                categories_to_atoms += category_to_atoms(category, "")
-                categories_to_atoms += " [SEP]"
-                categories_to_atoms = categories_to_atoms.lstrip()
+                categories_to_atoms += category_to_atoms(category, [])
+                categories_to_atoms.append("[SEP]")
         batch.append(categories_to_atoms)
     return batch
 
 
-print(" test for get atoms in categories on ['dr(0,s,np)', 'let']", get_atoms_batch([["dr(0,s,np)", "let"]]))
+print(" test for get atoms in categories on ['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'let']",
+      get_atoms_batch([['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'let']]))
 
 
 # endregion
@@ -201,7 +201,7 @@ def get_num_atoms_batch(category_batch, max_len_sentence):
     """
     batch = []
     for sentence in category_batch:
-        num_atoms_sentence = []
+        num_atoms_sentence = [0]
         for category in sentence:
             num_atoms_in_word = category_to_num_atoms(category, 0)
             # add 1 because for word we have SEP at the end
@@ -309,8 +309,7 @@ def find_pos_neg_idexes(atoms_batch):
     return list_batch
 
 
-print(
-    " test for get polarities for atoms in categories on ['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'dr(0,np,n)', 'n', 'dl(0,n,n)', 'dr(0,dl(0,np,np),np)', 'dr(0,np,np)', 'np']",
+print(" test for get polarities for atoms in categories on ['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'dr(0,np,n)', 'n', 'dl(0,n,n)', 'dr(0,dl(0,np,np),np)', 'dr(0,np,np)', 'np'] \n",
     find_pos_neg_idexes([['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'dr(0,np,n)', 'n', 'dl(0,n,n)',
                           'dr(0,dl(0,np,np),np)', 'dr(0,np,np)', 'np']]))
 
@@ -319,75 +318,86 @@ print(
 
 # region get atoms and polarities with GOAL
 
-def get_GOAL(max_atoms_in_sentence, categories_batch):
+def get_GOAL(max_len_sentence, df_axiom_links):
+    categories_batch = df_axiom_links["Z"]
+    categories_with_goal = df_axiom_links["Y"]
     polarities = find_pos_neg_idexes(categories_batch)
     atoms_batch = get_atoms_batch(categories_batch)
-    atoms_batch_for_polarities = list(
-        map(lambda sentence: [item for item in sentence.split(" ")], atoms_batch))
+    num_atoms_batch = get_num_atoms_batch(categories_batch, max_len_sentence)
     for s_idx in range(len(atoms_batch)):
-        for atom_type in list(atom_map_redux.keys()):
-            list_plus = [x for i, x in enumerate(atoms_batch_for_polarities[s_idx]) if polarities[s_idx][i]
-                         and bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch_for_polarities[s_idx][i]))]
-            list_minus = [x for i, x in enumerate(atoms_batch_for_polarities[s_idx]) if not polarities[s_idx][i]
-                          and bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch_for_polarities[s_idx][i]))]
-            while len(list_minus) != len(list_plus):
-                if len(list_minus) > len(list_plus):
-                    atoms_batch[s_idx] += " " + atom_type
-                    atoms_batch_for_polarities[s_idx].append(atom_type)
-                    polarities[s_idx].append(True)
-                else:
-                    atoms_batch[s_idx] += " " + atom_type
-                    atoms_batch_for_polarities[s_idx].append(atom_type)
-                    polarities[s_idx].append(False)
-                list_plus = [x for i, x in enumerate(atoms_batch_for_polarities[s_idx]) if polarities[s_idx][i]
-                             and atoms_batch_for_polarities[s_idx][i] == atom_type]
-                list_minus = [x for i, x in enumerate(atoms_batch_for_polarities[s_idx]) if not polarities[s_idx][i]
-                              and atoms_batch_for_polarities[s_idx][i] == atom_type]
-
-    return atoms_batch, pad_sequence([torch.as_tensor(polarities[i], dtype=torch.bool) for i in range(len(polarities))],
-                                     max_len=max_atoms_in_sentence, padding_value=0)
-
-
-print(" test for get GOAL on ['dr(0,s,np)', 's']", get_GOAL(12, [["dr(0,s,np)", "s"]]))
+        goal = categories_with_goal[s_idx][-1]
+        polarities_goal = category_to_atoms_polarity(goal, True)
+        goal = re.search(r"(\w+)_\d+", goal).groups()[0]
+        atoms = category_to_atoms(goal, [])
+
+        atoms_batch[s_idx] = atoms + atoms_batch[s_idx]  # + ["[SEP]"]
+        polarities[s_idx] = polarities_goal + polarities[s_idx]  # + False
+        num_atoms_batch[s_idx][0] += len(atoms)  # +1
+
+    return atoms_batch, polarities, num_atoms_batch
+
+
+df_axiom_links = pd.DataFrame({"Z": [['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'dr(0,np,n)', 'n', 'dl(0,n,n)',
+                                      'dr(0,dl(0,np,np),np)', 'dr(0,np,np)', 'np']],
+                               "Y": [['dr(0,np_1,n_2)', 'n_2', 'dr(0,dl(0,np_1,np_3),np_4)', 'dr(0,np_4,n_5)', 'n_6',
+                                      'dl(0,n_6,n_5)', 'dr(0,dl(0,np_3,np_7),np_8)', 'dr(0,np_8,np_9)', 'np_9',
+                                      'GOAL:np_7']]})
+print(" test for get GOAL ", get_GOAL(10, df_axiom_links))
+
+
+# endregion
+
+# region get atoms and polarities after tagger
+
+def get_info_for_tagger(max_len_sentence, pred_categories):
+    categories_batch = pred_categories
+    polarities = find_pos_neg_idexes(categories_batch)
+    atoms_batch = get_atoms_batch(categories_batch)
+    num_atoms_batch = get_num_atoms_batch(categories_batch, max_len_sentence)
+
+    return atoms_batch, polarities, num_atoms_batch
+
+
+df_axiom_links = pd.DataFrame({"Z": [['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'dr(0,np,n)', 'n', 'dl(0,n,n)',
+                                      'dr(0,dl(0,np,np),np)', 'dr(0,np,np)', 'np']],
+                               "Y": [['dr(0,np_1,n_2)', 'n_2', 'dr(0,dl(0,np_1,np_3),np_4)', 'dr(0,np_4,n_5)', 'n_6',
+                                      'dl(0,n_6,n_5)', 'dr(0,dl(0,np_3,np_7),np_8)', 'dr(0,np_8,np_9)', 'np_9',
+                                      'GOAL:np_7']]})
+print(" test for get GOAL ", get_GOAL(10, df_axiom_links))
 
 
 # endregion
 
 # region get idx for pos and neg
 
-def get_pos_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type, max_atoms_in_sentence):
-    atoms_batch_for_polarities = list(
-        map(lambda sentence: sentence.split(" "), atoms_batch))
+def get_pos_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type):
     pos_idx = [pad_sequence([torch.as_tensor([i for i, x in enumerate(sentence) if
-                                              bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z",
-                                                            atoms_batch_for_polarities[s_idx][i])) and
+                                              bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch[s_idx][i])) and
                                               atoms_polarity_batch[s_idx][i]])
-                             for s_idx, sentence in enumerate(atoms_batch_for_polarities)],
+                             for s_idx, sentence in enumerate(atoms_batch)],
                             max_len=max_atoms_in_one_type // 2, padding_value=-1)
                for atom_type in list(atom_map_redux.keys())]
 
     return torch.stack(pos_idx).permute(1, 0, 2)
 
 
-def get_neg_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type, max_atoms_in_sentence):
-    atoms_batch_for_polarities = list(
-        map(lambda sentence: sentence.split(" "), atoms_batch))
+def get_neg_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type):
     pos_idx = [pad_sequence([torch.as_tensor([i for i, x in enumerate(sentence) if
-                                              bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z",
-                                                            atoms_batch_for_polarities[s_idx][i])) and not
-                                              atoms_polarity_batch[s_idx][i]])
-                             for s_idx, sentence in enumerate(atoms_batch_for_polarities)],
+                                              bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch[s_idx][i])) and
+                                              not atoms_polarity_batch[s_idx][i]])
+                             for s_idx, sentence in enumerate(atoms_batch)],
                             max_len=max_atoms_in_one_type // 2, padding_value=-1)
                for atom_type in list(atom_map_redux.keys())]
 
     return torch.stack(pos_idx).permute(1, 0, 2)
 
 
-print(" test for cut into pos neg on ['s np [SEP] s [SEP] np s s n n']", get_neg_idx(['s np [SEP] s [SEP] np s s n n'],
-                                                                                     torch.as_tensor(
-                                                                                         [[False, True, False, False,
-                                                                                           False, False, True, True,
-                                                                                           False, True,
-                                                                                           False, False]]), 10, 50))
+print(" test for cut into pos neg on ['dr(0,s,np)', 's']",
+      get_neg_idx([['s', 's', 'np', 's', 'np', '[SEP]', 's', '[SEP]']],
+                  torch.as_tensor(
+                      [[True, True, False, False,
+                        True, False, False, False,
+                        False, False,
+                        False, False]]), 10))
 
-# endregion
+# endregion
\ No newline at end of file
diff --git a/NeuralProofNet/NeuralProofNet.py b/NeuralProofNet/NeuralProofNet.py
index be4dbb4774452a2c0611448946b43cd958c2bc21..bb12a552d5efcb8cf31f982ac7022bb74092efb5 100644
--- a/NeuralProofNet/NeuralProofNet.py
+++ b/NeuralProofNet/NeuralProofNet.py
@@ -13,8 +13,11 @@ from tqdm import tqdm
 
 from Configuration import Configuration
 from Linker import Linker
-from Linker.eval import mesure_accuracy, SinkhornLoss
-from Linker.utils_linker import get_axiom_links, get_GOAL, get_pos_idx, get_num_atoms_batch, get_neg_idx
+from Linker.eval import measure_accuracy, SinkhornLoss
+from Linker.utils_linker import get_axiom_links, get_GOAL, get_pos_idx, get_num_atoms_batch, get_neg_idx, \
+    get_info_for_tagger
+from find_config import configurate
+from utils import pad_sequence
 
 
 def format_time(elapsed):
@@ -44,15 +47,21 @@ def output_create_dir():
 class NeuralProofNet(Module):
     def __init__(self, supertagger_path_model, linker_path_model=None):
         super(NeuralProofNet, self).__init__()
+        config = Configuration.read_config()
+        datasetConfig = config["DATASET_PARAMS"]
+        modelEncoderConfig = config["MODEL_ENCODER"]
+        modelLinkerConfig = config["MODEL_LINKER"]
+        modelTrainingConfig = config["MODEL_TRAINING"]
+
         # pretrain settings
-        self.pretrain_linker_epochs = int(Configuration.modelTrainingConfig['pretrain_linker_epochs'])
+        self.pretrain_linker_epochs = int(modelTrainingConfig['pretrain_linker_epochs'])
         # settings
-        self.max_len_sentence = int(Configuration.datasetConfig['max_len_sentence'])
-        self.max_atoms_in_sentence = int(Configuration.datasetConfig['max_atoms_in_sentence'])
-        self.max_atoms_in_one_type = int(Configuration.datasetConfig['max_atoms_in_one_type'])
-        learning_rate = float(Configuration.modelTrainingConfig['learning_rate'])
+        self.max_len_sentence = int(datasetConfig['max_len_sentence'])
+        self.max_atoms_in_sentence = int(datasetConfig['max_atoms_in_sentence'])
+        self.max_atoms_in_one_type = int(datasetConfig['max_atoms_in_one_type'])
+        learning_rate = float(modelTrainingConfig['learning_rate'])
         self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-        self.batch_size = int(Configuration.modelTrainingConfig['batch_size'])
+        self.batch_size = int(modelTrainingConfig['batch_size'])
 
         linker = Linker(supertagger_path_model)
         if linker_path_model is not None:
@@ -63,16 +72,17 @@ class NeuralProofNet(Module):
         # Learning
         self.linker_loss = SinkhornLoss()
         self.linker_optimizer = AdamW(self.linker.parameters(),
-                               lr=learning_rate)
+                                      lr=learning_rate)
         self.linker_scheduler = StepLR(self.linker_optimizer, step_size=2, gamma=0.5)
 
         self.to(self.device)
 
-    def __pretrain_linker__(self, df_axiom_links):
+    def __pretrain_linker__(self, df_axiom_links, checkpoint=False, tensorboard=True):
         print("\nLinker Pre-Training\n")
-        self.linker.train_linker(df_axiom_links, validation_rate=0.1, epochs=self.pretrain_linker_epochs, batch_size=self.batch_size,
-                                 checkpoint=False,
-                                 tensorboard=True)
+        self.linker.train_linker(df_axiom_links, validation_rate=0.05, epochs=self.pretrain_linker_epochs,
+                                 batch_size=self.batch_size,
+                                 checkpoint=checkpoint,
+                                 tensorboard=tensorboard)
         print("\nEND Linker Pre-Training\n")
 
     def __preprocess_data(self, batch_size, df_axiom_links, validation_rate=0.1):
@@ -88,7 +98,10 @@ class NeuralProofNet(Module):
         sentences_batch = df_axiom_links["X"].str.strip().tolist()
         sentences_tokens, sentences_mask = self.Supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch)
 
-        _, atoms_polarity_batch = get_GOAL(self.max_atoms_in_sentence, df_axiom_links["Z"])
+        _, polarities, _ = get_GOAL(self.max_len_sentence, df_axiom_links)
+        atoms_polarity_batch = pad_sequence(
+            [torch.as_tensor(polarities[i], dtype=torch.bool) for i in range(len(polarities))],
+            max_len=self.max_atoms_in_sentence, padding_value=0)
 
         truth_links_batch = get_axiom_links(self.max_atoms_in_one_type, atoms_polarity_batch,
                                             df_axiom_links["Y"])
@@ -114,28 +127,26 @@ class NeuralProofNet(Module):
 
         # get sentence embedding from BERT which is already trained
         output = self.Supertagger.forward(batch_sentences_tokens, batch_sentences_mask)
-        last_hidden_state = output['last_hidden_state']
+        last_hidden_state = output['logit']
         pred_categories = torch.argmax(torch.softmax(last_hidden_state, dim=2), dim=2)
         pred_categories = self.Supertagger.tags_tokenizer.convert_ids_to_tags(pred_categories)
 
         # get information from tagger predictions
-        batch_atoms, atoms_polarity_batch = get_GOAL(self.max_atoms_in_sentence, pred_categories)
-        batch_num_atoms_per_word = get_num_atoms_batch(pred_categories, self.max_len_sentence)
-        batch_pos_idx = get_pos_idx(batch_atoms, atoms_polarity_batch, self.max_atoms_in_one_type,
-                                    self.max_atoms_in_sentence)
-        batch_neg_idx = get_neg_idx(batch_atoms, atoms_polarity_batch, self.max_atoms_in_one_type,
-                                    self.max_atoms_in_sentence)
-
-        atoms_batch_tokenized = self.linker.atoms_tokenizer.convert_batchs_to_ids(
-            list(map(lambda sentence: [item for item in sentence.split(" ")], batch_atoms)))
+        atoms_batch, polarities, batch_num_atoms_per_word = get_info_for_tagger(self.max_len_sentence, pred_categories)
+        atoms_polarity_batch = pad_sequence(
+            [torch.as_tensor(polarities[i], dtype=torch.bool) for i in range(len(polarities))],
+            max_len=self.max_atoms_in_sentence, padding_value=0)
+        atoms_batch_tokenized = self.linker.atoms_tokenizer.convert_batchs_to_ids(atoms_batch)
+        batch_pos_idx = get_pos_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type)
+        batch_neg_idx = get_neg_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type)
+
         logits_links = self.linker(batch_num_atoms_per_word, atoms_batch_tokenized, batch_pos_idx, batch_neg_idx,
-                                   output['word_embeding'],
-                                   last_hidden_state)
+                                   output['word_embeding'])
 
         return torch.log_softmax(logits_links, dim=3)
 
     def train_neuralproofnet(self, df_axiom_links, validation_rate=0.1, epochs=20,
-                     batch_size=32, checkpoint=True, tensorboard=False):
+                             batch_size=32, checkpoint=True, tensorboard=False):
         r"""
         Args:
             df_axiom_links : pandas dataFrame containing the atoms anoted with _i
@@ -171,8 +182,7 @@ class NeuralProofNet(Module):
                 print(f'\tVal Loss: {loss_test:.3f} | Val Acc: {accuracy_test * 100:.2f}%')
 
             if checkpoint:
-                self.__checkpoint_save(
-                    path=os.path.join("Output", 'linker' + datetime.today().strftime('%d-%m_%H-%M') + '.pt'))
+                self.__checkpoint_save(path='Output/linker.pt')
 
             if tensorboard:
                 writer.add_scalars(f'Accuracy', {
@@ -228,7 +238,7 @@ class NeuralProofNet(Module):
                 self.linker_optimizer.step()
 
                 pred_axiom_links = torch.argmax(logits_predictions_links, dim=3)
-                accuracy_train += mesure_accuracy(batch_true_links, pred_axiom_links, self.max_atoms_in_one_type)
+                accuracy_train += measure_accuracy(batch_true_links, pred_axiom_links, self.max_atoms_in_one_type)
 
         self.linker_scheduler.step()
 
@@ -245,14 +255,15 @@ class NeuralProofNet(Module):
         batch_sentences_mask = batch[2].to(self.device)
 
         logits_predictions_links = self(batch_sentences_tokens, batch_sentences_mask)
-        axiom_links_pred = torch.argmax(logits_predictions_links, dim=3)  # atom_vocab, batch_size, max atoms in one type
+        axiom_links_pred = torch.argmax(logits_predictions_links,
+                                        dim=3)  # atom_vocab, batch_size, max atoms in one type
 
         print('\n')
         print("Les vrais liens de la catégorie n : ", batch_true_links[1][2][:100])
         print("Les prédictions : ", axiom_links_pred[2][1][:100])
         print('\n')
 
-        accuracy = mesure_accuracy(batch_true_links, axiom_links_pred, self.max_atoms_in_one_type)
+        accuracy = measure_accuracy(batch_true_links, axiom_links_pred, self.max_atoms_in_one_type)
         linker_loss = self.linker_loss(logits_predictions_links, batch_true_links, self.max_atoms_in_one_type)
 
         return linker_loss, accuracy
@@ -272,4 +283,22 @@ class NeuralProofNet(Module):
                 accuracy_average += accuracy
                 loss_average += float(loss)
 
-        return loss_average / len(dataloader), accuracy_average / len(dataloader)
\ No newline at end of file
+        return loss_average / len(dataloader), accuracy_average / len(dataloader)
+
+    def __checkpoint_save(self, path='/linker.pt'):
+        """
+        @param path:
+        """
+        self.cpu()
+
+        torch.save({
+            'atom_encoder': self.linker.atom_encoder.state_dict(),
+            'position_encoder': self.linker.position_encoder,
+            'transformer': self.linker.transformer.state_dict(),
+            'linker_encoder': self.linker.linker_encoder.state_dict(),
+            'pos_transformation': self.linker.pos_transformation.state_dict(),
+            'neg_transformation': self.linker.neg_transformation.state_dict(),
+            'cross_entropy_loss': self.linker_loss,
+            'optimizer': self.linker_optimizer,
+        }, path)
+        self.to(self.device)
\ No newline at end of file
diff --git a/README.md b/README.md
index 5994f1455440e7055fec3c5dd2f7e9baaa7e0cd5..3348ca6cbfe8929a15c17c5101c8b6dfe2754856 100644
--- a/README.md
+++ b/README.md
@@ -1,15 +1,52 @@
-# DeepGrail
+# DeepGrail Linker
 
+This repository contains a Python implementation of a Neural Proof Net using TLGbank data.
+
+This code was designed to work with the [DeepGrail Tagger](https://gitlab.irit.fr/pnria/global-helper/deepgrail_tagger). 
+In this repository we only use the embedding of the word from the tagger and the tags from the dataset, but next step is to use the prediction of the tagger for the linking step.
+ 
 ## Usage
 
 ### Installation
 Python 3.9.10 **(Warning don't use Python 3.10**+**)**
+Clone the project locally.
+
+### Libraries installation
+
+Run the init.sh script or install the Tagger project under SuperTagger name.
+
+### Dataset format
+
+The sentences should be in a column "X", the links with '_x' postfix should be in a column "Y" and the categories in a column "Z".
+For the links each atom_x goes with the one and only other atom_x in the sentence.
+
+## Training
+
+Launch train.py, if you look at it you can give another dataset file and another tagging model.
+
+In train, if you use `checkpoint=True`, the model is automatically saved in a folder: Training_XX-XX_XX-XX. It saves
+after each epoch. Use `tensorboard=True` for log in same folder. (`tensorboard --logdir=logs` for see logs)
+
+## Predicting
+
+For predict on your data you need to load a model (save with this code).
+
+```
+df = read_csv_pgbar(file_path,20)
+texts = df['X'].tolist()
+categories = df['Z'].tolist()
+
+linker = Linker(tagging_model)
+linker.load_weights("your/linker/path")
 
-Clone the project locally. In a clean python venv do `pip install -r requirements.txt`
+links = linker.predict_with_categories(texts[7], categories[7])
+print(links)
+```
 
-## How To use
+The file ```postprocessing.py``` will allow you to draw the prediction. (limited sentence length otherwise it will be confusing) 
 
-TODO ...
+You can also use the function ```predict_without_categories``` which only needs the sentence.
 
-tensorboard --logdir=logs
+## Authors
 
+[de Pourtales Caroline](https://www.linkedin.com/in/caroline-de-pourtales/), [Rabault Julien](https://www.linkedin.com/in/julienrabault)
\ No newline at end of file
diff --git a/SuperTagger b/SuperTagger
new file mode 160000
index 0000000000000000000000000000000000000000..7b10151214babc2c3f1bc474eb9bec25458a8347
--- /dev/null
+++ b/SuperTagger
@@ -0,0 +1 @@
+Subproject commit 7b10151214babc2c3f1bc474eb9bec25458a8347
diff --git a/find_config.py b/find_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..53725288c4f3c1c2b29c7fe2f8d473fb0d0a4f4a
--- /dev/null
+++ b/find_config.py
@@ -0,0 +1,61 @@
+import configparser
+import re
+
+import torch
+
+from Linker.atom_map import atom_map_redux
+from Linker.utils_linker import get_GOAL, get_atoms_links_batch, get_atoms_batch
+from SuperTagger.SuperTagger.SuperTagger import SuperTagger
+from utils import read_csv_pgbar, pad_sequence
+
+
+def configurate(dataset, model_tagger, nb_sentences=1000000000):
+    print("#" * 20)
+    print("#" * 20)
+    print("Configuration with dataset\n")
+    config = configparser.ConfigParser()
+    config.read('Configuration/config.ini')
+
+    file_path_axiom_links = dataset
+    df_axiom_links = read_csv_pgbar(file_path_axiom_links, nb_sentences)
+
+    supertagger = SuperTagger()
+    supertagger.load_weights(model_tagger)
+    sentences_batch = df_axiom_links["X"].str.strip().tolist()
+    sentences_tokens, sentences_mask = supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch)
+    max_len_sentence = 0
+    for sentence in sentences_tokens:
+        if len(sentence) > max_len_sentence:
+            max_len_sentence = len(sentence)
+    print("Configure parameter max len sentence to ", max_len_sentence)
+    config.set('DATASET_PARAMS', 'max_len_sentence', str(max_len_sentence))
+
+    atoms_batch, polarities, num_batch = get_GOAL(max_len_sentence, df_axiom_links)
+    max_atoms_in_sentence = 0
+    for sentence in atoms_batch:
+        if len(sentence) > max_atoms_in_sentence:
+            max_atoms_in_sentence = len(sentence)
+    print("Configure parameter max atoms in categories to", max_atoms_in_sentence)
+    config.set('DATASET_PARAMS', 'max_atoms_in_sentence', str(max_atoms_in_sentence))
+
+    atoms_polarity_batch = pad_sequence([torch.as_tensor(polarities[i], dtype=torch.bool) for i in range(len(polarities))],
+                 max_len=max_atoms_in_sentence, padding_value=0)
+    pos_idx = [[torch.as_tensor([i for i, x in enumerate(sentence) if
+                 bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch[s_idx][i]))
+                 and atoms_polarity_batch[s_idx][i]])
+                for s_idx, sentence in enumerate(atoms_batch)]
+               for atom_type in list(atom_map_redux.keys())]
+    max_atoms_in_on_type = 0
+    for atoms_type_batch in pos_idx:
+        for sentence in atoms_type_batch:
+            length = sentence.size(0)
+            if length > max_atoms_in_on_type:
+                max_atoms_in_on_type = length
+    print("Configure parameter max atoms of one type in one sentence to", max_atoms_in_on_type)
+    config.set('DATASET_PARAMS', 'max_atoms_in_one_type', str(max_atoms_in_on_type * 2+2))
+
+    with open('Configuration/config.ini', 'w') as configfile:  # save
+        config.write(configfile)
+
+    print("#" * 20)
+    print("#" * 20)
\ No newline at end of file
diff --git a/init.sh b/init.sh
new file mode 100644
index 0000000000000000000000000000000000000000..be8706dd03e8a0e984d4b2e27be793842bb5ea0f
--- /dev/null
+++ b/init.sh
@@ -0,0 +1,3 @@
+git clone https://gitlab.irit.fr/pnria/global-helper/deepgrail_tagger.git SuperTagger
+
+pip install -r requirements.txt
\ No newline at end of file
diff --git a/postprocessing.py b/postprocessing.py
new file mode 100644
index 0000000000000000000000000000000000000000..d2d43f03a99f877de7391c9e62bd91aea2996e30
--- /dev/null
+++ b/postprocessing.py
@@ -0,0 +1,119 @@
+import re
+
+import graphviz
+import numpy as np
+import regex
+from Linker.atom_map import atom_map, atom_map_redux
+
+regex_categories = r'\w+\(\d+,(?:((?R))|(\w+))*,?(?:((?R))|(\w+))*\)'
+
+
+def recursive_linking(links, dot, category, parent_id, word_idx, depth,
+                      polarity, compt_plus, compt_neg):
+    r"""
+    recursive linking between atoms inside a category
+    :param links:
+    :param dot:
+    :param category:
+    :param parent_id:
+    :param word_idx:
+    :param depth:
+    :param polarity:
+    :param compt_plus:
+    :param compt_neg:
+    :return:
+    """
+    res = [(category == atom_type) for atom_type in atom_map.keys()]
+    if True in res:
+        polarity = not polarity
+        if polarity:
+            atoms_idx = compt_plus[category]
+            compt_plus[category] += 1
+        else:
+            idx_neg = compt_neg[category]
+            compt_neg[category] += 1
+            atoms_idx = np.where(links[atom_map_redux[category]] == idx_neg)[0][0]
+        atom_id = category + "_" + str(polarity) + "_" + str(atoms_idx)
+        dot.node(atom_id, category + " " + str("+" if polarity else "-"))
+        dot.edge(parent_id, atom_id)
+    else:
+        category_id = category + "_" + str(word_idx) + "_" + str(depth)
+        dot.node(category_id, category + " " + str("+" if polarity else "-"))
+        dot.edge(parent_id, category_id)
+        parent_id = category_id
+
+        if category.startswith("dr"):
+            categories_inside = regex.match(regex_categories, category).groups()
+            categories_inside = [cat for cat in categories_inside if cat is not None]
+            categories_inside = [categories_inside[0], categories_inside[1]]
+            polarities_inside = [polarity, not polarity]
+
+        # dl / p
+        elif category.startswith("dl") or category.startswith("p"):
+            categories_inside = regex.match(regex_categories, category).groups()
+            categories_inside = [cat for cat in categories_inside if cat is not None]
+            categories_inside = [categories_inside[0], categories_inside[1]]
+            polarities_inside = [not polarity, polarity]
+
+        # box / dia
+        elif category.startswith("box") or category.startswith("dia"):
+            categories_inside = regex.match(regex_categories, category).groups()
+            categories_inside = [cat for cat in categories_inside if cat is not None]
+            categories_inside = [categories_inside[0]]
+            polarities_inside = [polarity]
+
+        else:
+            categories_inside = []
+            polarities_inside = []
+
+        for cat_id in range(len(categories_inside)):
+            recursive_linking(links, dot, categories_inside[cat_id], parent_id, word_idx, depth + 1,
+                              polarities_inside[cat_id], compt_plus,
+                              compt_neg)
+
+
+def draw_sentence_output(sentence, categories, links):
+    r"""
+    Drawing the prediction of a sentence when given categories and links predictions
+    :param sentence: list of words
+    :param categories: list of categories
+    :param links: links predicted
+    :return: dot source
+    """
+    dot = graphviz.Graph('linking', comment='Axiom linking')
+    dot.graph_attr['rankdir'] = 'BT'
+    dot.graph_attr['splines'] = 'ortho'
+    dot.graph_attr['ordering'] = 'in'
+
+    compt_plus = {'cl_r': 0, 'pp': 0, 'n': 0, 'np': 0, 'cl_y': 0, 'txt': 0, 's': 0}
+    compt_neg = {'cl_r': 0, 'pp': 0, 'n': 0, 'np': 0, 'cl_y': 0, 'txt': 0, 's': 0}
+    last_word_id = ""
+    for word_idx in range(len(sentence)):
+        word = sentence[word_idx]
+        word_id = word + "_" + str(word_idx)
+        dot.node(word_id, word)
+        if word_idx > 0:
+            dot.edge(last_word_id, word_id, constraint="false", style="invis")
+
+        category = categories[word_idx]
+        polarity = True
+        parent_id = word_id
+        recursive_linking(links, dot, category, parent_id, word_idx, 0, polarity, compt_plus, compt_neg)
+        last_word_id = word_id
+
+    dot.attr('edge', color='red')
+    dot.attr('edge', style='dashed')
+    for atom_type in list(atom_map_redux.keys()):
+        for id in range(compt_plus[atom_type]):
+            atom_plus = atom_type + "_" + str(True) + "_" + str(id)
+            atom_moins = atom_type + "_" + str(False) + "_" + str(id)
+            dot.edge(atom_plus, atom_moins, constraint="false")
+
+    dot.render(format="svg", view=True)
+    return dot.source
+
+
+sentence = ["Le", "chat", "est", "noir", "bleu"]
+categories = ["dr(0,s,n)", "dl(0,s,n)", "dr(0,dl(0,n,np),n)", "dl(0,np,n)", "n"]
+links = np.array([[0, 0, 0, 0], [0, 0, 0, 0], [1, 0, 2, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]])
+draw_sentence_output(sentence, categories, links)
diff --git a/requirements.txt b/requirements.txt
index e13c01e51333f299c437c5f59e5895e7b576096d..ce8002c1e34cda5a75cd5330dc1e4ab8df659555 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,15 +1,11 @@
-numpy==1.22.2
 huggingface-hub==0.4.0
 pandas==1.4.1
-sentencepiece
-huggingface-hub==0.4.0
 Markdown==3.3.6
-numpy==1.22.3
 packaging==21.3
-pandas==1.4.2
 scikit-learn==1.0.2
 scipy==1.8.0
 sentencepiece==0.1.96
+tensorflow==2.9.1
 tensorboard==2.8.0
 torch==1.11.0
 tqdm==4.64.0
diff --git a/train.py b/train.py
index 974b3c5abf0debddeb36b6087a3bac02c0df11c4..1bdeb557da3ede0362de02ab05b4f75514d7cd06 100644
--- a/train.py
+++ b/train.py
@@ -1,15 +1,35 @@
 import torch
-from Configuration import Configuration
+
+from Linker import *
 from NeuralProofNet.NeuralProofNet import NeuralProofNet
 from utils import read_csv_pgbar
+from find_config import configurate
+from Configuration import Configuration
 
-torch.cuda.empty_cache()
-batch_size = int(Configuration.modelTrainingConfig['batch_size'])
-nb_sentences = batch_size * 4
-training_epoch = int(Configuration.modelTrainingConfig['training_epoch'])
 
+torch.cuda.empty_cache()
+nb_sentences = 100000000
 file_path_axiom_links = 'Datasets/goldANDsilver_dataset_links.csv'
+model_tagger = "models/flaubert_super_98_V2_50e.pt"
+
+# region config
+configurate(file_path_axiom_links, model_tagger, nb_sentences=nb_sentences)
+config = Configuration.read_config()
+version = config["VERSION"]
+datasetConfig = config["DATASET_PARAMS"]
+modelEncoderConfig = config["MODEL_ENCODER"]
+modelLinkerConfig = config["MODEL_LINKER"]
+modelTrainingConfig = config["MODEL_TRAINING"]
+epochs = int(modelTrainingConfig['epoch'])
+batch_size = int(modelTrainingConfig['batch_size'])
+# endregion
+
 df_axiom_links = read_csv_pgbar(file_path_axiom_links, nb_sentences)
 
-NeuralProofNet = NeuralProofNet("models/flaubert_super_98_V2_50e.pt")
-NeuralProofNet.train_neuralproofnet(df_axiom_links, validation_rate=0.1, epochs=training_epoch, batch_size=batch_size, checkpoint=False, tensorboard=True)
\ No newline at end of file
+print("#" * 20)
+print("#" * 20)
+neural_proof_net = NeuralProofNet(model_tagger)
+neural_proof_net.train_neuralproofnet(df_axiom_links, validation_rate=0.05, epochs=epochs, batch_size=batch_size,
+                    checkpoint=True, tensorboard=True)
+print("#" * 20)
+print("#" * 20)
\ No newline at end of file
diff --git a/utils.py b/utils.py
index 0433510b2838731d38fd2e42e16a4a7b94ecf3b8..c4fae14e45ecebee5077044687bd9db9a5280936 100644
--- a/utils.py
+++ b/utils.py
@@ -6,6 +6,14 @@ from tqdm import tqdm
 
 
 def pad_sequence(sequences, batch_first=True, padding_value=0, max_len=400):
+    r"""
+    Padding sequence for preparation to tensorDataset
+    :param sequences: data to pad
+    :param batch_first: boolean indicating whether the batch are in first dimension
+    :param padding_value: the value for pad
+    :param max_len: the maximum length
+    :return: padding sequences
+    """
     max_size = sequences[0].size()
     trailing_dims = max_size[1:]
     if batch_first:
@@ -26,7 +34,13 @@ def pad_sequence(sequences, batch_first=True, padding_value=0, max_len=400):
 
 
 def read_csv_pgbar(csv_path, nrows=float('inf'), chunksize=500):
-    print("\n" + "#" * 20)
+    r"""
+    Preparing csv dataset
+    :param csv_path:
+    :param nrows:
+    :param chunksize:
+    :return:
+    """
     print("Loading csv...")
 
     rows = sum(1 for _ in open(csv_path, 'r', encoding="utf8")) - 1  # minus the header
@@ -42,7 +56,6 @@ def read_csv_pgbar(csv_path, nrows=float('inf'), chunksize=500):
             bar.update(len(chunk))
 
     df = pd.concat((f for f in chunk_list), axis=0)
-    print("#" * 20)
 
     return df