diff --git a/Linker/Linker.py b/Linker/Linker.py
index 5002f9c7742b545b9dcf445558b44441ff182b8c..80fbfa160c4010c6d7cecb772786a3fb6fc37228 100644
--- a/Linker/Linker.py
+++ b/Linker/Linker.py
@@ -28,7 +28,7 @@ class Linker(Module):
     
     # region initialization
 
-    def __init__(self, supertagger_path_model):
+    def __init__(self):
         super(Linker, self).__init__()
 
         # region parameters
@@ -58,12 +58,6 @@ class Linker(Module):
         self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
         # endregion
 
-        # SuperTagger for categories
-        supertagger = SuperTagger()
-        supertagger.load_weights(supertagger_path_model)
-        self.Supertagger = supertagger
-        self.Supertagger.model.to(self.device)
-
         # Atoms embedding
         self.atoms_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence)
         self.atom_map_redux = atom_map_redux
@@ -118,53 +112,6 @@ class Linker(Module):
 
     #endregion
 
-    # region data
-
-    def __preprocess_data(self, batch_size, df_axiom_links, validation_rate=0.1):
-        r"""
-        Args:
-            batch_size : int
-            df_axiom_links pandas DataFrame
-            validation_rate
-        Returns:
-            the training dataloader and the validation dataloader. They contains the list of atoms, their polarities, the axiom links, the sentences tokenized, sentence mask
-        """
-        print("Start preprocess Data")
-        sentences_batch = df_axiom_links["X"].str.strip().tolist()
-        sentences_tokens, sentences_mask = self.Supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch)
-
-        atoms_batch, polarities, num_atoms_per_word = get_GOAL(self.max_len_sentence, df_axiom_links)
-        atoms_polarity_batch = pad_sequence(
-            [torch.as_tensor(polarities[i], dtype=torch.bool) for i in range(len(polarities))],
-            max_len=self.max_atoms_in_sentence, padding_value=0)
-        atoms_batch_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms_batch)
-
-        pos_idx = get_pos_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type)
-        neg_idx = get_neg_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type)
-
-        truth_links_batch = get_axiom_links(self.max_atoms_in_one_type, atoms_polarity_batch,
-                                            df_axiom_links["Y"])
-        truth_links_batch = truth_links_batch.permute(1, 0, 2)
-
-        # Construction tensor dataset
-        dataset = TensorDataset(num_atoms_per_word, atoms_batch_tokenized, pos_idx, neg_idx, truth_links_batch,
-                                sentences_tokens, sentences_mask)
-
-        if validation_rate > 0.0:
-            train_size = int(0.9 * len(dataset))
-            val_size = len(dataset) - train_size
-            train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
-            validation_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
-        else:
-            validation_dataloader = None
-            train_dataset = dataset
-
-        training_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
-        print("End preprocess Data")
-        return training_dataloader, validation_dataloader
-
-    #endregion
-
     # region training
 
     def make_sinkhorn_inputs(self, bsd_tensor, positional_ids, atom_type):
@@ -229,56 +176,7 @@ class Linker(Module):
 
         return F.log_softmax(link_weights, dim=3)
 
