diff --git a/SuperTagger/Linker/AtomEmbedding.py b/SuperTagger/Linker/AtomEmbedding.py
new file mode 100644
index 0000000000000000000000000000000000000000..e7be599a0fa145f76a5646b83973a3501ed52d4d
--- /dev/null
+++ b/SuperTagger/Linker/AtomEmbedding.py
@@ -0,0 +1,12 @@
+import torch
+from torch.nn import Module, Embedding
+
+
+class AtomEmbedding(Module):
+    def __init__(self, dim_linker, atom_vocab_size, padding_idx=None):
+        super(AtomEmbedding, self).__init__()
+        self.emb = Embedding(num_embeddings=atom_vocab_size, embedding_dim=dim_linker, padding_idx=padding_idx,
+                             scale_grad_by_freq=True)
+
+    def forward(self, x):
+        return self.emb(x)
diff --git a/SuperTagger/Linker/AtomTokenizer.py b/SuperTagger/Linker/AtomTokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..e400d4ef28a90fda4e8e1d5f13276720c1de9fe2
--- /dev/null
+++ b/SuperTagger/Linker/AtomTokenizer.py
@@ -0,0 +1,51 @@
+import torch
+
+
+class AtomTokenizer(object):
+    def __init__(self, atom_map, max_atoms_in_sentence):
+        self.atom_map = atom_map
+        self.max_atoms_in_sentence = max_atoms_in_sentence
+        self.inverse_atom_map = {v: k for k, v in self.atom_map.items()}
+        self.sep_token = '[SEP]'
+        self.pad_token = '[PAD]'
+        self.sos_token = '[SOS]'
+        self.sep_token_id = self.atom_map[self.sep_token]
+        self.pad_token_id = self.atom_map[self.pad_token]
+        self.sos_token_id = self.atom_map[self.sos_token]
+
+    def __len__(self):
+        return len(self.atom_map)
+
+    def convert_atoms_to_ids(self, atom):
+        return self.atom_map[str(atom)]
+
+    def convert_sents_to_ids(self, sentences):
+        return torch.as_tensor([self.convert_atoms_to_ids(atom) for atom in sentences])
+
+    def convert_batchs_to_ids(self, batchs_sentences):
+        return torch.as_tensor(pad_sequence([self.convert_sents_to_ids(sents) for sents in batchs_sentences],
+                                            max_len=self.max_atoms_in_sentence, padding_value=self.pad_token_id))
+
+    def convert_ids_to_atoms(self, ids):
+        return [self.inverse_atom_map[int(i)] for i in ids]
+
+
+def pad_sequence(sequences, batch_first=True, padding_value=0, max_len=400):
+    max_size = sequences[0].size()
+    trailing_dims = max_size[1:]
+    if batch_first:
+        out_dims = (len(sequences), max_len) + trailing_dims
+    else:
+        out_dims = (max_len, len(sequences)) + trailing_dims
+
+    out_tensor = sequences[0].data.new(*out_dims).fill_(padding_value)
+    for i, tensor in enumerate(sequences):
+        length = tensor.size(0)
+        # use index notation to prevent duplicate references to the tensor
+        if batch_first:
+            out_tensor[i, :length, ...] = tensor
+        else:
+            out_tensor[:length, i, ...] = tensor
+
+    return out_tensor
+
diff --git a/SuperTagger/Linker/Linker.py b/SuperTagger/Linker/Linker.py
index 65682306c2269b022e3ef23f1bd83da9aad19bf1..745b7d96083aca93f873c325cbcebd34939ecacf 100644
--- a/SuperTagger/Linker/Linker.py
+++ b/SuperTagger/Linker/Linker.py
@@ -4,9 +4,11 @@ import torch
 from torch.nn import Sequential, LayerNorm, Linear, Dropout, GELU
 
 from Configuration import Configuration
-
+from SuperTagger.Linker.AtomEmbedding import AtomEmbedding
+from SuperTagger.Linker.AtomTokenizer import AtomTokenizer
+from SuperTagger.Linker.atom_map import atom_map
 from SuperTagger.Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn
-from SuperTagger.Linker.utils import find_pos_neg_idexes, make_sinkhorn_inputs
+from SuperTagger.Linker.utils import find_pos_neg_idexes, get_atoms_batch
 
 
 def FFN(d_model, d_ff, dropout_rate=0.1, d_out=None):
@@ -24,56 +26,67 @@ class Linker:
 
         self.dim_encoder = int(Configuration.modelDecoderConfig['dim_encoder'])
         self.dim_decoder = int(Configuration.modelDecoderConfig['dim_decoder'])
