diff --git a/SuperTagger/Decoder/RNNDecoderLayer.py b/SuperTagger/Decoder/RNNDecoderLayer.py
deleted file mode 100644
index 93e96a69dbc13e6129b22301007252f2b486eef1..0000000000000000000000000000000000000000
--- a/SuperTagger/Decoder/RNNDecoderLayer.py
+++ /dev/null
@@ -1,180 +0,0 @@
-import random
-
-import torch
-import torch.nn.functional as F
-from torch.nn import (Module, Dropout, Linear, LSTM)
-
-from Configuration import Configuration
-from SuperTagger.Symbol.SymbolEmbedding import SymbolEmbedding
-
-
-class RNNDecoderLayer(Module):
-    def __init__(self, symbols_map):
-        super(RNNDecoderLayer, self).__init__()
-
-        # init params
-        self.dim_encoder = int(Configuration.modelEncoderConfig['dim_encoder'])
-        self.dim_decoder = int(Configuration.modelDecoderConfig['dim_decoder'])
-        dropout = float(Configuration.modelDecoderConfig['dropout'])
-        self.num_rnn_layers = int(Configuration.modelDecoderConfig['num_rnn_layers'])
-        self.teacher_forcing = float(Configuration.modelDecoderConfig['teacher_forcing'])
-        self.max_len_sentence = int(Configuration.datasetConfig['max_len_sentence'])
-        self.symbols_vocab_size = int(Configuration.datasetConfig['symbols_vocab_size'])
-
-        self.bidirectional = False
-        self.use_attention = True
-        self.symbols_map = symbols_map
-        self.symbols_padding_id = self.symbols_map["[PAD]"]
-        self.symbols_sep_id = self.symbols_map["[SEP]"]
-        self.symbols_start_id = self.symbols_map["[START]"]
-        self.symbols_sos_id = self.symbols_map["[SOS]"]
-
-        # Different layers
-        # Symbols Embedding
-        self.symbols_embedder = SymbolEmbedding(self.dim_decoder, self.symbols_vocab_size,
-                                                padding_idx=self.symbols_padding_id)
-        # For hidden_state
-        self.dropout = Dropout(dropout)
-        # rnn Layer
-        if self.use_attention:
-            self.rnn = LSTM(input_size=self.dim_encoder, hidden_size=self.dim_encoder, num_layers=self.num_rnn_layers,
-                            dropout=dropout,
-                            bidirectional=self.bidirectional, batch_first=True)
-        else:
-            self.rnn = LSTM(input_size=self.dim_decoder, hidden_size=self.dim_encoder, num_layers=self.num_rnn_layers,
-                            dropout=dropout,
-                            bidirectional=self.bidirectional, batch_first=True)
-
-        # Projection on vocab_size
-        if self.bidirectional:
-            self.proj = Linear(self.dim_encoder * 2, self.symbols_vocab_size)
-        else:
-            self.proj = Linear(self.dim_encoder, self.symbols_vocab_size)
-
-        self.attn = Linear(self.dim_decoder + self.dim_encoder, self.max_len_sentence)
-        self.attn_combine = Linear(self.dim_decoder + self.dim_encoder, self.dim_encoder)
-
-    def sos_mask(self, y):
-        return torch.eq(y, self.symbols_sos_id)
-
-    def forward(self, symbols_tokenized_batch, last_hidden_state, pooler_output):
-        r"""Training the translation from encoded sentences to symbols
-
-        Args:
-            symbols_tokenized_batch: [batch_size, max_len_sentence] the true symbols for each sentence.
-            last_hidden_state: [batch_size, max_len_sentence, dim_encoder]  Sequence of hidden-states at the output of the last layer of the model.
-            pooler_output: [batch_size, dim_encoder] Last layer hidden-state of the first token of the sequence (classification token) after further processing through the layers used for the auxiliary pretraining task
-        """
-        batch_size, sequence_length, hidden_size = last_hidden_state.shape
-
-        # y_hat[batch_size, max_len_sentence, vocab_size] init with probability pad =1
-        y_hat = torch.zeros(batch_size, self.max_len_sentence, self.symbols_vocab_size,
-                            dtype=torch.float, device="cuda" if torch.cuda.is_available() else "cpu")
-        y_hat[:, :, self.symbols_padding_id] = 1
-
-        decoded_i = torch.ones(batch_size, 1, dtype=torch.long,
-                               device="cuda" if torch.cuda.is_available() else "cpu") * self.symbols_start_id
-
-        sos_mask = torch.zeros(batch_size, dtype=torch.bool, device="cuda" if torch.cuda.is_available() else "cpu")
-
-        # hidden_state goes through multiple linear layers
-        hidden_state = pooler_output.unsqueeze(0).repeat(self.num_rnn_layers * (1 + self.bidirectional), 1, 1)
-
-        c_state = torch.zeros(self.num_rnn_layers * (1 + self.bidirectional), batch_size, hidden_size,
-                              dtype=torch.float, device="cuda" if torch.cuda.is_available() else "cpu")
-
-        use_teacher_forcing = True if random.random() < self.teacher_forcing else False
-
-        # for each symbol
-        for i in range(self.max_len_sentence):
-            # teacher-forcing training : we pass the target value in the embedding, not a created vector
-            symbols_embedding = self.symbols_embedder(decoded_i)
-            symbols_embedding = self.dropout(symbols_embedding)
-
-            output = symbols_embedding
-            if self.use_attention:
-                attn_weights = F.softmax(
-                    self.attn(torch.cat((symbols_embedding, hidden_state[0].unsqueeze(1)), 2)), dim=2)
-                attn_applied = torch.bmm(attn_weights, last_hidden_state)
-
-                output = torch.cat((symbols_embedding, attn_applied), 2)
-                output = self.attn_combine(output)
-                output = F.relu(output)
-
-            # rnn layer
-            output, (hidden_state, c_state) = self.rnn(output, (hidden_state, c_state))
-
-            # Projection of the output of the rnn omitting the last probability (which is pad) so we dont predict PAD
-            proj = self.proj(output)[:, :, :-2]
-
-            if use_teacher_forcing:
-                decoded_i = symbols_tokenized_batch[:, i].unsqueeze(1)
-            else:
-                decoded_i = torch.argmax(F.softmax(proj, dim=2), dim=2)
-
-            # Calculate sos and pad
-            sos_mask_i = self.sos_mask(torch.argmax(F.softmax(proj, dim=2), dim=2)[:, -1])
-            y_hat[~sos_mask, i, self.symbols_padding_id] = 0
-            y_hat[~sos_mask, i, :-2] = proj[~sos_mask, -1, :]
-            sos_mask = sos_mask_i | sos_mask
-
-            # Stop if every sentence says padding or if we are full
-            if not torch.any(~sos_mask):
-                break
-
-        return y_hat
-
-    def predict_rnn(self, last_hidden_state, pooler_output):
-        r"""Predicts the symbols from the output of the encoder.
-
-        Args:
-            last_hidden_state: [batch_size, max_len_sentence, dim_encoder] the output of the encoder
-            pooler_output: [batch_size, dim_encoder] Last layer hidden-state of the first token of the sequence (classification token) after further processing through the layers used for the auxiliary pretraining task
-        """
-        batch_size, sequence_length, hidden_size = last_hidden_state.shape
-
-        # contains the predictions
-        y_hat = torch.zeros(batch_size, self.max_len_sentence, self.symbols_vocab_size,
-                            dtype=torch.float, device="cuda" if torch.cuda.is_available() else "cpu")
-        y_hat[:, :, self.symbols_padding_id] = 1
-        # input of the embedder, a created vector that replace the true value
-        decoded_i = torch.ones(batch_size, 1, dtype=torch.long,
-                               device="cuda" if torch.cuda.is_available() else "cpu") * self.symbols_start_id
-
-        sos_mask = torch.zeros(batch_size, dtype=torch.bool, device="cuda" if torch.cuda.is_available() else "cpu")
-
-        hidden_state = pooler_output.unsqueeze(0).repeat(self.num_rnn_layers * (1 + self.bidirectional), 1, 1)
-
-        c_state = torch.zeros(self.num_rnn_layers * (1 + self.bidirectional), batch_size, hidden_size,
-                              dtype=torch.float, device="cuda" if torch.cuda.is_available() else "cpu")
-
-        for i in range(self.max_len_sentence):
-            symbols_embedding = self.symbols_embedder(decoded_i)
-            symbols_embedding = self.dropout(symbols_embedding)
-
-            output = symbols_embedding
-            if self.use_attention:
-                attn_weights = F.softmax(
-                    self.attn(torch.cat((symbols_embedding, hidden_state[0].unsqueeze(1)), 2)), dim=2)
-                attn_applied = torch.bmm(attn_weights, last_hidden_state)
-
-                output = torch.cat((symbols_embedding, attn_applied), 2)
-                output = self.attn_combine(output)
-                output = F.relu(output)
-
-            output, (hidden_state, c_state) = self.rnn(output, (hidden_state, c_state))
-
-            proj_softmax = F.softmax(self.proj(output)[:, :, :-2], dim=2)
-            decoded_i = torch.argmax(proj_softmax, dim=2)
-
-            # Set sos and pad
-            sos_mask_i = self.sos_mask(decoded_i[:, -1])
-            y_hat[~sos_mask, i, self.symbols_padding_id] = 0
-            y_hat[~sos_mask, i, :-2] = proj_softmax[~sos_mask, -1, :]
-            sos_mask = sos_mask_i | sos_mask
-
-            # Stop if every sentence says padding or if we are full
-            if not torch.any(~sos_mask):
-                break
-
-        return y_hat
diff --git a/SuperTagger/Decoder/__pycache__/RNNDecoderLayer.cpython-38.pyc b/SuperTagger/Decoder/__pycache__/RNNDecoderLayer.cpython-38.pyc
deleted file mode 100644
index cd9f43c2b2c435f545d90c5f77a9e9e4ce5480f2..0000000000000000000000000000000000000000
Binary files a/SuperTagger/Decoder/__pycache__/RNNDecoderLayer.cpython-38.pyc and /dev/null differ
diff --git a/SuperTagger/Encoder/EncoderInput.py b/SuperTagger/Encoder/EncoderInput.py
deleted file mode 100644
index e19da7d0d28e27e7b191d4333659f58c27e59f09..0000000000000000000000000000000000000000
--- a/SuperTagger/Encoder/EncoderInput.py
+++ /dev/null
@@ -1,18 +0,0 @@
-import torch
-
-
-class EncoderInput():
-
-    def __init__(self, tokenizer):
-        """@params tokenizer (PretrainedTokenizer): Tokenizer that tokenizes text """
-        self.tokenizer = tokenizer
-
-    def fit_transform(self, sents):
-        return self.tokenizer(sents, padding=True,)
-
-    def fit_transform_tensors(self, sents):
-        temp = self.tokenizer(sents, padding=True, return_tensors='pt', )
-        return temp['input_ids'], temp['attention_mask']
-
-    def convert_ids_to_tokens(self, inputs_ids, skip_special_tokens=False):
-        return self.tokenizer.batch_decode(inputs_ids, skip_special_tokens=skip_special_tokens)
diff --git a/SuperTagger/Encoder/EncoderLayer.py b/SuperTagger/Encoder/EncoderLayer.py
deleted file mode 100644
index c954584f332ff6207371cda0bc93aae8fe6edfea..0000000000000000000000000000000000000000
--- a/SuperTagger/Encoder/EncoderLayer.py
+++ /dev/null
@@ -1,67 +0,0 @@
-import sys
-
-import torch
-from torch import nn
-
-from Configuration import Configuration
-
-
-class EncoderLayer(nn.Module):
-    """Encoder class, imput of supertagger"""
-
-    def __init__(self, model):
-        super(EncoderLayer, self).__init__()
-        self.name = "Encoder"
-
-        self.bert = model
-
-        self.hidden_size = self.bert.config.hidden_size
-
-    def forward(self, batch):
-        r"""
-        Args :
-            batch: list[str,mask], list of sentences (NOTE: untokenized, continuous sentences)
-        Returns :
-                last_hidden_state: [batch_size, max_len_sentence, dim_encoder]  Sequence of hidden-states at the output of the last layer of the model.
-                pooler_output: [batch_size, dim_encoder] Last layer hidden-state of the first token of the sequence (classification token) after further processing through the layers used for the auxiliary pretraining task
-        """
-        b_input_ids = batch[0]
-        b_input_mask = batch[1]
-
-        outputs = self.bert(
-            input_ids=b_input_ids, attention_mask=b_input_mask)
-
-        return outputs[0], outputs[1]
-
-    @staticmethod
-    def load(model_path: str):
-        r""" Load the model from a file.
-        Args :
-            model_path (str): path to model
-        Returns :
-            model (nn.Module): model with saved parameters
-        """
-        params = torch.load(
-            model_path, map_location=lambda storage, loc: storage)
-        args = params['args']
-        model = EncoderLayer(**args)
-        model.load_state_dict(params['state_dict'])
-
-        return model
-
-    def save(self, path: str):
-        r""" Save the model to a file.
-        Args :
-            path (str): path to the model
-        """
-        print('save model parameters to [%s]' % path, file=sys.stderr)
-
-        params = {
-            'args': dict(bert_config=self.bert.config, dropout_rate=self.dropout_rate),
-            'state_dict': self.state_dict()
-        }
-
-        torch.save(params, path)
-
-    def to_dict(self):
-        return {}
diff --git a/SuperTagger/Encoder/__pycache__/EncoderInput.cpython-38.pyc b/SuperTagger/Encoder/__pycache__/EncoderInput.cpython-38.pyc
deleted file mode 100644
index 03717155496997dc3ef4713b269d7115fda65fd7..0000000000000000000000000000000000000000
Binary files a/SuperTagger/Encoder/__pycache__/EncoderInput.cpython-38.pyc and /dev/null differ
diff --git a/SuperTagger/Encoder/__pycache__/EncoderLayer.cpython-38.pyc b/SuperTagger/Encoder/__pycache__/EncoderLayer.cpython-38.pyc
deleted file mode 100644
index f685148ba6d4ea28a6f741836205060a3d60123d..0000000000000000000000000000000000000000
Binary files a/SuperTagger/Encoder/__pycache__/EncoderLayer.cpython-38.pyc and /dev/null differ
diff --git a/SuperTagger/Linker/AtomTokenizer.py b/SuperTagger/Linker/AtomTokenizer.py
index e400d4ef28a90fda4e8e1d5f13276720c1de9fe2..6df55edc77013fe6d1ccdcac2d34df146a172db4 100644
--- a/SuperTagger/Linker/AtomTokenizer.py
+++ b/SuperTagger/Linker/AtomTokenizer.py
@@ -1,5 +1,7 @@
 import torch
 