-    def train_linker(self, df_axiom_links, validation_rate=0.1, epochs=20,
-                     batch_size=32, checkpoint=True, tensorboard=False):
-        r"""
-        Args:
-            df_axiom_links : pandas dataFrame containing the atoms anoted with _i
-            validation_rate : float
-            epochs : int
-            batch_size : int
-            checkpoint : boolean
-            tensorboard : boolean
-        Returns:
-            Final accuracy and final loss
-        """
-        training_dataloader, validation_dataloader = self.__preprocess_data(batch_size, df_axiom_links,
-                                                                            validation_rate)
-        if checkpoint or tensorboard:
-            checkpoint_dir, writer = output_create_dir()
-
-        for epoch_i in range(epochs):
-            print("")
-            print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs))
-            print('Training...')
-            avg_train_loss, avg_accuracy_train, training_time = self.train_epoch(training_dataloader)
-
-            print("")
-            print(f'Epoch: {epoch_i + 1:02} | Epoch Time: {training_time}')
-            print(f'\tTrain Loss: {avg_train_loss:.3f} | Train Acc: {avg_accuracy_train * 100:.2f}%')
-
-            if validation_rate > 0.0:
-                loss_test, accuracy_test = self.eval_epoch(validation_dataloader)
-                print(f'\tVal Loss: {loss_test:.3f} | Val Acc: {accuracy_test * 100:.2f}%')
-
-            if checkpoint:
-                self.__checkpoint_save(
-                    path=os.path.join("Output", 'linker.pt'))
-
-            if tensorboard:
-                writer.add_scalars(f'Accuracy', {
-                    'Train': avg_accuracy_train}, epoch_i)
-                writer.add_scalars(f'Loss', {
-                    'Train': avg_train_loss}, epoch_i)
-                if validation_rate > 0.0:
-                    writer.add_scalars(f'Accuracy', {
-                        'Validation': accuracy_test}, epoch_i)
-                    writer.add_scalars(f'Loss', {
-                        'Validation': loss_test}, epoch_i)
-
-            print('\n')
-
-    def train_epoch(self, training_dataloader):
+    def train_epoch(self, training_dataloader, Supertagger):
         r""" Train epoch
 
         Args:
@@ -309,12 +207,11 @@ class Linker(Module):
                 self.optimizer.zero_grad()
 
                 # get sentence embedding from BERT which is already trained
-                output = self.Supertagger.forward(batch_sentences_tokens, batch_sentences_mask)
+                output = Supertagger.forward(batch_sentences_tokens, batch_sentences_mask)
 
                 # Run the Linker on the atoms
                 logits_predictions = self(batch_num_atoms, batch_atoms_tok, batch_pos_idx, batch_neg_idx,
                                           output['word_embedding'])
-
                 linker_loss = self.cross_entropy_loss(logits_predictions, batch_true_links)
                 # Perform a backward pass to calculate the gradients.
                 epoch_loss += float(linker_loss)
@@ -342,33 +239,7 @@ class Linker(Module):
 
     # region evaluation
 
-    def eval_batch(self, batch):
-        batch_num_atoms = batch[0].to(self.device)
-        batch_atoms_tok = batch[1].to(self.device)
-        batch_pos_idx = batch[2].to(self.device)
-        batch_neg_idx = batch[3].to(self.device)
-        batch_true_links = batch[4].to(self.device)
-        batch_sentences_tokens = batch[5].to(self.device)
-        batch_sentences_mask = batch[6].to(self.device)
-
-        output = self.Supertagger.forward(batch_sentences_tokens, batch_sentences_mask)
-
-        logits_predictions = self(batch_num_atoms, batch_atoms_tok, batch_pos_idx, batch_neg_idx, output[
-            'word_embedding'])  # atom_vocab, batch_size, max atoms in one type, max atoms in one type
-        axiom_links_pred = torch.argmax(logits_predictions, dim=3)  # atom_vocab, batch_size, max atoms in one type
-
-        print('\n')
-        print(batch_true_links)
-        print("Les vrais liens de la catégorie n : ", batch_true_links[0][2][:100])
-        print("Les prédictions : ", axiom_links_pred[2][0][:100])
-        print('\n')
-
-        accuracy = measure_accuracy(batch_true_links, axiom_links_pred)
-        loss = self.cross_entropy_loss(logits_predictions, batch_true_links)
-
-        return loss, accuracy
-
-    def eval_epoch(self, dataloader):
+    def eval_epoch(self, dataloader, Supertagger):
         r"""Average the evaluation of all the batch.
 
         Args:
@@ -379,107 +250,25 @@ class Linker(Module):
         loss_average = 0
         with torch.no_grad():
             for step, batch in enumerate(dataloader):
