diff --git a/SuperTagger/Decoder/RNNDecoderLayer.py b/SuperTagger/Decoder/RNNDecoderLayer.py
index 3f4f9d8a44d1006ce1030b2b833afc441b08cf6b..9c6c12b483752bf4797a9269b2c403521984f63d 100644
--- a/SuperTagger/Decoder/RNNDecoderLayer.py
+++ b/SuperTagger/Decoder/RNNDecoderLayer.py
@@ -2,18 +2,11 @@ import random
 
 import torch
 import torch.nn.functional as F
-from torch.nn import (Dropout, Module, Module, Sequential, LayerNorm, Dropout, GELU, Linear, LSTM, GRU)
+from torch.nn import (Module, Dropout, Linear, LSTM)
 
 from Configuration import Configuration
 from SuperTagger.Symbol.SymbolEmbedding import SymbolEmbedding
 
-def FFN(d_model, d_ff, dropout_rate = 0.1, d_out = None) -> Module:
-    return Sequential(
-        Linear(d_model, d_ff, bias=False),
-        GELU(),
-        Dropout(dropout_rate),
-        Linear(d_ff, d_model if d_out is None else d_out, bias=False)
-    )
 
 class RNNDecoderLayer(Module):
     def __init__(self, symbols_map):
@@ -45,12 +38,12 @@ class RNNDecoderLayer(Module):
         # rnn Layer
         if self.use_attention:
             self.rnn = LSTM(input_size=self.dim_encoder, hidden_size=self.dim_encoder, num_layers=self.num_rnn_layers,
-                        dropout=dropout,
-                        bidirectional=self.bidirectional, batch_first=True)
-        else :
+                            dropout=dropout,
+                            bidirectional=self.bidirectional, batch_first=True)
+        else:
             self.rnn = LSTM(input_size=self.dim_decoder, hidden_size=self.dim_encoder, num_layers=self.num_rnn_layers,
-                        dropout=dropout,
-                        bidirectional=self.bidirectional, batch_first=True)
+                            dropout=dropout,
+                            bidirectional=self.bidirectional, batch_first=True)
 
         # Projection on vocab_size
         if self.bidirectional:
@@ -61,13 +54,6 @@ class RNNDecoderLayer(Module):
         self.attn = Linear(self.dim_decoder + self.dim_encoder, self.max_len_sentence)
         self.attn_combine = Linear(self.dim_decoder + self.dim_encoder, self.dim_encoder)
 