+from SuperTagger.utils import pad_sequence
+
 
 class AtomTokenizer(object):
     def __init__(self, atom_map, max_atoms_in_sentence):
@@ -28,24 +30,3 @@ class AtomTokenizer(object):
 
     def convert_ids_to_atoms(self, ids):
         return [self.inverse_atom_map[int(i)] for i in ids]
-
-
-def pad_sequence(sequences, batch_first=True, padding_value=0, max_len=400):
-    max_size = sequences[0].size()
-    trailing_dims = max_size[1:]
-    if batch_first:
-        out_dims = (len(sequences), max_len) + trailing_dims
-    else:
-        out_dims = (max_len, len(sequences)) + trailing_dims
-
-    out_tensor = sequences[0].data.new(*out_dims).fill_(padding_value)
-    for i, tensor in enumerate(sequences):
-        length = tensor.size(0)
-        # use index notation to prevent duplicate references to the tensor
-        if batch_first:
-            out_tensor[i, :length, ...] = tensor
-        else:
-            out_tensor[:length, i, ...] = tensor
-
-    return out_tensor
-
diff --git a/SuperTagger/Linker/Linker.py b/SuperTagger/Linker/Linker.py
index 6b5c6f1a4de8ccddf2057c124582c533b3bd5c30..6a39ae7658faf456301399cb811d634c8d9f0d38 100644
--- a/SuperTagger/Linker/Linker.py
+++ b/SuperTagger/Linker/Linker.py
@@ -2,6 +2,7 @@ from itertools import chain
 
 import torch
 from torch.nn import Sequential, LayerNorm, Linear, Dropout, GELU