-                loss, accuracy = self.eval_batch(batch)
-                accuracy_average += accuracy
-                loss_average += float(loss)
-
-        return loss_average / len(dataloader), accuracy_average / len(dataloader)
-
-    #endregion
-
-    #region prediction 
-
-    def predict_with_categories(self, sentence, categories):
-        r""" Predict the links from a sentence and its categories
-
-        Args :
-            sentence : list of words composing the sentence
-            categories : list of categories (tags) of each word
-
-        Return :
-            links : links prediction
-        """
-        self.eval()
-        with torch.no_grad():
-            self.cpu()
-            self.device = torch.device("cpu")
-            sentences_tokens, sentences_mask = self.Supertagger.sent_tokenizer.fit_transform_tensors([sentence])
-            nb_sentence, len_sentence = sentences_tokens.shape
-
-            atoms = get_atoms_batch([categories])
-            atoms_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms)
-
-            polarities = find_pos_neg_idexes([categories])
-            polarities = pad_sequence(
-                [torch.as_tensor(polarities[i], dtype=torch.bool) for i in range(len(polarities))],
-                max_len=self.max_atoms_in_sentence, padding_value=0)
-
-            num_atoms_per_word = get_num_atoms_batch([categories], len_sentence)
-
-            pos_idx = get_pos_idx(atoms, polarities, self.max_atoms_in_one_type)
-            neg_idx = get_neg_idx(atoms, polarities, self.max_atoms_in_one_type)
-
-            output = self.Supertagger.forward(sentences_tokens, sentences_mask)
-
-            logits_predictions = self(num_atoms_per_word, atoms_tokenized, pos_idx, neg_idx, output['word_embedding'])
-            axiom_links_pred = torch.argmax(logits_predictions, dim=3)
-
-        return axiom_links_pred
-
-    def predict_without_categories(self, sentence):
-        r""" Predict the links from a sentence
-
-        Args :
-            sentence : list of words composing the sentence
-
-        Return :
-            categories : the supertags predicted
-            links : links prediction
-        """
-        self.eval()
-        with torch.no_grad():
-            self.cpu()
-            self.device = torch.device("cpu")
-            sentences_tokens, sentences_mask = self.Supertagger.sent_tokenizer.fit_transform_tensors(sentence)
-            nb_sentence, len_sentence = sentences_tokens.shape
-
-            hidden_state, categories = self.Supertagger.predict(sentence)
+                batch_num_atoms = batch[0].to(self.device)
+                batch_atoms_tok = batch[1].to(self.device)
+                batch_pos_idx = batch[2].to(self.device)
+                batch_neg_idx = batch[3].to(self.device)
+                batch_true_links = batch[4].to(self.device)
+                batch_sentences_tokens = batch[5].to(self.device)
+                batch_sentences_mask = batch[6].to(self.device)
 
-            output = self.Supertagger.forward(sentences_tokens, sentences_mask)
-            atoms = get_atoms_batch(categories)
-            atoms_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms)
+                output = Supertagger.forward(batch_sentences_tokens, batch_sentences_mask)
 
-            polarities = find_pos_neg_idexes(categories)
-            polarities = pad_sequence(
-                [torch.as_tensor(polarities[i], dtype=torch.bool) for i in range(len(polarities))],
-                max_len=self.max_atoms_in_sentence, padding_value=0)
+                logits_predictions = self(batch_num_atoms, batch_atoms_tok, batch_pos_idx, batch_neg_idx, output['word_embedding'])  # atom_vocab, batch_size, max atoms in one type, max atoms in one type
+                axiom_links_pred = torch.argmax(logits_predictions, dim=3)  # atom_vocab, batch_size, max atoms in one type
 
-            num_atoms_per_word = get_num_atoms_batch(categories, len_sentence)
+                accuracy = measure_accuracy(batch_true_links, axiom_links_pred)
+                loss = self.cross_entropy_loss(logits_predictions, batch_true_links)
 
-            pos_idx = get_pos_idx(atoms, polarities, self.max_atoms_in_one_type)
-            neg_idx = get_neg_idx(atoms, polarities, self.max_atoms_in_one_type)
+                accuracy_average += accuracy
+                loss_average += float(loss)
 
-            logits_predictions = self(num_atoms_per_word, atoms_tokenized, pos_idx, neg_idx, output['word_embedding'])
-            axiom_links_pred = torch.argmax(logits_predictions, dim=3)
+        return loss_average / len(dataloader), accuracy_average / len(dataloader)
 
-        return categories, axiom_links_pred
-    
     #endregion
-
-    def __checkpoint_save(self, path='/linker.pt'):
-        """
-        @param path:
-        """
-        self.cpu()
-
-        torch.save({
-            'atom_encoder': self.atom_encoder.state_dict(),
-            'position_encoder': self.position_encoder.state_dict(),
-            'transformer': self.transformer.state_dict(),
-            'linker_encoder': self.linker_encoder.state_dict(),
-            'pos_transformation': self.pos_transformation.state_dict(),
-            'neg_transformation': self.neg_transformation.state_dict(),
-            'cross_entropy_loss': self.cross_entropy_loss.state_dict(),
-            'optimizer': self.optimizer,
-        }, path)
-        self.to(self.device)
diff --git a/Linker/eval.py b/Linker/eval.py
index b252e5e5b07fbfe1dd8cacf7e1486b0b5850c34b..fe7f347ba27c82d70b3b64219d7b60f873bd9be3 100644
--- a/Linker/eval.py
+++ b/Linker/eval.py
@@ -1,6 +1,9 @@
+import numpy as np
+
 import torch
 from torch.nn import Module
 from torch.nn.functional import nll_loss