-        # linking and pos neg weights
-        self.linker =
-        self.positive_transfo = Sequential(
-            FFN(self.dec_dim * 2, self.dec_dim, 0.1, self.dec_dim//2), LayerNorm(self.dec_dim//2, eps=1e-12))
-        self.negative_transfo = Sequential(
-            FFN(self.dec_dim * 2, self.dec_dim, 0.1, self.dec_dim // 2), LayerNorm(self.dec_dim//2, eps=1e-12))
-
     def sos_mask(self, y):
         return torch.eq(y, self.symbols_sos_id)
 
diff --git a/SuperTagger/Linker/Linker.py b/SuperTagger/Linker/Linker.py
new file mode 100644
index 0000000000000000000000000000000000000000..65682306c2269b022e3ef23f1bd83da9aad19bf1
--- /dev/null
+++ b/SuperTagger/Linker/Linker.py
@@ -0,0 +1,79 @@
+from itertools import chain
+
+import torch
+from torch.nn import Sequential, LayerNorm, Linear, Dropout, GELU
+
+from Configuration import Configuration
+
+from SuperTagger.Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn
+from SuperTagger.Linker.utils import find_pos_neg_idexes, make_sinkhorn_inputs
+
+
+def FFN(d_model, d_ff, dropout_rate=0.1, d_out=None):
+    return Sequential(
+        Linear(d_model, d_ff, bias=False),
+        GELU(),
+        Dropout(dropout_rate),
+        Linear(d_ff, d_model if d_out is None else d_out, bias=False)
+    )
+
+
+class Linker:
+    def __init__(self):
+        self.__init__()
+
+        self.dim_encoder = int(Configuration.modelDecoderConfig['dim_encoder'])
+        self.dim_decoder = int(Configuration.modelDecoderConfig['dim_decoder'])
+
+        self.dropout = Dropout(0.1)
+
+        self.pos_transformation = Sequential(
+            FFN(self.dim_decoder * 2, self.dim_decoder, 0.1, self.dim_decoder // 2),
+            LayerNorm(self.dim_decoder // 2, eps=1e-12)
+        )
+        self.neg_transformation = Sequential(
+            FFN(self.dim_decoder * 2, self.dim_decoder, 0.1, self.dim_decoder // 2),
+            LayerNorm(self.dim_decoder // 2, eps=1e-12)
+        )
+
+    def forward(self, symbols_batch, symbols_decoding):
+        '''
+        Parameters :
+        symbols_decoding : batch of size (batch_size, sequence_length) = output of decoder
+        '''
+
+        # some sequential for linker with output of decoder and initial ato
+
+        # decompose into batch_size, max symbols in sentence
+        decompose_decoding = find_pos_neg_idexes(symbols_batch)
+
+        # get  tensors of shape (batch_size, max_symbols_in_sentence/2)
+        pos_idxes_batch = list(map(lambda sub_list: [i for i, x in enumerate(sub_list) if x], decompose_decoding))
+        neg_idxes_batch = list(map(lambda sub_list: [i for i, x in enumerate(sub_list) if not x], decompose_decoding))
+
+        _positives = make_sinkhorn_inputs(symbols_decoding, pos_idxes_batch)
+        _negatives = make_sinkhorn_inputs(symbols_decoding, neg_idxes_batch)
+
+        positives = [tensor for tensor in chain.from_iterable(_positives) if min(tensor.size()) != 0]
+        negatives = [tensor for tensor in chain.from_iterable(_negatives) if min(tensor.size()) != 0]
+
+        distinct_shapes = {tensor.size()[0] for tensor in positives}
+        distinct_shapes = sorted(distinct_shapes)
+
+        # going to match the pos and neg together
+        matches = []
+
+        all_shape_positives = [self.pos_transformation(self.dropout(torch.stack([tensor for tensor in positives
+                                                                                 if tensor.size()[0] == shape])))
+                               for shape in distinct_shapes]
+
+        all_shape_negatives = [self.neg_transformation(self.dropout(torch.stack([tensor for tensor in negatives
+                                                                                 if tensor.size()[0] == shape])))
+                               for shape in distinct_shapes]
+
+        for this_shape_positives, this_shape_negatives in zip(all_shape_positives, all_shape_negatives):
+            weights = torch.bmm(this_shape_positives,
+                                this_shape_negatives.transpose(2, 1))
+            matches.append(sinkhorn(weights, iters=3))
+
+        return matches
diff --git a/SuperTagger/Linker/Sinkhorn.py b/SuperTagger/Linker/Sinkhorn.py
new file mode 100644
index 0000000000000000000000000000000000000000..912abb4a0a070c7eae8af7dd4dd1cf3aafbc3a65
--- /dev/null
+++ b/SuperTagger/Linker/Sinkhorn.py
@@ -0,0 +1,17 @@
+
+from torch import logsumexp
+
+
+def norm(x, dim):
+    return x - logsumexp(x, dim=dim, keepdim=True)
+
+
+def sinkhorn_step(x):
+    return norm(norm(x, dim=1), dim=2)
+
+
+def sinkhorn_fn_no_exp(x, tau=1, iters=3):
+    x = x / tau
+    for _ in range(iters):
+        x = sinkhorn_step(x)
+    return x
diff --git a/SuperTagger/Linker/utils.py b/SuperTagger/Linker/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..49e702c77c7b9bc1c400c57711049fdbac15bfe7
--- /dev/null
+++ b/SuperTagger/Linker/utils.py
@@ -0,0 +1,82 @@
+import re
+
+
+atoms_list = ['r', 'np']
+
+
+def cut_category_in_symbols(category):
+    '''
+    Parameters :
+    category : str of kind AtomCat | CategoryCat
+    Returns :
+    Boolean Tensor of shape max_symbols_in_word, containing 1 for pos indexes and 0 for neg indexes
+    '''
+    category_to_weights = []
+
+    if category in atoms_list:
+        category_to_weights.append(True)
+
+    else:
+        # dr = /
+        if category.startswith("dr"):
+            category_cut = re.search(r'dr\(\d+,(.+),(.+)\)', category)
+            left_side, right_side = category_cut.group(1), category_cut.group(2)
+
+            # for the left side
+            if left_side in atoms_list:
+                category_to_weights.append(False)
+            else:
+                category_to_weights += cut_category_in_symbols(left_side)
+
+            # for the right side
+            if right_side in atoms_list:
+                category_to_weights.append(True)
+            else:
+                category_to_weights += cut_category_in_symbols(right_side)
+
+        # dl = \
+        elif category.startswith("dl"):
+            category_cut = re.search(r'dl\(\d+,(.+),(.+)\)', category)
+            left_side, right_side = category_cut.group(1), category_cut.group(2)
+
+            # for the left side
+            if left_side in atoms_list:
+                category_to_weights.append(True)
+            else:
+                category_to_weights += cut_category_in_symbols(left_side)
+
+            # for the right side
+            if right_side in atoms_list:
+                category_to_weights.append(False)
+            else:
+                category_to_weights += cut_category_in_symbols(right_side)
+
+    return category_to_weights
+
+
+print( cut_category_in_symbols('dr(1,dr(1,r,np),np)'))
+
+
+def find_pos_neg_idexes(batch_symbols):
+    '''
+    Parameters :
+    batch_symbols : (batch_size, sequence_length) the batch of symbols
+
+    Returns :
+    (batch_size, max_symbols_in_sentence) boolean tensor indiating pos and ne indexes
+    '''
+    return None
+
+
+def make_sinkhorn_inputs(bsd_tensor, positional_ids):
+    """
+    :param bsd_tensor:
+        Tensor of shape (batch size, sequence length, feature dimensionality).
+    :param positional_ids:
+        A List (batch_size, max_atoms_in_sentence) .
+        Each positional_ids[b][a] indexes the location of atoms of type a in sentence b.
+    :return:
+    """
+
+    return [[bsd_tensor.select(0, index=i).index_select(0, index=atom) for atom in sentence]
+            for i, sentence in enumerate(positional_ids)]
\ No newline at end of file
diff --git a/test.py b/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..d6882f3fedce1202da1ebc5fa1b35d6b9cf7c409
--- /dev/null
+++ b/test.py
@@ -0,0 +1,7 @@
+l = [[False, True, True, False],
+        [True, False, True, False]]
+
+print(l)
+print([i for i, x in enumerate(l) if x])
+
+print(list(map(lambda sub_list : [i for i, x in enumerate(sub_list) if x], l)))
\ No newline at end of file
diff --git a/train.py b/train.py
index 7f595a9eb61ecd1bbcc092acbe1f6c4f38fbaaba..58ebe4523d2004d52862905df2c09d88aff9dd81 100644
--- a/train.py
+++ b/train.py
@@ -26,7 +26,6 @@ torch.cuda.empty_cache()
 
 # region ParamsModel
 
-max_symbols_in_sentence = int(Configuration.modelDecoderConfig['max_symbols_in_sentence'])
 max_len_sentence = int(Configuration.modelDecoderConfig['max_len_sentence'])
 symbol_vocab_size = int(Configuration.modelDecoderConfig['symbols_vocab_size'])
 num_gru_layers = int(Configuration.modelDecoderConfig['num_rnn_layers'])
@@ -74,7 +73,6 @@ print("##" * 15 + "\nConfiguration : \n")
 
 print("ParamsModel\n")
 
-print("\tmax_symbols_in_sentence :", max_symbols_in_sentence)
 print("\tsymbol_vocab_size :", symbol_vocab_size)
 print("\tbidirectional : ", False)
 print("\tnum_gru_layers : ", num_gru_layers)
@@ -117,7 +115,7 @@ BASE_TOKENIZER = AutoTokenizer.from_pretrained(
     'camembert-base',
     do_lower_case=True)
 BASE_MODEL = CamembertModel.from_pretrained("camembert-base")
-symbols_tokenizer = SymbolTokenizer(symbol_map, max_symbols_in_sentence, max_len_sentence)
+symbols_tokenizer = SymbolTokenizer(symbol_map, max_len_sentence, max_len_sentence)
 sents_tokenizer = EncoderInput(BASE_TOKENIZER)
 model = EncoderDecoder(BASE_TOKENIZER, BASE_MODEL, symbol_map)
 model = model.to("cuda" if torch.cuda.is_available() else "cpu")