+from torch.nn import Module
 
 from Configuration import Configuration
 from SuperTagger.Linker.AtomEmbedding import AtomEmbedding
@@ -10,11 +11,12 @@ from SuperTagger.Linker.atom_map import atom_map
 from SuperTagger.Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn
 from SuperTagger.Linker.utils import find_pos_neg_idexes, get_atoms_batch
 from SuperTagger.Linker.AttentionLayer import FFN, AttentionLayer
+from SuperTagger.utils import pad_sequence
 
 
-
-class Linker:
+class Linker(Module):
     def __init__(self):
+        super(Linker, self).__init__()
         self.__init__()
 
         self.dim_encoder = int(Configuration.modelEncoderConfig['dim_encoder'])
@@ -71,20 +73,25 @@ class Linker:
         atoms_polarity = find_pos_neg_idexes(category_batch)
 
         link_weights = []
-        for sentence_idx in range(len(atoms_polarity)):
-            for atom_type in self.atom_map.keys():
-                pos_idx_per_atom_type = [i for i, x in enumerate(atoms_polarity[sentence_idx]) if
-                                         x and atoms_batch[sentence_idx][i] == atom_type]
-                neg_idx_per_atom_type = [i for i, x in enumerate(atoms_polarity[sentence_idx]) if
-                                         not x and atoms_batch[sentence_idx][i] == atom_type]
-
-                pos_encoding = atoms_encoding[sentence_idx, pos_idx_per_atom_type, :]
-                neg_encoding = atoms_encoding[sentence_idx, neg_idx_per_atom_type, :]
-
-                pos_encoding = self.pos_transformation(pos_encoding)
-                neg_encoding = self.neg_transformation(neg_encoding)
-
-                weights = torch.bmm(pos_encoding.unsqueeze(0), neg_encoding.transpose(1, 0).unsqueeze(0))
-                link_weights.append(sinkhorn(weights, iters=self.sinkhorn_iters))
+        for atom_type in self.atom_map.keys():
+            pos_idx_per_atom_type = [[i for i, x in enumerate(atoms_polarity[s_idx]) if
+                                      x and atoms_batch[s_idx][i] == atom_type] for s_idx in range(len(atoms_polarity))]
+            neg_idx_per_atom_type = [[i for i, x in enumerate(atoms_polarity[s_idx]) if
+                                      not x and atoms_batch[s_idx][i] == atom_type] for s_idx in
+                                     range(len(atoms_polarity))]
+
+            # to do select with list of list
+            pos_encoding = pad_sequence(
+                [atoms_encoding.select(0, index=i).index_select(0, index=torch.as_tensor(sentence))
+                 for i, sentence in enumerate(pos_idx_per_atom_type)], max_len=self.max_atoms_in_sentence, padding_value=0)
+            neg_encoding = pad_sequence(
+                [atoms_encoding.select(0, index=i).index_select(0, index=torch.as_tensor(sentence))
+                 for i, sentence in enumerate(neg_idx_per_atom_type)], max_len=self.max_atoms_in_sentence, padding_value=0)
+
+            # pos_encoding = self.pos_transformation(pos_encoding)
+            # neg_encoding = self.neg_transformation(neg_encoding)
+
+            weights = torch.bmm(pos_encoding, neg_encoding.transpose(2, 1))
+            link_weights.append(sinkhorn(weights, iters=3))
 
         return link_weights
