diff --git a/Configuration/config.ini b/Configuration/config.ini
index dafdae64b7de87bb404e81a233852de939a013f8..8e6c08cc532c9390444b52bc9da07564ddd04d6c 100644
--- a/Configuration/config.ini
+++ b/Configuration/config.ini
@@ -18,7 +18,7 @@ dropout=0.1
 teacher_forcing=0.05
 
 [MODEL_LINKER]
-nhead=8
+nhead=1
 dim_feedforward=246
 dim_embedding_atoms=8
 dim_polarity_transfo=128
diff --git a/SuperTagger/Linker/Linker.py b/SuperTagger/Linker/Linker.py
index ef03e0e44d146e8a6ae70bebd9e41123c6b6ba1d..93028fdeaf6cc7f1cc978e796d56d73fd0ff6b5a 100644
--- a/SuperTagger/Linker/Linker.py
+++ b/SuperTagger/Linker/Linker.py
@@ -1,36 +1,21 @@
 from itertools import chain
 
 import torch
-from torch.nn import Sequential, LayerNorm, Linear, Dropout, GELU
+from torch.nn import Sequential, LayerNorm, Linear, Dropout, GELU, MultiheadAttention
 from torch.nn import Module
 import torch.nn.functional as F
 
 from Configuration import Configuration
 from SuperTagger.Linker.AtomEmbedding import AtomEmbedding
 from SuperTagger.Linker.AtomTokenizer import AtomTokenizer
+from SuperTagger.Linker.MHA import AttentionDecoderLayer
 from SuperTagger.Linker.atom_map import atom_map
 from SuperTagger.Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn
-from SuperTagger.Linker.utils import find_pos_neg_idexes, get_atoms_batch
+from SuperTagger.Linker.utils import find_pos_neg_idexes, get_atoms_batch, FFN
 from SuperTagger.eval import mesure_accuracy
 from SuperTagger.utils import pad_sequence
 
 
-class FFN(Module):
-    "Implements FFN equation."
-
-    def __init__(self, d_model, d_ff, dropout=0.1):
-        super(FFN, self).__init__()
-        self.ffn = Sequential(
-            Linear(d_model, d_ff, bias=False),
-            GELU(),
-            Dropout(dropout),
-            Linear(d_ff, d_model, bias=False)
-        )
-
-    def forward(self, x):
-        return self.ffn(x)
-
-
 class Linker(Module):
     def __init__(self):
         super(Linker, self).__init__()
@@ -39,6 +24,8 @@ class Linker(Module):
         self.dim_polarity_transfo = int(Configuration.modelLinkerConfig['dim_polarity_transfo'])
         self.dim_embedding_atoms = int(Configuration.modelLinkerConfig['dim_embedding_atoms'])
         self.sinkhorn_iters = int(Configuration.modelLinkerConfig['sinkhorn_iters'])
+        self.nhead = int(Configuration.modelLinkerConfig['nhead'])
+        self.max_len_sentence = int(Configuration.datasetConfig['max_len_sentence'])
         self.max_atoms_in_sentence = int(Configuration.datasetConfig['max_atoms_in_sentence'])
         self.max_atoms_in_one_type = int(Configuration.datasetConfig['max_atoms_in_one_type'])
         self.atom_vocab_size = int(Configuration.datasetConfig['atom_vocab_size'])
@@ -50,7 +37,7 @@ class Linker(Module):
         self.atom_embedding = AtomEmbedding(self.dim_embedding_atoms, self.atom_vocab_size, self.padding_id)
 
         # to do : definit un encoding
-        # self.linker_encoder =
+        self.linker_encoder = AttentionDecoderLayer()
 
         self.pos_transformation = Sequential(
             FFN(self.dim_embedding_atoms, self.dim_polarity_transfo, 0.1),
@@ -61,23 +48,32 @@ class Linker(Module):
             LayerNorm(self.dim_embedding_atoms, eps=1e-12)
         )
 
-    def forward(self, atoms_batch_tokenized, atoms_polarity_batch, sents_embedding):
+    def make_decoder_mask(self, atoms_token) :
+        decoder_attn_mask = torch.ones_like(atoms_token, dtype=torch.float64)
+        decoder_attn_mask[atoms_token.eq(self.padding_id)] = 0.0
+        return decoder_attn_mask.unsqueeze(1).repeat(1, atoms_token.shape[1], 1).repeat(self.nhead, 1, 1)
+
+    def forward(self, atoms_batch_tokenized, atoms_polarity_batch, sents_embedding, sents_mask=None):
         r'''
         Parameters :
         atoms_batch_tokenized : (batch_size, max_atoms_in_one_sentence) flattened categories
         atoms_polarity_batch : (batch_size, max_atoms_in_one_sentence) flattened categories polarities
-        sents_embedding : output of BERT for context
+        sents_embedding : (batch_size, len_sentence, dim_encoder) output of BERT for context
+        sents_mask
         Returns :
         link_weights : atom_vocab_size, batch-size, max_atoms_in_one_cat, max_atoms_in_one_cat)
         '''
 
         # atoms embedding
         atoms_embedding = self.atom_embedding(atoms_batch_tokenized)