+        self.dim_linker = int(Configuration.modelDecoderConfig['dim_linker'])
+        self.max_atoms_in_sentence = int(Configuration.modelDecoderConfig['max_atoms_in_sentence'])
+        self.atom_vocab_size = int(Configuration.modelDecoderConfig['atom_vocab_size'])
 
         self.dropout = Dropout(0.1)
 
+        self.atom_map = atom_map
+        self.padding_id = self.atom_map['[PAD]']
+        self.atom_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence)
+        self.atom_embedding = AtomEmbedding(self.dim_linker, self.atom_vocab_size, self.padding_id)
+
+        # to do : definit un encoding
+        self.linker_encoder = FFN(self.dim_linker, self.dim_linker, 0.1)
+
         self.pos_transformation = Sequential(
-            FFN(self.dim_decoder * 2, self.dim_decoder, 0.1, self.dim_decoder // 2),
-            LayerNorm(self.dim_decoder // 2, eps=1e-12)
+            FFN(self.dim_decoder, self.dim_decoder, 0.1),
+            LayerNorm(self.dim_decoder, eps=1e-12)
         )
         self.neg_transformation = Sequential(
-            FFN(self.dim_decoder * 2, self.dim_decoder, 0.1, self.dim_decoder // 2),
-            LayerNorm(self.dim_decoder // 2, eps=1e-12)
+            FFN(self.dim_decoder, self.dim_decoder, 0.1),
+            LayerNorm(self.dim_decoder, eps=1e-12)
         )
 
-    def forward(self, symbols_batch, symbols_decoding):
+    def forward(self, category_batch):
         '''
         Parameters :
         symbols_decoding : batch of size (batch_size, sequence_length) = output of decoder
+        Retturns :
+        link_weights : batch-size, atom_vocab_size, ...)
         '''
 
-        # some sequential for linker with output of decoder and initial ato
-
-        # decompose into batch_size, max symbols in sentence
-        decompose_decoding = find_pos_neg_idexes(symbols_batch)
-
-        # get  tensors of shape (batch_size, max_symbols_in_sentence/2)
-        pos_idxes_batch = list(map(lambda sub_list: [i for i, x in enumerate(sub_list) if x], decompose_decoding))
-        neg_idxes_batch = list(map(lambda sub_list: [i for i, x in enumerate(sub_list) if not x], decompose_decoding))
-
-        _positives = make_sinkhorn_inputs(symbols_decoding, pos_idxes_batch)
-        _negatives = make_sinkhorn_inputs(symbols_decoding, neg_idxes_batch)
+        # atoms embedding
+        atoms_batch = get_atoms_batch(category_batch)
+        atoms_batch = self.atom_tokenizer.convert_batchs_to_ids(atoms_batch)
+        atoms_embedding = self.atom_embedding(atoms_batch)
 
-        positives = [tensor for tensor in chain.from_iterable(_positives) if min(tensor.size()) != 0]
-        negatives = [tensor for tensor in chain.from_iterable(_negatives) if min(tensor.size()) != 0]
+        # MHA ou LSTM avec sortie de BERT
+        #
+        # TO DO
+        # atoms_encoding = self.linker_encoder(atoms_embedding)
+        #
+        atoms_encoding = atoms_embedding
 
-        distinct_shapes = {tensor.size()[0] for tensor in positives}
-        distinct_shapes = sorted(distinct_shapes)
+        # find atoms polarity : list (not tensor) (batch_size, max_atoms_in sentence)
+        atoms_polarity = find_pos_neg_idexes(category_batch)
 
-        # going to match the pos and neg together
-        matches = []
+        link_weights = []
+        for sentence_idx in range(len(atoms_polarity)):
+            for atom_type in self.atom_map.keys():
+                pos_idx_per_atom_type = [i for i, x in enumerate(atoms_polarity[sentence_idx]) if
+                                         x and atoms_batch[sentence_idx][i] == atom_type]
+                neg_idx_per_atom_type = [i for i, x in enumerate(atoms_polarity[sentence_idx]) if
+                                         not x and atoms_batch[sentence_idx][i] == atom_type]
 
-        all_shape_positives = [self.pos_transformation(self.dropout(torch.stack([tensor for tensor in positives
-                                                                                 if tensor.size()[0] == shape])))
-                               for shape in distinct_shapes]
+                pos_encoding = atoms_encoding[sentence_idx, pos_idx_per_atom_type, :]
+                neg_encoding = atoms_encoding[sentence_idx, neg_idx_per_atom_type, :]
 
-        all_shape_negatives = [self.neg_transformation(self.dropout(torch.stack([tensor for tensor in negatives
-                                                                                 if tensor.size()[0] == shape])))
-                               for shape in distinct_shapes]
+                pos_encoding = self.pos_transformation(pos_encoding)
+                neg_encoding = self.neg_transformation(neg_encoding)
 
-        for this_shape_positives, this_shape_negatives in zip(all_shape_positives, all_shape_negatives):
-            weights = torch.bmm(this_shape_positives,
-                                this_shape_negatives.transpose(2, 1))
-            matches.append(sinkhorn(weights, iters=3))
+                weights = torch.bmm(pos_encoding.unsqueeze(0), neg_encoding.transpose(1, 0).unsqueeze(0))
+                link_weights.append(sinkhorn(weights, iters=3))
 
-        return matches
+        return link_weights
diff --git a/SuperTagger/Linker/__pycache__/AtomTokenizer.cpython-38.pyc b/SuperTagger/Linker/__pycache__/AtomTokenizer.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..eb55c03f688748485e4452c56dda80e12c73c904
Binary files /dev/null and b/SuperTagger/Linker/__pycache__/AtomTokenizer.cpython-38.pyc differ
diff --git a/SuperTagger/Linker/__pycache__/Sinkhorn.cpython-38.pyc b/SuperTagger/Linker/__pycache__/Sinkhorn.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..afbe4b3ab22416e9fb3fa5b7e422587b47fe3c95
Binary files /dev/null and b/SuperTagger/Linker/__pycache__/Sinkhorn.cpython-38.pyc differ
diff --git a/SuperTagger/Linker/__pycache__/atom_map.cpython-38.pyc b/SuperTagger/Linker/__pycache__/atom_map.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..95490adcd0f1a39a9d2181a1fe19b54b9ac7e4be
Binary files /dev/null and b/SuperTagger/Linker/__pycache__/atom_map.cpython-38.pyc differ
diff --git a/SuperTagger/Linker/atom_map.py b/SuperTagger/Linker/atom_map.py
new file mode 100644
index 0000000000000000000000000000000000000000..893fd00518bc8b58bb9b3448243286c8ca14e6b8
--- /dev/null
+++ b/SuperTagger/Linker/atom_map.py
@@ -0,0 +1,28 @@
+atom_map = \
+    {'cl_r': 0,
+     '\\': 1,
+     'n': 2,
+     'p': 3,
+     's_ppres': 4,
+     'dia': 5,
+     's_whq': 6,
+     'let': 7,
+     '/': 8,
+     's_inf': 9,
+     's_pass': 10,
+     'pp_a': 11,
+     'pp_par': 12,
+     'pp_de': 13,
+     'cl_y': 14,
+     'box': 15,
+     'txt': 16,
+     's': 17,
+     's_ppart': 18,
+     's_q': 19,
+     'np': 20,
+     'pp': 21,
+     '[SEP]': 22,
+     '[SOS]': 23,
+     '[START]': 24,
+     '[PAD]': 25
+     }
diff --git a/SuperTagger/Linker/utils.py b/SuperTagger/Linker/utils.py
index 49e702c77c7b9bc1c400c57711049fdbac15bfe7..ddb8cb582d60625fb82b651663dc0a732bf6fb7a 100644
--- a/SuperTagger/Linker/utils.py
+++ b/SuperTagger/Linker/utils.py
@@ -1,7 +1,30 @@
 import re
 
+from SuperTagger.Linker.AtomTokenizer import AtomTokenizer
+from SuperTagger.Linker.atom_map import atom_map
 
-atoms_list = ['r', 'np']
+
+def get_atoms_from_category(category, category_to_atoms):
+    if category in atom_map.keys():
+        return [category]
+    else:
+        category_cut = re.search(r'\w*\(\d+,(.+),(.+)\)', category)
+        left_side, right_side = category_cut.group(1), category_cut.group(2)
+
+        category_to_atoms += get_atoms_from_category(left_side, [])
+        category_to_atoms += get_atoms_from_category(right_side, [])
+
+        return category_to_atoms
+
+
+def get_atoms_batch(category_batch):
+    batch = []
+    for sentence in category_batch:
+        category_to_atoms = []
+        for category in sentence:
+            category_to_atoms = get_atoms_from_category(category, category_to_atoms)
+        batch.append(category_to_atoms)
+    return batch
 
 
 def cut_category_in_symbols(category):
@@ -11,10 +34,10 @@ def cut_category_in_symbols(category):
     Returns :
     Boolean Tensor of shape max_symbols_in_word, containing 1 for pos indexes and 0 for neg indexes
     '''
-    category_to_weights = []
+    category_to_polarity = []
 
-    if category in atoms_list:
-        category_to_weights.append(True)
+    if category in atom_map.keys():
+        category_to_polarity.append(True)
 
     else:
         # dr = /
@@ -23,16 +46,16 @@ def cut_category_in_symbols(category):
             left_side, right_side = category_cut.group(1), category_cut.group(2)
 
             # for the left side
-            if left_side in atoms_list:
-                category_to_weights.append(False)
+            if left_side in atom_map.keys():
+                category_to_polarity.append(False)
             else:
-                category_to_weights += cut_category_in_symbols(left_side)
+                category_to_polarity += cut_category_in_symbols(left_side)
 
             # for the right side
-            if right_side in atoms_list:
-                category_to_weights.append(True)
+            if right_side in atom_map.keys():
+                category_to_polarity.append(True)
             else:
-                category_to_weights += cut_category_in_symbols(right_side)
+                category_to_polarity += cut_category_in_symbols(right_side)
 
         # dl = \
         elif category.startswith("dl"):
@@ -40,21 +63,18 @@ def cut_category_in_symbols(category):
             left_side, right_side = category_cut.group(1), category_cut.group(2)
 
             # for the left side
-            if left_side in atoms_list:
-                category_to_weights.append(True)
+            if left_side in atom_map.keys():
+                category_to_polarity.append(True)
             else:
-                category_to_weights += cut_category_in_symbols(left_side)
+                category_to_polarity += cut_category_in_symbols(left_side)
 
             # for the right side
-            if right_side in atoms_list:
-                category_to_weights.append(False)
+            if right_side in atom_map.keys():
+                category_to_polarity.append(False)
             else:
-                category_to_weights += cut_category_in_symbols(right_side)
-
-    return category_to_weights
-
+                category_to_polarity += cut_category_in_symbols(right_side)
 
-print( cut_category_in_symbols('dr(1,dr(1,r,np),np)'))
+    return category_to_polarity
 
 
 def find_pos_neg_idexes(batch_symbols):
@@ -65,18 +85,11 @@ def find_pos_neg_idexes(batch_symbols):
     Returns :
     (batch_size, max_symbols_in_sentence) boolean tensor indiating pos and ne indexes
     '''
-    return None
-
-
-def make_sinkhorn_inputs(bsd_tensor, positional_ids):
-    """
-    :param bsd_tensor:
-        Tensor of shape (batch size, sequence length, feature dimensionality).
-    :param positional_ids:
-        A List (batch_size, max_atoms_in_sentence) .
-        Each positional_ids[b][a] indexes the location of atoms of type a in sentence b.
-    :return:
-    """
+    list_batch = []
+    for sentence in batch_symbols:
+        list_symbols = []
+        for category in sentence:
+            list_symbols.append(cut_category_in_symbols(category))
+        list_batch.append(list_symbols)
+    return list_batch
 
-    return [[bsd_tensor.select(0, index=i).index_select(0, index=atom) for atom in sentence]
-            for i, sentence in enumerate(positional_ids)]
\ No newline at end of file
diff --git a/test.py b/test.py
index d6882f3fedce1202da1ebc5fa1b35d6b9cf7c409..f208027894f01b95d1509ccd2fafb58b12c2ac44 100644
--- a/test.py
+++ b/test.py
@@ -1,7 +1,27 @@
-l = [[False, True, True, False],
-        [True, False, True, False]]
+from SuperTagger.Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn
+import torch
 
-print(l)
-print([i for i, x in enumerate(l) if x])
+atoms_batch = [["np", "v", "np", "v","np", "v", "np", "v"],
+               ["np", "np", "v", "v","np", "np", "v", "v"]]
 
-print(list(map(lambda sub_list : [i for i, x in enumerate(sub_list) if x], l)))
\ No newline at end of file
+atoms_polarity = [[False, True, True, False,False, True, True, False],
+                  [True, False, True, False,True, False, True, False]]
+
+atoms_encoding = torch.randn((2, 8, 24))
+
+matches = []
+for sentence_idx in range(len(atoms_polarity)):
+
+    for atom_type in ["np", "v"]:
+        pos_idx_per_atom_type = [i for i, x in enumerate(atoms_polarity[sentence_idx]) if
+                                 x and atoms_batch[sentence_idx][i] == atom_type]
+        neg_idx_per_atom_type = [i for i, x in enumerate(atoms_polarity[sentence_idx]) if
+                                 not x and atoms_batch[sentence_idx][i] == atom_type]
+
+        pos_encoding = atoms_encoding[sentence_idx, pos_idx_per_atom_type, :]
+        neg_encoding = atoms_encoding[sentence_idx, neg_idx_per_atom_type, :]
+
+        weights = torch.bmm(pos_encoding.unsqueeze(0), neg_encoding.transpose(1, 0).unsqueeze(0))
+        matches.append(sinkhorn(weights, iters=3))
+
+print(matches)