diff --git a/SuperTagger/Linker/Sinkhorn.py b/SuperTagger/Linker/Sinkhorn.py
index 912abb4a0a070c7eae8af7dd4dd1cf3aafbc3a65..9cf9b45607800c1f35efa98801c86e3326726a19 100644
--- a/SuperTagger/Linker/Sinkhorn.py
+++ b/SuperTagger/Linker/Sinkhorn.py
@@ -1,4 +1,3 @@
-
 from torch import logsumexp
 
 
diff --git a/SuperTagger/Linker/__pycache__/Sinkhorn.cpython-38.pyc b/SuperTagger/Linker/__pycache__/Sinkhorn.cpython-38.pyc
index afbe4b3ab22416e9fb3fa5b7e422587b47fe3c95..26d18c0959a2b1114a5564ff8b1ec4304f81acf3 100644
Binary files a/SuperTagger/Linker/__pycache__/Sinkhorn.cpython-38.pyc and b/SuperTagger/Linker/__pycache__/Sinkhorn.cpython-38.pyc differ
diff --git a/SuperTagger/Linker/utils.py b/SuperTagger/Linker/utils.py
index ddb8cb582d60625fb82b651663dc0a732bf6fb7a..f2e72e1b7f0ab7a95e3c098d043707c4693f2fc4 100644
--- a/SuperTagger/Linker/utils.py
+++ b/SuperTagger/Linker/utils.py
@@ -92,4 +92,3 @@ def find_pos_neg_idexes(batch_symbols):
             list_symbols.append(cut_category_in_symbols(category))
         list_batch.append(list_symbols)
     return list_batch