+
 from Linker.atom_map import atom_map, atom_map_redux
 
 
@@ -12,8 +15,16 @@ class SinkhornLoss(Module):
         super(SinkhornLoss, self).__init__()
 
     def forward(self, predictions, truths):
-        return sum(nll_loss(link.flatten(0, 1), perm.flatten(), reduction='mean', ignore_index=-1)
-                   for link, perm in zip(predictions, truths.permute(1, 0, 2)))
+        sum = 0 
+        # for each categorie of atom (txt, np ...)
+        for link, perm in zip(predictions, truths.permute(1, 0, 2)):
+            # test if there are true links in this categorie
+            if 0 in perm.flatten():
+                # mean nll loss of the categorie current calculated on the whole batch
+                it = nll_loss(link.flatten(0, 1), perm.flatten(), reduction='mean', ignore_index=-1)
+                # sum it to the current total loss 
+                sum+=it
+        return sum
 
 
 def measure_accuracy(batch_true_links, axiom_links_pred):
diff --git a/NeuralProofNet/NeuralProofNet.py b/NeuralProofNet/NeuralProofNet.py
index ab12693a359aade06103a47a8a6483e38907c5bf..41ee516775c9d9ab19cdb098d6b735d3823b748c 100644
--- a/NeuralProofNet/NeuralProofNet.py
+++ b/NeuralProofNet/NeuralProofNet.py
@@ -8,13 +8,14 @@ from torch.utils.data import TensorDataset, random_split
 from tqdm import tqdm
 
 from Configuration import Configuration
+from NeuralProofNet.utils_proofnet import get_info_for_tagger
+from SuperTagger import SuperTagger
 from Linker import Linker
 from Linker.eval import measure_accuracy, SinkhornLoss
-from Linker.utils_linker import get_axiom_links, get_GOAL, get_pos_idx, get_num_atoms_batch, get_neg_idx
-from NeuralProofNet.utils_proofnet import get_info_for_tagger
+from Linker.utils_linker import FFN, get_axiom_links, get_GOAL, get_pos_idx, get_neg_idx, get_atoms_batch, \
+    find_pos_neg_idexes, get_num_atoms_batch, generate_square_subsequent_mask
 from utils import pad_sequence, format_time, output_create_dir
 
-
 class NeuralProofNet(Module):
     
     def __init__(self, supertagger_path_model, linker_path_model=None):
@@ -28,7 +29,13 @@ class NeuralProofNet(Module):
         self.max_atoms_in_one_type = int(datasetConfig['max_atoms_in_one_type'])
         self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
-        linker = Linker(supertagger_path_model)
+        # SuperTagger for categories
+        supertagger = SuperTagger()
+        supertagger.load_weights(supertagger_path_model)
+        self.Supertagger = supertagger
+        self.Supertagger.model.to(self.device)
+
+        linker = Linker()
         if linker_path_model is not None:
             linker.load_weights(linker_path_model)
         self.linker = linker
@@ -41,12 +48,6 @@ class NeuralProofNet(Module):
 
         self.to(self.device)
 
-    def __pretrain_linker__(self, df_axiom_links, pretrain_linker_epochs, batch_size, checkpoint=False, tensorboard=True):
-        print("\nLinker Pre-Training\n")
-        self.linker.train_linker(df_axiom_links, validation_rate=0.05, epochs=pretrain_linker_epochs,
-                                 batch_size=batch_size, checkpoint=checkpoint, tensorboard=tensorboard)
-        print("\nEND Linker Pre-Training\n")
-
     def __preprocess_data(self, batch_size, df_axiom_links, validation_rate=0.1):
         r"""
         Args:
@@ -54,26 +55,31 @@ class NeuralProofNet(Module):
             df_axiom_links pandas DataFrame
             validation_rate
         Returns:
-            the training dataloader and the validation dataloader. They contain the list of atoms, their polarities, the axiom links, the sentences tokenized, sentence mask
+            the training dataloader and the validation dataloader. They contains the list of atoms, their polarities, the axiom links, the sentences tokenized, sentence mask
         """
         print("Start preprocess Data")
         sentences_batch = df_axiom_links["X"].str.strip().tolist()
