diff --git a/Linker/Linker.py b/Linker/Linker.py
index f65325eef04bc224f025f209f99f9d1a6f653207..ca3a8dcf4b6f00d62a315a2645ee4c59296828f8 100644
--- a/Linker/Linker.py
+++ b/Linker/Linker.py
@@ -1,16 +1,24 @@
+import os
+from datetime import datetime
+
 import torch
 from torch.nn import Sequential, LayerNorm, Dropout
 from torch.nn import Module
 import torch.nn.functional as F
 import sys
+
+from torch.optim import AdamW
+from torch.utils.data import TensorDataset, random_split
+from transformers import get_cosine_schedule_with_warmup
+
 from Configuration import Configuration
 from AtomEmbedding import AtomEmbedding
 from AtomTokenizer import AtomTokenizer
 from MHA import AttentionDecoderLayer
 from atom_map import atom_map
 from Sinkhorn import sinkhorn_fn_no_exp as sinkhorn
-from utils_linker import find_pos_neg_idexes, get_atoms_batch, FFN
-from eval import mesure_accuracy
+from utils_linker import find_pos_neg_idexes, get_atoms_batch, FFN, get_axiom_links
+from eval import mesure_accuracy, SinkhornLoss
 from ..utils import pad_sequence
 
 
@@ -27,6 +35,10 @@ class Linker(Module):
         self.max_atoms_in_sentence = int(Configuration.datasetConfig['max_atoms_in_sentence'])
         self.max_atoms_in_one_type = int(Configuration.datasetConfig['max_atoms_in_one_type'])
         self.atom_vocab_size = int(Configuration.datasetConfig['atom_vocab_size'])
+        batch_size = int(Configuration.modelTrainingConfig['batch_size'])
+        nb_sentences = batch_size * 10
+        self.epochs = int(Configuration.modelTrainingConfig['epoch'])
+        learning_rate = float(Configuration.modelTrainingConfig['learning_rate'])
         self.dropout = Dropout(0.1)
         self.device = ""
 
@@ -47,6 +59,41 @@ class Linker(Module):
             LayerNorm(self.dim_embedding_atoms, eps=1e-12)
         )
 
+        self.cross_entropy_loss = SinkhornLoss()
+        self.optimizer = AdamW(self.parameters(),
+                               weight_decay=1e-5,
+                               lr=learning_rate)
+        self.scheduler = get_cosine_schedule_with_warmup(self.optimizer,
+                                                         num_warmup_steps=0,
+                                                         num_training_steps=100)
+
+    def __preprocess_data(self, batch_size, df_axiom_links, validation_rate=0.0):
+        atoms_batch = get_atoms_batch(df_axiom_links["sub_tree"])
+        atom_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence)
+        atoms_batch_tokenized = atom_tokenizer.convert_batchs_to_ids(atoms_batch)
+
+        atoms_polarity_batch = find_pos_neg_idexes(self.max_atoms_in_sentence, df_axiom_links["sub_tree"])
+
+        torch.set_printoptions(edgeitems=20)
+        truth_links_batch = get_axiom_links(self.max_atoms_in_one_type, atoms_polarity_batch,
+                                            df_axiom_links["sub_tree"])
+        truth_links_batch = truth_links_batch.permute(1, 0, 2)
+
+        # Construction tensor dataset
+        dataset = TensorDataset(atoms_batch_tokenized, atoms_polarity_batch, truth_links_batch)
+
+        if validation_rate > 0:
+            train_size = int(0.9 * len(dataset))
+            val_size = len(dataset) - train_size
+            train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
+            validation_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
+        else:
+            validation_dataloader = None
+            train_dataset = dataset
+
+        training_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
+        return training_dataloader, validation_dataloader
+
     def make_decoder_mask(self, atoms_token):
         decoder_attn_mask = torch.ones_like(atoms_token, dtype=torch.float64)
         decoder_attn_mask[atoms_token.eq(self.padding_id)] = 0.0
@@ -101,6 +148,63 @@ class Linker(Module):
 
         return torch.stack(link_weights)
 