-
diff --git a/SuperTagger/Symbol/SymbolEmbedding.py b/SuperTagger/Symbol/SymbolEmbedding.py
deleted file mode 100644
index b982ef07f084578ac550fd28b5e58b3a2acef3b8..0000000000000000000000000000000000000000
--- a/SuperTagger/Symbol/SymbolEmbedding.py
+++ /dev/null
@@ -1,12 +0,0 @@
-import torch
-from torch.nn import Module, Embedding
-
-
-class SymbolEmbedding(Module):
-    def __init__(self, dim_decoder, atom_vocab_size, padding_idx):
-        super(SymbolEmbedding, self).__init__()
-        self.emb = Embedding(num_embeddings=atom_vocab_size, embedding_dim=dim_decoder, padding_idx=padding_idx,
-                             scale_grad_by_freq=True)
-
-    def forward(self, x):
-        return self.emb(x)
diff --git a/SuperTagger/Symbol/SymbolTokenizer.py b/SuperTagger/Symbol/SymbolTokenizer.py
deleted file mode 100644
index cded840a71b3c2ca461524d5c5f06d5e12e69f01..0000000000000000000000000000000000000000
--- a/SuperTagger/Symbol/SymbolTokenizer.py
+++ /dev/null
@@ -1,53 +0,0 @@
-
-import torch
-
-
-class SymbolTokenizer(object):
-    def __init__(self, symbol_map, max_symbols_in_sentence, max_len_sentence):
-        self.symbol_map = symbol_map
-        self.max_symbols_in_sentence = max_symbols_in_sentence
-        self.max_len_sentence = max_len_sentence
-        self.inverse_symbol_map = {v: k for k, v in self.symbol_map.items()}
-        self.sep_token = '[SEP]'
-        self.pad_token = '[PAD]'
-        self.sos_token = '[SOS]'
-        self.sep_token_id = self.symbol_map[self.sep_token]
-        self.pad_token_id = self.symbol_map[self.pad_token]
-        self.sos_token_id = self.symbol_map[self.sos_token]
-
-    def __len__(self):
-        return len(self.symbol_map)
-
-    def convert_symbols_to_ids(self, symbol):
-        return self.symbol_map[str(symbol)]
-
-    def convert_sents_to_ids(self, sentences):
-        return torch.as_tensor([self.convert_symbols_to_ids(symbol) for symbol in sentences])
-
-    def convert_batchs_to_ids(self, batchs_sentences):
-        return torch.as_tensor(pad_sequence([self.convert_sents_to_ids(sents) for sents in batchs_sentences],
-                                            max_len=self.max_symbols_in_sentence, padding_value=self.pad_token_id))
-
-    def convert_ids_to_symbols(self, ids):
-        return [self.inverse_symbol_map[int(i)] for i in ids]
-
-
-def pad_sequence(sequences, batch_first=True, padding_value=0, max_len=400):
-    max_size = sequences[0].size()
-    trailing_dims = max_size[1:]
-    if batch_first:
-        out_dims = (len(sequences), max_len) + trailing_dims
-    else:
-        out_dims = (max_len, len(sequences)) + trailing_dims
-
-    out_tensor = sequences[0].data.new(*out_dims).fill_(padding_value)
-    for i, tensor in enumerate(sequences):
-        length = tensor.size(0)
-        # use index notation to prevent duplicate references to the tensor
-        if batch_first:
-            out_tensor[i, :length, ...] = tensor
-        else:
-            out_tensor[:length, i, ...] = tensor
-
-    return out_tensor
-
diff --git a/SuperTagger/Symbol/__pycache__/SymbolEmbedding.cpython-38.pyc b/SuperTagger/Symbol/__pycache__/SymbolEmbedding.cpython-38.pyc
deleted file mode 100644
index 030ce696540244b363763b0b592d2a84a8c19fd9..0000000000000000000000000000000000000000
Binary files a/SuperTagger/Symbol/__pycache__/SymbolEmbedding.cpython-38.pyc and /dev/null differ
diff --git a/SuperTagger/Symbol/__pycache__/SymbolTokenizer.cpython-38.pyc b/SuperTagger/Symbol/__pycache__/SymbolTokenizer.cpython-38.pyc
deleted file mode 100644
index 7c631aa22d6bb2cee2970ecccb0a089a347c9f95..0000000000000000000000000000000000000000
Binary files a/SuperTagger/Symbol/__pycache__/SymbolTokenizer.cpython-38.pyc and /dev/null differ
diff --git a/SuperTagger/Symbol/__pycache__/symbol_map.cpython-38.pyc b/SuperTagger/Symbol/__pycache__/symbol_map.cpython-38.pyc
deleted file mode 100644
index 1e1195f2ba885473f1d89b05684939ab6a2385b1..0000000000000000000000000000000000000000
Binary files a/SuperTagger/Symbol/__pycache__/symbol_map.cpython-38.pyc and /dev/null differ
diff --git a/SuperTagger/Symbol/symbol_map.py b/SuperTagger/Symbol/symbol_map.py
deleted file mode 100644
index c16b8fdc18d5ec805bd5558699c52ace1597174d..0000000000000000000000000000000000000000
--- a/SuperTagger/Symbol/symbol_map.py
+++ /dev/null
@@ -1,28 +0,0 @@
-symbol_map = \
-    {'cl_r': 0,
-     '\\': 1,
-     'n': 2,
-     'p': 3,
-     's_ppres': 4,
-     'dia': 5,
-     's_whq': 6,
-     'let': 7,
-     '/': 8,
-     's_inf': 9,
-     's_pass': 10,
-     'pp_a': 11,
-     'pp_par': 12,
-     'pp_de': 13,
-     'cl_y': 14,
-     'box': 15,
-     'txt': 16,
-     's': 17,
-     's_ppart': 18,
-     's_q': 19,
-     'np': 20,
-     'pp': 21,
-     '[SEP]': 22,
-     '[SOS]': 23,
-     '[START]': 24,
-     '[PAD]': 25
-     }
diff --git a/SuperTagger/eval.py b/SuperTagger/eval.py
index 3017c7f52dc53a1ce9cc8ba21d75c63446af1b83..372e68cab3a4cf4af8da655a7073f372de39cb5e 100644
--- a/SuperTagger/eval.py
+++ b/SuperTagger/eval.py
@@ -1,8 +1,7 @@
 import torch
 from torch import Tensor
 from torch.nn import Module
-from torch.nn.functional import cross_entropy
-
+from torch.nn.functional import nll_loss, cross_entropy
 
 # Another from Kokos function to calculate the accuracy of our predictions vs labels
 def measure_supertagging_accuracy(pred, truth, ignore_idx=0):
@@ -42,3 +41,12 @@ class NormCrossEntropy(Module):
     def forward(self, predictions, truths):
         return cross_entropy(predictions.flatten(0, -2), truths.flatten(), weight=self.weights,
                              reduction='sum', ignore_index=self.ignore_index) / count_sep(truths.flatten(), self.sep_id)
+
+
+class SinkhornLoss(Module):
+    def __init__(self):
+        super(SinkhornLoss, self).__init__()
+
+    def forward(self, predictions, truths):
+        return sum(nll_loss(link.flatten(0, 1), perm.flatten(), reduction='mean')
+                   for link, perm in zip(predictions, truths))
\ No newline at end of file
diff --git a/SuperTagger/utils.py b/SuperTagger/utils.py
index 8712cca081d5897d460d451b05814dfa46bfc538..cfacf2503ba924d1fd3ba07b340e27a2b13a2002 100644
--- a/SuperTagger/utils.py
+++ b/SuperTagger/utils.py
@@ -5,6 +5,26 @@ import torch
 from tqdm import tqdm
 
 
+def pad_sequence(sequences, batch_first=True, padding_value=0, max_len=400):
+    max_size = sequences[0].size()
+    trailing_dims = max_size[1:]
+    if batch_first:
+        out_dims = (len(sequences), max_len) + trailing_dims
+    else:
+        out_dims = (max_len, len(sequences)) + trailing_dims
+
+    out_tensor = sequences[0].data.new(*out_dims).fill_(padding_value)
+    for i, tensor in enumerate(sequences):
+        length = tensor.size(0)
+        # use index notation to prevent duplicate references to the tensor
+        if batch_first:
+            out_tensor[i, :length, ...] = tensor
+        else:
+            out_tensor[:length, i, ...] = tensor
+
+    return out_tensor
+
+
 def read_csv_pgbar(csv_path, nrows=float('inf'), chunksize=500):
     print("\n" + "#" * 20)
     print("Loading csv...")
diff --git a/test.py b/test.py
index f208027894f01b95d1509ccd2fafb58b12c2ac44..9e14d08a3992b2c2c6366af2ef3943a2737f3ef8 100644
--- a/test.py
+++ b/test.py
@@ -1,27 +1,54 @@
 from SuperTagger.Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn
 import torch
 
