diff --git a/Configuration/config.ini b/Configuration/config.ini
index 1d2d4b08635b2048161e5dc6953c5247ddea3d17..eae7cc331135b9b6330a4d4b5718b1a13a87c9b4 100644
--- a/Configuration/config.ini
+++ b/Configuration/config.ini
@@ -6,6 +6,7 @@ symbols_vocab_size=26
 atom_vocab_size=12
 max_len_sentence=148
 max_symbols_in_sentence=1250
+max_atoms_in_one_type=50
 
 [MODEL_ENCODER]
 dim_encoder = 768
diff --git a/SuperTagger/Linker/AttentionDecoderLayer.py b/SuperTagger/Linker/AttentionDecoderLayer.py
new file mode 100644
index 0000000000000000000000000000000000000000..3843b19ba246211e7cdbf07a8270f062eaed7b0f
--- /dev/null
+++ b/SuperTagger/Linker/AttentionDecoderLayer.py
@@ -0,0 +1,113 @@
+from torch import Tensor
+from torch.nn import (GELU, Dropout, LayerNorm, Linear, Module, MultiheadAttention,
+                      Sequential)
+
+from Configuration import Configuration
+from SuperTagger.Symbol.SymbolEmbedding import SymbolEmbedding
+
+
+class FullyConnectedFeedForward(Module):
+    "Implements FFN equation."
+
+    def __init__(self, d_model, d_ff, dropout=0.1):
+        super(FullyConnectedFeedForward, self).__init__()
+        self.ffn = Sequential(
+            Linear(d_model, d_ff, bias=False),
+            GELU(),
+            Dropout(dropout),
+            Linear(d_ff, d_model, bias=False)
+        )
+
+    def forward(self, x):
+        return self.ffn(x)
+
+
+class AttentionDecoderLayer(Module):
+    r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network.
+    This standard decoder layer is based on the paper "Attention Is All You Need".
+    Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
+    Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
+    Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
+    in a different way during application.
+
+    Args:
+        dim_model: the number of expected features in the input (required).
+        nhead: the number of heads in the multiheadattention models (required).
+        dim_feedforward: the dimension of the feedforward network model (default=2048).
+        dropout: the dropout value (default=0.1).
+        activation: the activation function of the intermediate layer, can be a string
+            ("relu" or "gelu") or a unary callable. Default: relu
+        layer_norm_eps: the eps value in layer normalization components (default=1e-5).
+        batch_first: If ``True``, then the input and output tensors are provided
+            as (batch, seq, feature). Default: ``False``.
+        norm_first: if ``True``, layer norm is done prior to self attention, multihead
+            attention and feedforward operations, respectivaly. Otherwise it's done after.
+            Default: ``False`` (after).
+    """
+    __constants__ = ['batch_first', 'norm_first']
+
+    def __init__(self) -> None:
+        factory_kwargs = {'device': Configuration.modelDecoderConfig['device'],
+                          'dtype': Configuration.modelDecoderConfig['dtype']}
+        super(AttentionDecoderLayer, self).__init__()
+
+        # init params
+        dim_encoder = int(Configuration.modelDecoderConfig['dim_encoder'])
+        dim_decoder = int(Configuration.modelDecoderConfig['dim_decoder'])
+        self.max_symbols_in_sentence = int(Configuration.modelDecoderConfig['max_symbols_in_sentence'])
+        nhead = int(Configuration.modelDecoderConfig['nhead'])
+        dropout = float(Configuration.modelDecoderConfig['dropout'])
+        dim_feedforward = int(Configuration.modelDecoderConfig['dim_feedforward'])
+        layer_norm_eps = float(Configuration.modelDecoderConfig['layer_norm_eps'])
+        self.nhead = int(Configuration.modelDecoderConfig['nhead'])
+
+        self.symbols_embedder = SymbolEmbedding(self.dim_decoder, self.symbols_vocab_size)
+
+        # layers
+        self.dropout = Dropout(dropout)
+        self.self_attn = MultiheadAttention(dim_decoder, nhead, dropout=dropout, batch_first=True,
+                                            kdim=dim_decoder, vdim=dim_decoder)
+        self.norm1 = LayerNorm(dim_decoder, eps=layer_norm_eps)
+        self.multihead_attn = MultiheadAttention(dim_decoder, nhead, dropout=dropout,
+                                                 kdim=dim_encoder, vdim=dim_encoder,
+                                                 batch_first=True)
+        self.norm2 = LayerNorm(dim_decoder, eps=layer_norm_eps)
+        self.ffn = FullyConnectedFeedForward(d_model=dim_decoder, d_ff=dim_feedforward, dropout=dropout)
+        self.norm3 = LayerNorm(dim_decoder, eps=layer_norm_eps)
+
+    def forward(self, symbols_tokens: Tensor, sents_embedding: Tensor, encoder_mask: Tensor,
+                decoder_mask: Tensor) -> Tensor:
+        r"""Pass the inputs through the decoder layer.
+
+        Args:
+            symbols: the sequence to the decoder layer (required).
+            sents: the sequence from the last layer of the encoder (required).
+        """
+        x = symbols_tokens
+        x = self.symbols_embedder(symbols_tokens)
+        x = self.norm1(x + self._mask_mha_block(x, decoder_mask))
+        x = self.norm2(x + self._mha_block(x, sents_embedding, encoder_mask))
+        x = self.norm3(x + self._ff_block(x))
+
+        return x
+
+    # self-attention block
+    def _mask_mha_block(self, x: Tensor, decoder_mask: Tensor) -> Tensor:
+        if decoder_mask is not None:
+            # Same mask applied to all h heads.
+            decoder_mask = decoder_mask.repeat(self.nhead, 1, 1)
+        x = self.self_attn(x, x, x, attn_mask=decoder_mask)[0]
+        return x
+
+    # multihead attention block
+    def _mha_block(self, x: Tensor, sents_embs: Tensor, encoder_mask: Tensor) -> Tensor:
+        if encoder_mask is not None:
+            # Same mask applied to all h heads.
+            encoder_mask = encoder_mask.repeat(self.nhead, 1, 1)     
+        x = self.multihead_attn(x, sents_embs, sents_embs, attn_mask=encoder_mask)[0]
+        return x
+
+    # feed forward block
+    def _ff_block(self, x: Tensor) -> Tensor:
+        x = self.ffn.forward(x)
+        return x
diff --git a/SuperTagger/Linker/Linker.py b/SuperTagger/Linker/Linker.py
index 181b33ff4955f2958c4074c7d788117a21910d15..4567362753856bc888a46ed2b79ae759d5216128 100644
--- a/SuperTagger/Linker/Linker.py
+++ b/SuperTagger/Linker/Linker.py
@@ -41,6 +41,7 @@ class Linker(Module):
         self.dim_embedding_atoms = int(Configuration.modelLinkerConfig['dim_embedding_atoms'])
         self.sinkhorn_iters = int(Configuration.modelLinkerConfig['sinkhorn_iters'])
         self.max_atoms_in_sentence = int(Configuration.datasetConfig['max_atoms_in_sentence'])
+        self.max_atoms_in_one_type = int(Configuration.datasetConfig['max_atoms_in_one_type'])
         self.atom_vocab_size = int(Configuration.datasetConfig['atom_vocab_size'])
         self.dropout = Dropout(0.1)
 
@@ -100,11 +101,11 @@ class Linker(Module):
             # to do select with list of list
             pos_encoding = pad_sequence(
                 [atoms_encoding.select(0, index=i).index_select(0, index=torch.as_tensor(sentence))
-                 for i, sentence in enumerate(pos_idx_per_atom_type)], max_len=self.max_atoms_in_sentence,
+                 for i, sentence in enumerate(pos_idx_per_atom_type)], max_len=self.max_atoms_in_one_type//2,
                 padding_value=0)
             neg_encoding = pad_sequence(
                 [atoms_encoding.select(0, index=i).index_select(0, index=torch.as_tensor(sentence))
-                 for i, sentence in enumerate(neg_idx_per_atom_type)], max_len=self.max_atoms_in_sentence,
+                 for i, sentence in enumerate(neg_idx_per_atom_type)], max_len=self.max_atoms_in_one_type//2,
                 padding_value=0)
 
             # pos_encoding = self.pos_transformation(pos_encoding)
diff --git a/SuperTagger/Linker/utils.py b/SuperTagger/Linker/utils.py
index aa2ad6ecdd510d0ac883d5c9e756ff11b2a9223d..b84f9072f83eed1dee87f93deb08a0b4a22d044f 100644
--- a/SuperTagger/Linker/utils.py
+++ b/SuperTagger/Linker/utils.py
@@ -4,7 +4,7 @@ from SuperTagger.Linker.AtomTokenizer import AtomTokenizer
 from SuperTagger.Linker.atom_map import atom_map
 
 
-def category_to_atoms(category, category_to_atoms):
+def category_to_atoms(category, categories_to_atoms):
     res = [i for i in atom_map.keys() if category in i]
     if len(res) > 0:
         return [category]
@@ -12,19 +12,19 @@ def category_to_atoms(category, category_to_atoms):
         category_cut = re.search(r'\w*\(\d+,(.+),(.+)\)', category)
         left_side, right_side = category_cut.group(1), category_cut.group(2)
 
-        category_to_atoms += category_to_atoms(left_side, [])
-        category_to_atoms += category_to_atoms(right_side, [])
+        categories_to_atoms += category_to_atoms(left_side, [])
+        categories_to_atoms += category_to_atoms(right_side, [])
 
-        return category_to_atoms
+        return categories_to_atoms
 
 
 def get_atoms_batch(category_batch):
     batch = []
     for sentence in category_batch:
-        category_to_atoms = []
+        categories_to_atoms = []
         for category in sentence:
-            category_to_atoms = category_to_atoms(category, category_to_atoms)
-        batch.append(category_to_atoms)
+            categories_to_atoms = category_to_atoms(category, categories_to_atoms)
+        batch.append(categories_to_atoms)
     return batch
 
 
diff --git a/SuperTagger/eval.py b/SuperTagger/eval.py
index 07441e818de95760681cbfcb4317f005dae48438..6b7e3d82d22eade15059be86b9be59bb73f71c25 100644
--- a/SuperTagger/eval.py
+++ b/SuperTagger/eval.py
@@ -2,6 +2,7 @@ import torch
 from torch import Tensor
 from torch.nn import Module
 from torch.nn.functional import nll_loss, cross_entropy
+from SuperTagger.Linker.atom_map import atom_map
 
 from SuperTagger.Linker.utils import get_atoms_batch, find_pos_neg_idexes
 
@@ -26,12 +27,20 @@ def mesure_accuracy(batch_axiom_links, axiom_links_pred):
     # then convert into atom_vocab_size lists of (batch_size, max atom in one cat) with prefix parcours of graphe
     atoms_polarity = find_pos_neg_idexes(atoms_batch)
 
-    axiom_links_true = ""
+    num_correct_links = 0
+    for atom_type in atom_map.keys():
+        #filtrer sur atom_batch que ce type puis filtrer avec les indices sur atom polarity
 
-    # match axiom_links_pred and true data
+        # contruire liste + et liste -
 
-    correct_links = torch.ones(axiom_links_pred.size())
-    correct_links[axiom_links_pred != axiom_links_true] = 0
-    num_correct_links = correct_links.sum().item()
+        # associer par indice
+
+        axiom_links_true = ""
+
+        # match axiom_links_pred and true data
+
+        correct_links = torch.ones(axiom_links_pred.size())
+        correct_links[axiom_links_pred != axiom_links_true] = 0
+        num_correct_links += correct_links.sum().item()
 
     return num_correct_links/(axiom_links_pred.size()[0] * axiom_links_pred.size()[1])
diff --git a/test.py b/test.py
index 9e14d08a3992b2c2c6366af2ef3943a2737f3ef8..f0a8139e1ab329a7fc05ea5832c21c2760bc0494 100644
--- a/test.py
+++ b/test.py
@@ -30,25 +30,29 @@ atoms_polarity = [[False, True, True, False, False, True, True, False],
 
 atoms_encoding = torch.randn((2, 8, 24))
 
-matches = []
+link_weights=[]
 for atom_type in ["np", "v"]:
     pos_idx_per_atom_type = [[i for i, x in enumerate(atoms_polarity[s_idx]) if
                               x and atoms_batch[s_idx][i] == atom_type] for s_idx in range(len(atoms_polarity))]
+    print(pos_idx_per_atom_type)
     neg_idx_per_atom_type = [[i for i, x in enumerate(atoms_polarity[s_idx]) if
-                              not x and atoms_batch[s_idx][i] == atom_type] for s_idx in range(len(atoms_polarity))]
+                              not x and atoms_batch[s_idx][i] == atom_type] for s_idx in
+                             range(len(atoms_polarity))]
 
     # to do select with list of list
-    pos_encoding = pad_sequence([atoms_encoding.select(0, index=i).index_select(0, index=torch.as_tensor(sentence))
-            for i, sentence in enumerate(pos_idx_per_atom_type)], max_len=3, padding_value=0)
-    neg_encoding = pad_sequence([atoms_encoding.select(0, index=i).index_select(0, index=torch.as_tensor(sentence))
-            for i, sentence in enumerate(neg_idx_per_atom_type)], max_len=3, padding_value=0)
-
-    print(neg_encoding.shape)
+    pos_encoding = pad_sequence(
+        [atoms_encoding.select(0, index=i).index_select(0, index=torch.as_tensor(sentence))
+         for i, sentence in enumerate(pos_idx_per_atom_type)], max_len=3,
+        padding_value=0)
+    neg_encoding = pad_sequence(
+        [atoms_encoding.select(0, index=i).index_select(0, index=torch.as_tensor(sentence))
+         for i, sentence in enumerate(neg_idx_per_atom_type)], max_len=3,
+        padding_value=0)
+
+    # pos_encoding = self.pos_transformation(pos_encoding)
+    # neg_encoding = self.neg_transformation(neg_encoding)
 
     weights = torch.bmm(pos_encoding, neg_encoding.transpose(2, 1))
-    print(weights.shape)
-    print("sinkhorn")
-    print(sinkhorn(weights, iters=3).shape)
-    matches.append(sinkhorn(weights, iters=3))
+    link_weights.append(sinkhorn(weights, iters=3))
 
-print(matches)
+print(torch.cat([link_weights[i].unsqueeze(0) for i in range(len(link_weights))]).shape)
\ No newline at end of file