+        print(atoms_embedding.shape)
 
         # MHA ou LSTM avec sortie de BERT
-        # decoder_mask = self.make_decoder_mask(atoms_batch)
-        # atoms_encoding = self.linker_encoder(atoms_embedding, sents_embedding, sents_mask, decoder_mask)
-        atoms_encoding = atoms_embedding
+        sents_embedding = torch.randn(32, self.max_len_sentence, self.dim_encoder)
+        batch_size, len_sentence, sents_embedding_dim = sents_embedding.shape
+        sents_mask = torch.randn(batch_size * self.nhead, self.max_atoms_in_sentence, self.max_len_sentence)
+        atoms_encoding = self.linker_encoder(atoms_embedding, sents_embedding, sents_mask, self.make_decoder_mask(atoms_batch_tokenized))
+        #atoms_encoding = atoms_embedding
 
         link_weights = []
         for atom_type in list(self.atom_map.keys())[:-1]:
diff --git a/SuperTagger/Linker/MHA.py b/SuperTagger/Linker/MHA.py
new file mode 100644
index 0000000000000000000000000000000000000000..d85d5e03b29ad33077224bb19f90c44d7b3d630f
--- /dev/null
+++ b/SuperTagger/Linker/MHA.py
@@ -0,0 +1,92 @@
+import copy
+import torch
+import torch.nn.functional as F
+import torch.optim as optim
+from Configuration import Configuration
+from torch import Tensor, LongTensor
+from torch.nn import (GELU, LSTM, Dropout, LayerNorm, Linear, Module, MultiheadAttention,
+                      ModuleList, Sequential)
+
+from SuperTagger.Linker.utils import FFN
+
+
+class AttentionDecoderLayer(Module):
+    r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network.
+    This standard decoder layer is based on the paper "Attention Is All You Need".
+    Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
+    Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
+    Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
+    in a different way during application.
+
+    Args:
+        dim_model: the number of expected features in the input (required).
+        nhead: the number of heads in the multiheadattention models (required).
+        dim_feedforward: the dimension of the feedforward network model (default=2048).
+        dropout: the dropout value (default=0.1).
+        activation: the activation function of the intermediate layer, can be a string
+            ("relu" or "gelu") or a unary callable. Default: relu
+        layer_norm_eps: the eps value in layer normalization components (default=1e-5).
+        batch_first: If ``True``, then the input and output tensors are provided
+            as (batch, seq, feature). Default: ``False``.
+        norm_first: if ``True``, layer norm is done prior to self attention, multihead
+            attention and feedforward operations, respectivaly. Otherwise it's done after.
+            Default: ``False`` (after).
+    """
+    __constants__ = ['batch_first', 'norm_first']
+
+    def __init__(self) -> None:
+        super(AttentionDecoderLayer, self).__init__()
+
+        # init params
+        dim_encoder = int(Configuration.modelEncoderConfig['dim_encoder'])
+        dim_decoder = int(Configuration.modelDecoderConfig['dim_decoder'])
+        max_len_sentence = int(Configuration.datasetConfig['max_len_sentence'])
+        atom_vocab_size = int(Configuration.datasetConfig['atom_vocab_size'])
+        nhead = int(Configuration.modelLinkerConfig['nhead'])
+        dropout = float(Configuration.modelLinkerConfig['dropout'])
+        dim_feedforward = int(Configuration.modelLinkerConfig['dim_feedforward'])
+        layer_norm_eps = float(Configuration.modelLinkerConfig['layer_norm_eps'])
+
+        # layers
+        self.dropout = Dropout(dropout)
+        self.self_attn = MultiheadAttention(dim_decoder, nhead, dropout=dropout, batch_first=True,
+                                            kdim=dim_decoder, vdim=dim_decoder)
+        self.norm1 = LayerNorm(dim_decoder, eps=layer_norm_eps)
+        self.multihead_attn = MultiheadAttention(dim_decoder, nhead, dropout=dropout,
+                                                 kdim=dim_encoder, vdim=dim_encoder,
+                                                 batch_first=True)
+        self.norm2 = LayerNorm(dim_decoder, eps=layer_norm_eps)
+        self.ffn = FFN(d_model=dim_decoder, d_ff=dim_feedforward, dropout=dropout)
+        self.norm3 = LayerNorm(dim_decoder, eps=layer_norm_eps)
+
+    def forward(self, atoms_embedding: Tensor, sents_embedding: Tensor, encoder_mask: Tensor,
+                decoder_mask: Tensor) -> Tensor:
+        r"""Pass the inputs through the decoder layer.
+
+        Args:
+            atoms_embedding: the sequence to the decoder layer (required).
+            sents_embedding: the sequence from the last layer of the encoder (required)
+            encoder_mask
+            decoder_mask
+        """
+        x = atoms_embedding
+        x = self.norm1(x + self._mask_mha_block(atoms_embedding, decoder_mask))
+        x = self.norm2(x + self._mha_block(x, sents_embedding, encoder_mask))
+        x = self.norm3(x + self._ff_block(x))
+
+        return x
+
+    # self-attention block
+    def _mask_mha_block(self, x: Tensor, decoder_mask: Tensor) -> Tensor:
+        x = self.self_attn(x, x, x, attn_mask=decoder_mask)[0]
+        return x
+
+    # multihead attention block
+    def _mha_block(self, x: Tensor, sents_embs: Tensor, encoder_mask: Tensor) -> Tensor:
+        x = self.multihead_attn(x, sents_embs, sents_embs, attn_mask=encoder_mask)[0]
+        return x
+
+    # feed forward block
+    def _ff_block(self, x: Tensor) -> Tensor:
+        x = self.ffn.forward(x)
+        return x
diff --git a/SuperTagger/Linker/__pycache__/Linker.cpython-38.pyc b/SuperTagger/Linker/__pycache__/Linker.cpython-38.pyc
index e732b7dd6a2e215b497d136ba944e5e832aeacd6..facf9eafa213710664a3d7f18c9150693e6c0a2c 100644
Binary files a/SuperTagger/Linker/__pycache__/Linker.cpython-38.pyc and b/SuperTagger/Linker/__pycache__/Linker.cpython-38.pyc differ
diff --git a/SuperTagger/Linker/__pycache__/MHA.cpython-38.pyc b/SuperTagger/Linker/__pycache__/MHA.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..679c41a8ef96c82cce084ed1d859b6f0933c12f5
Binary files /dev/null and b/SuperTagger/Linker/__pycache__/MHA.cpython-38.pyc differ
diff --git a/SuperTagger/Linker/__pycache__/utils.cpython-38.pyc b/SuperTagger/Linker/__pycache__/utils.cpython-38.pyc
index 0e04fe2ee2a37736df214e8faee8e94eaa55ae64..c4eef1e07886db024496a876b846bd69646a7538 100644
Binary files a/SuperTagger/Linker/__pycache__/utils.cpython-38.pyc and b/SuperTagger/Linker/__pycache__/utils.cpython-38.pyc differ
diff --git a/SuperTagger/Linker/utils.py b/SuperTagger/Linker/utils.py
index d95192689514640f048838aad2030fcf959272bc..abd6814fc0bc8ae839b8efe40d3e50a8921cbfb1 100644
--- a/SuperTagger/Linker/utils.py
+++ b/SuperTagger/Linker/utils.py
@@ -2,11 +2,29 @@ import re
 import regex
 import numpy as np
 import torch
-
+from torch.nn import Sequential, LayerNorm, Linear, Dropout, GELU, MultiheadAttention
+from torch.nn import Module
 from SuperTagger.Linker.AtomTokenizer import AtomTokenizer
 from SuperTagger.Linker.atom_map import atom_map
 from SuperTagger.utils import pad_sequence
 
+
+class FFN(Module):
+    "Implements FFN equation."
+
+    def __init__(self, d_model, d_ff, dropout=0.1):
+        super(FFN, self).__init__()
+        self.ffn = Sequential(
+            Linear(d_model, d_ff, bias=False),
+            GELU(),
+            Dropout(dropout),
+            Linear(d_ff, d_model, bias=False)
+        )
+
+    def forward(self, x):
+        return self.ffn(x)
+
+
 regex_categories = r'\w+\(\d+,(?:((?R))|(\w+))*,?(?:((?R))|(\w+))*\)'
 
 
@@ -29,9 +47,10 @@ def get_axiom_links(max_atoms_in_one_type, atoms_polarity, batch_axiom_links):
     for atom_type in list(atom_map.keys())[:-1]:
         # filtrer sur atom_batch que ce type puis filtrer avec les indices sur atom polarity
         l_polarity_plus = [[x for i, x in enumerate(atoms_batch[s_idx]) if atoms_polarity[s_idx, i]
-                            and bool(re.search(atom_type+"_", atoms_batch[s_idx][i]))] for s_idx in range(len(atoms_batch))]
+                            and bool(re.search(atom_type + "_", atoms_batch[s_idx][i]))] for s_idx in
+                           range(len(atoms_batch))]
         l_polarity_minus = [[x for i, x in enumerate(atoms_batch[s_idx]) if not atoms_polarity[s_idx, i]
-                             and bool(re.search(atom_type+"_", atoms_batch[s_idx][i]))] for s_idx in
+                             and bool(re.search(atom_type + "_", atoms_batch[s_idx][i]))] for s_idx in
                             range(len(atoms_batch))]
 
         linking_plus_to_minus = pad_sequence(
diff --git a/SuperTagger/__pycache__/eval.cpython-38.pyc b/SuperTagger/__pycache__/eval.cpython-38.pyc
index fce253cccfad91c0b80dd6cd68cc5e05e3437104..f5cec815f7dbc90d4ab075b8e48682a5c6e119fd 100644
Binary files a/SuperTagger/__pycache__/eval.cpython-38.pyc and b/SuperTagger/__pycache__/eval.cpython-38.pyc differ