-atoms_batch = [["np", "v", "np", "v","np", "v", "np", "v"],
-               ["np", "np", "v", "v","np", "np", "v", "v"]]
 
-atoms_polarity = [[False, True, True, False,False, True, True, False],
-                  [True, False, True, False,True, False, True, False]]
+def pad_sequence(sequences, batch_first=True, padding_value=0, max_len=400):
+    max_size = sequences[0].size()
+    trailing_dims = max_size[1:]
+    if batch_first:
+        out_dims = (len(sequences), max_len) + trailing_dims
+    else:
+        out_dims = (max_len, len(sequences)) + trailing_dims
 
-atoms_encoding = torch.randn((2, 8, 24))
+    out_tensor = sequences[0].data.new(*out_dims).fill_(padding_value)
+    for i, tensor in enumerate(sequences):
+        length = tensor.size(0)
+        # use index notation to prevent duplicate references to the tensor
+        if batch_first:
+            out_tensor[i, :length, ...] = tensor
+        else:
+            out_tensor[:length, i, ...] = tensor
+
+    return out_tensor
 
-matches = []
-for sentence_idx in range(len(atoms_polarity)):
 
-    for atom_type in ["np", "v"]:
-        pos_idx_per_atom_type = [i for i, x in enumerate(atoms_polarity[sentence_idx]) if
-                                 x and atoms_batch[sentence_idx][i] == atom_type]
-        neg_idx_per_atom_type = [i for i, x in enumerate(atoms_polarity[sentence_idx]) if
-                                 not x and atoms_batch[sentence_idx][i] == atom_type]
+atoms_batch = [["np", "v", "np", "v", "np", "v", "np", "v"],
+               ["np", "np", "v", "v"]]
 
-        pos_encoding = atoms_encoding[sentence_idx, pos_idx_per_atom_type, :]
-        neg_encoding = atoms_encoding[sentence_idx, neg_idx_per_atom_type, :]
+atoms_polarity = [[False, True, True, False, False, True, True, False],
+                  [True, False, True, False]]
 
-        weights = torch.bmm(pos_encoding.unsqueeze(0), neg_encoding.transpose(1, 0).unsqueeze(0))
-        matches.append(sinkhorn(weights, iters=3))
+atoms_encoding = torch.randn((2, 8, 24))
+
+matches = []
+for atom_type in ["np", "v"]:
+    pos_idx_per_atom_type = [[i for i, x in enumerate(atoms_polarity[s_idx]) if
+                              x and atoms_batch[s_idx][i] == atom_type] for s_idx in range(len(atoms_polarity))]
+    neg_idx_per_atom_type = [[i for i, x in enumerate(atoms_polarity[s_idx]) if
+                              not x and atoms_batch[s_idx][i] == atom_type] for s_idx in range(len(atoms_polarity))]
+
+    # to do select with list of list
+    pos_encoding = pad_sequence([atoms_encoding.select(0, index=i).index_select(0, index=torch.as_tensor(sentence))
+            for i, sentence in enumerate(pos_idx_per_atom_type)], max_len=3, padding_value=0)
+    neg_encoding = pad_sequence([atoms_encoding.select(0, index=i).index_select(0, index=torch.as_tensor(sentence))
+            for i, sentence in enumerate(neg_idx_per_atom_type)], max_len=3, padding_value=0)
+
+    print(neg_encoding.shape)
+
+    weights = torch.bmm(pos_encoding, neg_encoding.transpose(2, 1))
+    print(weights.shape)
+    print("sinkhorn")
+    print(sinkhorn(weights, iters=3).shape)
+    matches.append(sinkhorn(weights, iters=3))
 
 print(matches)
diff --git a/train.py b/train.py
index 58ebe4523d2004d52862905df2c09d88aff9dd81..25154db1b58abb38c2e7a6ffb94f89fcfb4d4b0d 100644
--- a/train.py
+++ b/train.py
@@ -1,134 +1,51 @@
 import os
+import pickle
 import time
 from datetime import datetime
 
 import numpy as np
 import torch
 import torch.nn.functional as F
-import transformers
 from torch.optim import SGD, Adam, AdamW
 from torch.utils.data import Dataset, TensorDataset, random_split
-from transformers import (AutoTokenizer, get_cosine_schedule_with_warmup)
-from transformers import (CamembertModel)
+from transformers import get_cosine_schedule_with_warmup
 
 from Configuration import Configuration
-from SuperTagger.Encoder.EncoderInput import EncoderInput
-from SuperTagger.EncoderDecoder import EncoderDecoder
-from SuperTagger.Symbol.SymbolTokenizer import SymbolTokenizer
-from SuperTagger.Symbol.symbol_map import symbol_map
-from SuperTagger.eval import NormCrossEntropy
+from SuperTagger.Linker.Linker import Linker
+from SuperTagger.Linker.atom_map import atom_map
+from SuperTagger.eval import NormCrossEntropy, SinkhornLoss
 from SuperTagger.utils import format_time, read_csv_pgbar, checkpoint_save, checkpoint_load
 
 from torch.utils.tensorboard import SummaryWriter
 
-transformers.TOKENIZERS_PARALLELISM = True
 torch.cuda.empty_cache()
 
 # region ParamsModel
 
-max_len_sentence = int(Configuration.modelDecoderConfig['max_len_sentence'])
-symbol_vocab_size = int(Configuration.modelDecoderConfig['symbols_vocab_size'])
-num_gru_layers = int(Configuration.modelDecoderConfig['num_rnn_layers'])
+max_len_sentence = int(Configuration.datasetConfig['max_len_sentence'])
+atom_vocab_size = int(Configuration.datasetConfig['atoms_vocab_size'])
 
 # endregion ParamsModel
 
 # region ParamsTraining
 
-file_path = 'Datasets/m2_dataset.csv'
 batch_size = int(Configuration.modelTrainingConfig['batch_size'])
 nb_sentences = batch_size * 40
 epochs = int(Configuration.modelTrainingConfig['epoch'])
 seed_val = int(Configuration.modelTrainingConfig['seed_val'])
 learning_rate = float(Configuration.modelTrainingConfig['learning_rate'])