-        sentences_tokens, sentences_mask = self.linker.Supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch)
+        sentences_tokens, sentences_mask = self.Supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch)
 
-        _, polarities, _ = get_GOAL(self.max_len_sentence, df_axiom_links)
+        atoms_batch, polarities, num_atoms_per_word = get_GOAL(self.max_len_sentence, df_axiom_links)
         atoms_polarity_batch = pad_sequence(
             [torch.as_tensor(polarities[i], dtype=torch.bool) for i in range(len(polarities))],
             max_len=self.max_atoms_in_sentence, padding_value=0)
+        atoms_batch_tokenized = self.linker.atoms_tokenizer.convert_batchs_to_ids(atoms_batch)
 
-        truth_links_batch = get_axiom_links(self.max_atoms_in_one_type, atoms_polarity_batch,
+        pos_idx = get_pos_idx(atoms_batch, atoms_polarity_batch, self.linker.max_atoms_in_one_type)
+        neg_idx = get_neg_idx(atoms_batch, atoms_polarity_batch, self.linker.max_atoms_in_one_type)
+
+        truth_links_batch = get_axiom_links(self.linker.max_atoms_in_one_type, atoms_polarity_batch,
                                             df_axiom_links["Y"])
         truth_links_batch = truth_links_batch.permute(1, 0, 2)
 
         # Construction tensor dataset
-        dataset = TensorDataset(truth_links_batch, sentences_tokens, sentences_mask)
+        dataset = TensorDataset(num_atoms_per_word, atoms_batch_tokenized, pos_idx, neg_idx, truth_links_batch,
+                                sentences_tokens, sentences_mask)
 
         if validation_rate > 0.0:
-            train_size = int(0.9 * len(dataset))
+            train_size = int((1-validation_rate) * len(dataset))
             val_size = len(dataset) - train_size
             train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
             validation_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
@@ -85,13 +91,15 @@ class NeuralProofNet(Module):
         print("End preprocess Data")
         return training_dataloader, validation_dataloader
 
+    # region training
+
     def forward(self, batch_sentences_tokens, batch_sentences_mask):
 
         # get sentence embedding from BERT which is already trained
-        output = self.linker.Supertagger.forward(batch_sentences_tokens, batch_sentences_mask)
+        output = self.Supertagger.forward(batch_sentences_tokens, batch_sentences_mask)
         last_hidden_state = output['logit']
         pred_categories = torch.argmax(torch.softmax(last_hidden_state, dim=2), dim=2)
-        pred_categories = self.linker.Supertagger.tags_tokenizer.convert_ids_to_tags(pred_categories)
+        pred_categories = self.Supertagger.tags_tokenizer.convert_ids_to_tags(pred_categories)
 
         # get information from tagger predictions
         atoms_batch, polarities, batch_num_atoms_per_word = get_info_for_tagger(self.max_len_sentence, pred_categories)
@@ -112,6 +120,49 @@ class NeuralProofNet(Module):
 
         return torch.log_softmax(logits_links, dim=3)
 
+    def pretrain_linker(self, training_dataloader, validation_dataloader, pretrain_linker_epochs, checkpoint=None, writer=None):
+        r"""
+        Args:
+            df_axiom_links : pandas dataFrame containing the atoms anoted with _i
+            validation_rate : float
+            epochs : int
+            batch_size : int
+            checkpoint : boolean
+            tensorboard : boolean
+        Returns:
+            Final accuracy and final loss
+        """
+
+        for epoch_i in range(pretrain_linker_epochs):
+            print("")
+            print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, pretrain_linker_epochs))
+            print('Training...')
+            avg_train_loss, avg_accuracy_train, training_time = self.linker.train_epoch(training_dataloader, self.Supertagger)
+
+            print("")
+            print(f'Epoch: {epoch_i + 1:02} | Epoch Time: {training_time}')
+            print(f'\tTrain Loss: {avg_train_loss:.3f} | Train Acc: {avg_accuracy_train * 100:.2f}%')
+
+            if validation_dataloader:
+                loss_test, accuracy_test = self.linker.eval_epoch(validation_dataloader, self.Supertagger)
+                print(f'\tVal Loss: {loss_test:.3f} | Val Acc: {accuracy_test * 100:.2f}%')
+
+            if checkpoint:
+                self.__checkpoint_save(path='Output/linker.pt')
+
+            if writer:
+                writer.add_scalars(f'Accuracy', {
+                    'Train': avg_accuracy_train}, epoch_i)
+                writer.add_scalars(f'Loss', {
+                    'Train': avg_train_loss}, epoch_i)
+                if validation_dataloader :
+                    writer.add_scalars(f'Accuracy', {
+                        'Validation': accuracy_test}, epoch_i)
+                    writer.add_scalars(f'Loss', {
+                        'Validation': loss_test}, epoch_i)
+
+            print('\n')
+
     def train_neuralproofnet(self, df_axiom_links, validation_rate=0.1, epochs=20, pretrain_linker_epochs=0, 
                              batch_size=32, checkpoint=True, tensorboard=False):
         r"""
@@ -125,15 +176,20 @@ class NeuralProofNet(Module):
         Returns:
             Final accuracy and final loss
         """