+    def train_linker(self, df_axiom_links, validation_rate=0.1, epochs=20, batch_size=32):
+
+        training_dataloader, validation_dataloader = self.__preprocess_data(batch_size, df_axiom_links, validation_rate)
+        epochs = epochs - self.epochs
+        self.train()
+
+        for epoch_i in range(0, epochs):
+            epoch_acc, epoch_loss = self.__train_epoch(training_dataloader, validation_dataloader)
+
+    def __train_epoch(self, training_dataloader, validation_dataloader, checkpoint=True, validate=True):
+
+        # Reset the total loss for this epoch.
+        epoch_loss = 0
+
+        self.train()
+
+        # For each batch of training data...
+        for step, batch in enumerate(training_dataloader):
+            # Unpack this training batch from our dataloader
+            batch_atoms = batch[0].to("cuda" if torch.cuda.is_available() else "cpu")
+            batch_polarity = batch[1].to("cuda" if torch.cuda.is_available() else "cpu")
+            batch_true_links = batch[2].to("cuda" if torch.cuda.is_available() else "cpu")
+            # batch_sentences = batch[3].to("cuda" if torch.cuda.is_available() else "cpu")
+
+            self.optimizer.zero_grad()
+
+            # get sentence embedding from BERT which is already trained
+            # sentences_embedding = supertagger(batch_sentences)
+
+            # Run the kinker on the categories predictions
+            logits_predictions = self(batch_atoms, batch_polarity, [])
+
+            linker_loss = self.cross_entropy_loss(logits_predictions.permute(1, 0, 2, 3), batch_true_links)
+            # Perform a backward pass to calculate the gradients.
+            epoch_loss += float(linker_loss)
+            linker_loss.backward()
+
+            # This is to help prevent the "exploding gradients" problem.
+            # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0, norm_type=2)
+
+            # Update parameters and take a step using the computed gradient.
+            self.optimizer.step()
+            self.scheduler.step()
+
+        avg_train_loss = epoch_loss / len(training_dataloader)
+
+        if checkpoint:
+            checkpoint_dir = os.path.join("Output", 'Tranning_' + datetime.today().strftime('%d-%m_%H-%M'))
+            self.__checkpoint_save(path=os.path.join(checkpoint_dir, 'model_check.pt'))
+
+        if validate:
+            self.eval()
+            with torch.no_grad():
+                accuracy, average_test_loss = self.eval_epoch(validation_dataloader, self.cross_entropy_loss)
+
+        return accuracy, avg_train_loss
+
     def predict(self, categories, sents_embedding, sents_mask=None):
         r'''
         Parameters :
diff --git a/Utils/PostpreprocesTXT.py b/Utils/PostpreprocesTXT.py
index a1848e235f0d308b462c16ecdf2d90248e50a1f1..eaa9d30efb4d8bf0ae844cbc7b174c237ffc5f0c 100644
--- a/Utils/PostpreprocesTXT.py
+++ b/Utils/PostpreprocesTXT.py
@@ -25,7 +25,7 @@ def sub_tree_line(line_with_data: str):
     for word_with_data in line_list:
         w, t = sub_tree_word(word_with_data)
         sentence += ' ' + w
-        if t not in ["\\", "/", "let"] and len(t)>0:
+        if t not in ["\\", "/", "let"] and len(t) > 0:
             sub_trees.append([t])
         """if ('ppp' in list(itertools.chain(*sub_trees))):
             print(sentence)"""
@@ -35,17 +35,9 @@ def sub_tree_line(line_with_data: str):
 def Txt_to_csv(file_name: str):
     file = open(file_name, "r", encoding="utf8")
     text = file.readlines()
-
     sub = [sub_tree_line(data) for data in text]
-
     df = pd.DataFrame(data=sub, columns=['Sentences', 'sub_tree'])
-
     df.to_csv("../Datasets/" + file_name[:-4] + "_dataset_links.csv", index=False)
 
 
 Txt_to_csv("aa1_links.txt")
-
-"""trees = df['sub_tree']
-trees_flat = set(list(itertools.chain(*list(itertools.chain(*trees)))))
-fruit_dictionary = dict(zip(list(trees_flat), range(len(list(trees_flat)))))
-print(fruit_dictionary)"""