-loss_scaled_by_freq = True
 
 # endregion ParamsTraining
 
-# region OutputTraining
-
-outpout_path = str(Configuration.modelTrainingConfig['output_path'])
-
-training_dir = os.path.join(outpout_path, 'Tranning_' + datetime.today().strftime('%d-%m_%H-%M'))
-logs_dir = os.path.join(training_dir, 'logs')
-
-checkpoint_dir = training_dir
-writer = SummaryWriter(log_dir=logs_dir)
-
-use_checkpoint_SAVE = bool(Configuration.modelTrainingConfig.getboolean('use_checkpoint_SAVE'))
-
-# endregion OutputTraining
-
-# region InputTraining
-
-input_path = str(Configuration.modelTrainingConfig['input_path'])
-model_to_load = str(Configuration.modelTrainingConfig['model_to_load'])
-model_to_load_path = os.path.join(input_path, model_to_load)
-use_checkpoint_LOAD = bool(Configuration.modelTrainingConfig.getboolean('use_checkpoint_LOAD'))
-
-# endregion InputTraining
-
-# region Print config
-
-print("##" * 15 + "\nConfiguration : \n")
-
-print("ParamsModel\n")
-
-print("\tsymbol_vocab_size :", symbol_vocab_size)
-print("\tbidirectional : ", False)
-print("\tnum_gru_layers : ", num_gru_layers)
-
-print("\n ParamsTraining\n")
-
-print("\tDataset :", file_path)
-print("\tb_sentences :", nb_sentences)
-print("\tbatch_size :", batch_size)
-print("\tepochs :", epochs)
-print("\tseed_val :", seed_val)
-
-print("\n Output\n")
-print("\tuse checkpoint save :", use_checkpoint_SAVE)
-print("\tcheckpoint_dir :", checkpoint_dir)
-print("\tlogs_dir :", logs_dir)
-
-print("\n Input\n")
-print("\tModel to load :", model_to_load_path)
-print("\tLoad checkpoint :", use_checkpoint_LOAD)
-
-print("\nLoss and optimizer : ")
-
-print("\tlearning_rate :", learning_rate)
-print("\twith loss scaled by freq :", loss_scaled_by_freq)
-
-print("\n Device\n")
-device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-print("\t", device)
-
-print()
-print("##" * 15)
-
-# endregion Print config
-
-# region Model
+# region Data loader
 
 file_path = 'Datasets/m2_dataset.csv'
-BASE_TOKENIZER = AutoTokenizer.from_pretrained(
-    'camembert-base',
-    do_lower_case=True)
-BASE_MODEL = CamembertModel.from_pretrained("camembert-base")
-symbols_tokenizer = SymbolTokenizer(symbol_map, max_len_sentence, max_len_sentence)
-sents_tokenizer = EncoderInput(BASE_TOKENIZER)
-model = EncoderDecoder(BASE_TOKENIZER, BASE_MODEL, symbol_map)
-model = model.to("cuda" if torch.cuda.is_available() else "cpu")
+file_path_axiom_links = 'Datasets/axiom_links.csv'
 
-# endregion Model
-
-# region Data loader
 df = read_csv_pgbar(file_path, nb_sentences)
+df_axiom_links = read_csv_pgbar(file_path_axiom_links, nb_sentences)
 
-symbols_tokenized = symbols_tokenizer.convert_batchs_to_ids(df['sub_tree'])
-sents_tokenized, sents_mask = sents_tokenizer.fit_transform_tensors(df['Sentences'].tolist())
-
-dataset = TensorDataset(sents_tokenized, sents_mask, symbols_tokenized)
+dataset = TensorDataset(df, df, df_axiom_links)
 
 # Calculate the number of samples to include in each set.
 train_size = int(0.9 * len(dataset))
@@ -137,46 +54,34 @@ val_size = len(dataset) - train_size
 # Divide the dataset by randomly selecting samples.
 train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
 
-print('{:>5,} training samples'.format(train_size))
-print('{:>5,} validation samples'.format(val_size))
-
 training_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
 validation_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
 
 # endregion Data loader
 
+
+# region Models
+
+supertagger_path = ""
+supertagger = pickle.load(supertagger_path)
+linker = Linker()
+
+# endregion Models
+
 # region Fit tunning
 
 # Optimizer
-optimizer_encoder = AdamW(model.encoder.parameters(),
-                          weight_decay=1e-5,
-                          lr=5e-5)
-optimizer_decoder = AdamW(model.decoder.parameters(),
-                          weight_decay=1e-5,
-                          lr=learning_rate)
-
-# Total number of training steps is [number of batches] x [number of epochs].
-# (Note that this is not the same as the number of training samples).
-total_steps = len(training_dataloader) * epochs
+optimizer_linker = AdamW(linker.parameters(),
+                         weight_decay=1e-5,
+                         lr=learning_rate)
 
 # Create the learning rate scheduler.
-scheduler_encoder = get_cosine_schedule_with_warmup(optimizer_encoder,
-                                                    num_warmup_steps=0,
-                                                    num_training_steps=5)
-scheduler_decoder = get_cosine_schedule_with_warmup(optimizer_decoder,
-                                                    num_warmup_steps=0,
-                                                    num_training_steps=total_steps)
+scheduler_linker = get_cosine_schedule_with_warmup(optimizer_linker,
+                                                   num_warmup_steps=0,
+                                                   num_training_steps=100)
 
 # Loss