-        # Pretrain the linker
-        self.__pretrain_linker__(df_axiom_links, pretrain_linker_epochs, batch_size)
-
         # Start learning with output from tagger
         training_dataloader, validation_dataloader = self.__preprocess_data(batch_size, df_axiom_links,
                                                                             validation_rate)
         if checkpoint or tensorboard:
             checkpoint_dir, writer = output_create_dir()
 
+        # Pretrain the linker with the rights categories
+        if pretrain_linker_epochs >0 :
+            print("\nLinker Pre-Training\n")
+            self.pretrain_linker(training_dataloader, validation_dataloader, \
+                                    pretrain_linker_epochs, checkpoint, writer)
+            print("\nEND Linker Pre-Training\n")
+
+        # Train Linker with predicted categories from supertagger
         for epoch_i in range(epochs):
             print("")
             print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs))
@@ -153,14 +209,14 @@ class NeuralProofNet(Module):
 
             if tensorboard:
                 writer.add_scalars(f'Accuracy', {
-                    'Train': avg_accuracy_train}, epoch_i)
+                    'Train': avg_accuracy_train}, pretrain_linker_epochs + epoch_i)
                 writer.add_scalars(f'Loss', {
-                    'Train': avg_train_loss}, epoch_i)
+                    'Train': avg_train_loss}, pretrain_linker_epochs + epoch_i)
                 if validation_rate > 0.0:
                     writer.add_scalars(f'Accuracy', {
-                        'Validation': accuracy_test}, epoch_i)
+                        'Validation': accuracy_test}, pretrain_linker_epochs + epoch_i)
                     writer.add_scalars(f'Loss', {
-                        'Validation': loss_test}, epoch_i)
+                        'Validation': loss_test}, pretrain_linker_epochs + epoch_i)
 
             print('\n')
 
@@ -184,9 +240,9 @@ class NeuralProofNet(Module):
         with tqdm(training_dataloader, unit="batch") as tepoch:
             for batch in tepoch:
                 # Unpack this training batch from our dataloader
-                batch_true_links = batch[0].to(self.device)
-                batch_sentences_tokens = batch[1].to(self.device)
-                batch_sentences_mask = batch[2].to(self.device)
+                batch_true_links = batch[4].to(self.device)
+                batch_sentences_tokens = batch[5].to(self.device)
+                batch_sentences_mask = batch[6].to(self.device)
 
                 self.linker_optimizer.zero_grad()
 
@@ -215,25 +271,10 @@ class NeuralProofNet(Module):
         avg_accuracy_train = accuracy_train / len(training_dataloader)
 
         return avg_train_loss, avg_accuracy_train, training_time
+    
+    #endregion
 
-    def eval_batch(self, batch):
-        batch_true_links = batch[0].to(self.device)
-        batch_sentences_tokens = batch[1].to(self.device)
-        batch_sentences_mask = batch[2].to(self.device)
-
-        logits_predictions_links = self(batch_sentences_tokens, batch_sentences_mask)
-        axiom_links_pred = torch.argmax(logits_predictions_links,
-                                        dim=3)  # atom_vocab, batch_size, max atoms in one type
-
-        print('\n')
-        print("Les vrais liens de la catégorie n : ", batch_true_links[0][2][:100])
-        print("Les prédictions : ", axiom_links_pred[2][0][:100])
-        print('\n')
-
-        accuracy = measure_accuracy(batch_true_links, axiom_links_pred)
-        linker_loss = self.linker_loss(logits_predictions_links, batch_true_links)
-
-        return linker_loss, accuracy
+    # region evaluation
 
     def eval_epoch(self, dataloader):
         r"""Average the evaluation of all the batch.
@@ -246,12 +287,24 @@ class NeuralProofNet(Module):
         loss_average = 0
         with torch.no_grad():
             for step, batch in enumerate(dataloader):
-                loss, accuracy = self.eval_batch(batch)
+                batch_true_links = batch[4].to(self.device)
+                batch_sentences_tokens = batch[5].to(self.device)
+                batch_sentences_mask = batch[6].to(self.device)
+
+                logits_predictions_links = self(batch_sentences_tokens, batch_sentences_mask)
+                axiom_links_pred = torch.argmax(logits_predictions_links,
+                                                dim=3)  # atom_vocab, batch_size, max atoms in one type
+
+                accuracy = measure_accuracy(batch_true_links, axiom_links_pred)
+                linker_loss = self.linker_loss(logits_predictions_links, batch_true_links)
+
                 accuracy_average += accuracy
-                loss_average += float(loss)
+                loss_average += float(linker_loss)
 
         return loss_average / len(dataloader), accuracy_average / len(dataloader)
 
+    #endregion
+
     def __checkpoint_save(self, path='/linker.pt'):
         """
         @param path:
@@ -268,4 +321,83 @@ class NeuralProofNet(Module):
             'cross_entropy_loss': self.linker_loss.state_dict(),
             'optimizer': self.linker_optimizer,
         }, path)
-        self.to(self.device)
\ No newline at end of file
+        self.to(self.device)
+
+    #region prediction 
+
+    def predict_with_categories(self, sentence, categories):
+        r""" Predict the links from a sentence and its categories
+
+        Args :
+            sentence : list of words composing the sentence
+            categories : list of categories (tags) of each word
+
+        Return :
+            links : links prediction
+        """
+        self.eval()
+        with torch.no_grad():
+            self.cpu()
+            self.device = torch.device("cpu")
+            sentences_tokens, sentences_mask = self.Supertagger.sent_tokenizer.fit_transform_tensors(sentence)
+            nb_sentence, len_sentence = sentences_tokens.shape
+
+            atoms = get_atoms_batch(categories)
+            atoms_tokenized = self.linker.atoms_tokenizer.convert_batchs_to_ids(atoms)
+
+            polarities = find_pos_neg_idexes(categories)
+            polarities = pad_sequence(
+                [torch.as_tensor(polarities[i], dtype=torch.bool) for i in range(len(polarities))],
+                max_len=self.max_atoms_in_sentence, padding_value=0)
+
+            num_atoms_per_word = get_num_atoms_batch(categories, len_sentence)
+
+            pos_idx = get_pos_idx(atoms, polarities, self.max_atoms_in_one_type)
+            neg_idx = get_neg_idx(atoms, polarities, self.max_atoms_in_one_type)
+
+            output = self.Supertagger.forward(sentences_tokens, sentences_mask)
+
+            logits_predictions = self.linker(num_atoms_per_word, atoms_tokenized, pos_idx, neg_idx, output['word_embedding'])
+            axiom_links_pred = torch.argmax(logits_predictions, dim=3)
+
+        return axiom_links_pred
+
+    def predict_without_categories(self, sentence):
+        r""" Predict the links from a sentence
+
+        Args :
+            sentence : list of words composing the sentence
+
+        Return :
+            categories : the supertags predicted
+            links : links prediction
+        """
+        self.eval()
+        with torch.no_grad():
+            self.cpu()
+            self.device = torch.device("cpu")
+            sentences_tokens, sentences_mask = self.Supertagger.sent_tokenizer.fit_transform_tensors(sentence)
+            nb_sentence, len_sentence = sentences_tokens.shape
+
+            hidden_state, categories = self.Supertagger.predict(sentence)
+
+            output = self.Supertagger.forward(sentences_tokens, sentences_mask)
+            atoms = get_atoms_batch(categories)
+            atoms_tokenized = self.linker.atoms_tokenizer.convert_batchs_to_ids(atoms)
+
+            polarities = find_pos_neg_idexes(categories)
+            polarities = pad_sequence(
+                [torch.as_tensor(polarities[i], dtype=torch.bool) for i in range(len(polarities))],
+                max_len=self.max_atoms_in_sentence, padding_value=0)
+
+            num_atoms_per_word = get_num_atoms_batch(categories, len_sentence)
+
+            pos_idx = get_pos_idx(atoms, polarities, self.max_atoms_in_one_type)
+            neg_idx = get_neg_idx(atoms, polarities, self.max_atoms_in_one_type)
+
+            logits_predictions = self.linker(num_atoms_per_word, atoms_tokenized, pos_idx, neg_idx, output['word_embedding'])
+            axiom_links_pred = torch.argmax(logits_predictions, dim=3)
+
+        return categories, axiom_links_pred
+    
+    #endregion
\ No newline at end of file
diff --git a/predict_links.py b/predict_links.py
index 4e87a1f5fae1892530e929dbf778e648531fa8dc..b52982ad736d0afadaf4e9a61252c9940ebee742 100644
--- a/predict_links.py
+++ b/predict_links.py
@@ -4,22 +4,21 @@ from postprocessing import draw_sentence_output
 if __name__== '__main__':
       # region data
       a_s = ["( 1 ) parmi les huit \" partants \" acquis ou potentiels , MM. Lacombe , Koehler et Laroze ne sont pas membres du PCF ."]