-if loss_scaled_by_freq:
-    weights = torch.as_tensor(
-        [6.9952, 1.0763, 1.0317, 43.274, 16.5276, 11.8821, 28.2416, 2.7548, 1.0728, 3.1847, 8.4521, 6.77, 11.1887,
-         6.6692, 23.1277, 11.8821, 4.4338, 1.2303, 5.0238, 8.4376, 1.0656, 4.6886, 1.028, 4.273, 4.273, 0],
-        device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
-    cross_entropy_loss = NormCrossEntropy(symbols_tokenizer.pad_token_id, symbols_tokenizer.sep_token_id,
-                                          weights=weights)
-else:
-    cross_entropy_loss = NormCrossEntropy(symbols_tokenizer.pad_token_id, symbols_tokenizer.sep_token_id)
+cross_entropy_loss = SinkhornLoss()
 
 np.random.seed(seed_val)
 torch.manual_seed(seed_val)
@@ -192,10 +97,6 @@ total_t0 = time.time()
 
 validate = True
 
-if use_checkpoint_LOAD:
-    model, optimizer_decoder, last_epoch, loss = checkpoint_load(model, optimizer_decoder, model_to_load_path)
-    epochs = epochs - last_epoch
-
 
 def run_epochs(epochs):
     # For each epoch...
@@ -216,60 +117,38 @@ def run_epochs(epochs):
         # Reset the total loss for this epoch.
         total_train_loss = 0
 
-        model.train()
+        linker.train()
 
         # For each batch of training data...
         for step, batch in enumerate(training_dataloader):
+            # Unpack this training batch from our dataloader.
+            batch_categories = batch[0].to("cuda" if torch.cuda.is_available() else "cpu")
+            batch_sentences = batch[1].to("cuda" if torch.cuda.is_available() else "cpu")
+            batch_axiom_links = batch[2].to("cuda" if torch.cuda.is_available() else "cpu")
+
+            optimizer_linker.zero_grad()
+
+            # Find the prediction of categories to feed the linker and the sentences embedding
+            category_logits_pred, sents_embedding, sents_mask = supertagger(batch_categories, batch_sentences)
+
+            # Predict the categories from prediction with argmax and softmax
+            category_batch = torch.argmax(torch.nn.functional.softmax(category_logits_pred, dim=2), dim=2)
 
-            # if epoch_i == 0 and step == 0:
-            #     writer.add_graph(model, input_to_model=batch[0], verbose=False)
-
-            # Progress update every 40 batches.
-            if step % 40 == 0 and not step == 0:
-                # Calculate elapsed time in minutes.
-                elapsed = format_time(time.time() - t0)
-                # Report progress.
-                print('  Batch {:>5,}  of  {:>5,}.    Elapsed: {:}.'.format(step, len(training_dataloader), elapsed))
-
-                # Unpack this training batch from our dataloader.
-            b_sents_tokenized = batch[0].to("cuda" if torch.cuda.is_available() else "cpu")
-            b_sents_mask = batch[1].to("cuda" if torch.cuda.is_available() else "cpu")
-            b_symbols_tokenized = batch[2].to("cuda" if torch.cuda.is_available() else "cpu")
-
-            optimizer_encoder.zero_grad()
-            optimizer_decoder.zero_grad()
-
-            logits_predictions = model(b_sents_tokenized, b_sents_mask, b_symbols_tokenized)
-
-            predict_trad = [{v: k for k, v in symbol_map.items()}[int(i)] for i in
-                            torch.argmax(F.softmax(logits_predictions, dim=2), dim=2)[0]]
-            true_trad = [{v: k for k, v in symbol_map.items()}[int(i)] for i in b_symbols_tokenized[0]]
-            l = len([i for i in true_trad if i != '[PAD]'])
-            if step % 40 == 0 and not step == 0:
-                writer.add_text("Sample", "\ntrain true (" + str(l) + ") : " + str(
-                    [token for token in true_trad if token != '[PAD]']) + "\ntrain predict (" + str(
-                    len([i for i in predict_trad if i != '[PAD]'])) + ") : " + str(
-                    [token for token in predict_trad[:l] if token != '[PAD]']))
-
-            loss = cross_entropy_loss(logits_predictions, b_symbols_tokenized)
+            # Run the kinker on the categories predictions
+            logits_predictions = linker(category_batch, sents_embedding, sents_mask)
+
+            linker_loss = cross_entropy_loss(logits_predictions, batch_axiom_links)
             # Perform a backward pass to calculate the gradients.
-            total_train_loss += float(loss)
-            loss.backward()
+            total_train_loss += float(linker_loss)
+            linker_loss.backward()
 
             # This is to help prevent the "exploding gradients" problem.
             # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0, norm_type=2)
 
             # Update parameters and take a step using the computed gradient.
-            optimizer_encoder.step()
-            optimizer_decoder.step()
-
-            scheduler_encoder.step()
-            scheduler_decoder.step()
-
-        # checkpoint
+            optimizer_linker.step()
 
-        if use_checkpoint_SAVE:
-            checkpoint_save(model, optimizer_decoder, epoch_i, checkpoint_dir, loss)
+            scheduler_linker.step()
 
         avg_train_loss = total_train_loss / len(training_dataloader)
 
@@ -277,27 +156,18 @@ def run_epochs(epochs):
         training_time = format_time(time.time() - t0)
 
         if validate:
-            model.eval()
+            linker.eval()
             with torch.no_grad():
                 print("Start eval")
-                accuracy_sents, accuracy_symbol, v_loss = model.eval_epoch(validation_dataloader, cross_entropy_loss)
+                accuracy_sents, accuracy_atom, v_loss = linker.eval_epoch(validation_dataloader, cross_entropy_loss)
                 print("")
                 print("  Average accuracy sents on epoch: {0:.2f}".format(accuracy_sents))
-                print("  Average accuracy symbol on epoch: {0:.2f}".format(accuracy_symbol))
-                writer.add_scalar('Accuracy/sents', accuracy_sents, epoch_i + 1)
-                writer.add_scalar('Accuracy/symbol', accuracy_symbol, epoch_i + 1)
+                print("  Average accuracy atom on epoch: {0:.2f}".format(accuracy_atom))
 
         print("")
         print("  Average training loss: {0:.2f}".format(avg_train_loss))
         print("  Training epcoh took: {:}".format(training_time))
 
-        # writer.add_scalar('Loss/train', total_train_loss, epoch_i+1)
-
-        writer.add_scalars('Training vs. Validation Loss',
-                           {'Training': avg_train_loss, 'Validation': v_loss},
-                           epoch_i + 1)
-        writer.flush()
-
 
 run_epochs(epochs)
 # endregion Train