-      tags_s = ['let', 'dr(0,s,s)', 'let', 'dr(0,dr(0,s,s),np)', 'dr(0,np,n)', 'dr(0,n,n)', 'let', 'n', 'let', 'dl(0,n,n)',
+      tags_s = [['let', 'dr(0,s,s)', 'let', 'dr(0,dr(0,s,s),np)', 'dr(0,np,n)', 'dr(0,n,n)', 'let', 'n', 'let', 'dl(0,n,n)',
             'dr(0,dl(0,dl(0,n,n),dl(0,n,n)),dl(0,n,n))', 'dl(0,n,n)', 'let', 'dr(0,np,np)', 'np', 'dr(0,dl(0,np,np),np)',
             'np', 'dr(0,dl(0,np,np),np)', 'np', 'dr(0,dl(0,np,s),dl(0,np,s))', 'dr(0,dl(0,np,s),np)', 'dl(1,s,s)', 'np',
-            'dr(0,dl(0,np,np),n)', 'n', 'dl(0,s,txt)']
+            'dr(0,dl(0,np,np),n)', 'n', 'dl(0,s,txt)']]
       # endregion
 
-
       # region model
       model_tagger = "models/flaubert_super_98_V2_50e.pt"
       neuralproofnet = NeuralProofNet(model_tagger)
-      model = "Output/linker.pt"
+      model = "Output/saved_linker.pt"
       neuralproofnet.linker.load_weights(model)
       # endregion
 
-      linker = neuralproofnet.linker
-      categories, links = linker.predict_without_categories(a_s)
-      #links = linker.predict_with_categories(a_s, tags_s)
+      #categories, links = neuralproofnet.predict_without_categories(a_s)
+      links = neuralproofnet.predict_with_categories(a_s, tags_s)
+      
       idx=0
-      draw_sentence_output(a_s[idx].split(" "), categories[idx], links[:,idx,:].numpy())
+      draw_sentence_output(a_s[idx].split(" "), tags_s[idx], links[:,idx,:].numpy())
diff --git a/train_neuralproofnet.py b/train_neuralproofnet.py
index ce393adfa39c6bfa8b5b52e3d3888453ec892e52..f3c2284aa528b4bba9c470458af099e1cc027780 100644
--- a/train_neuralproofnet.py
+++ b/train_neuralproofnet.py
@@ -6,8 +6,8 @@ torch.cuda.empty_cache()
 
 
 # region data
-file_path_axiom_links = 'Datasets/goldANDsilver_dataset_links.csv'
-df_axiom_links = read_links_csv(file_path_axiom_links)
+file_path_axiom_links = 'Datasets/gold_dataset_links.csv'
+df_axiom_links = read_links_csv(file_path_axiom_links)[:32]
 # endregion
 
 
@@ -16,7 +16,7 @@ print("#" * 20)
 print("#" * 20)
 model_tagger = "models/flaubert_super_98_V2_50e.pt"
 neural_proof_net = NeuralProofNet(model_tagger)
-neural_proof_net.train_neuralproofnet(df_axiom_links, validation_rate=0.1, epochs=25, pretrain_linker_epochs=20, batch_size=16,
+neural_proof_net.train_neuralproofnet(df_axiom_links, validation_rate=0, epochs=5, pretrain_linker_epochs=5, batch_size=16,
                                       checkpoint=True, tensorboard=True)
 print("#" * 20)
 print("#" * 20)