diff --git a/Configuration/__pycache__/Configuration.cpython-38.pyc b/Configuration/__pycache__/Configuration.cpython-38.pyc
deleted file mode 100644
index 15ddaa2324daba83e16223382e73fbbbeae0bd57..0000000000000000000000000000000000000000
Binary files a/Configuration/__pycache__/Configuration.cpython-38.pyc and /dev/null differ
diff --git a/Linker/Linker.py b/Linker/Linker.py
index ad39aed0177bdde9e1bcec0f2b0c4ed0069fc0e5..12f9534edfc8fe5a35fbad8b62dbc2fe7ee44d67 100644
--- a/Linker/Linker.py
+++ b/Linker/Linker.py
@@ -65,6 +65,8 @@ class Linker(Module):
                                                          num_warmup_steps=0,
                                                          num_training_steps=100)
 
+        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
     def __preprocess_data(self, batch_size, df_axiom_links, sentences_tokens, sentences_mask, validation_rate=0.0):
         r"""
         Args:
@@ -133,18 +135,18 @@ class Linker(Module):
                                                           atoms_batch_tokenized[s_idx][i] == self.atom_map[
                                                               atom_type] and
                                                           atoms_polarity_batch[s_idx][i])] + [
-                                                         torch.zeros(self.dim_embedding_atoms)])
+                                                         torch.zeros(self.dim_embedding_atoms, device=self.device)]).to(self.device)
                                          for s_idx in range(len(atoms_polarity_batch))], padding_value=0,
-                                        max_len=self.max_atoms_in_one_type // 2)
+                                        max_len=self.max_atoms_in_one_type // 2).to(self.device)
 
             neg_encoding = pad_sequence([torch.stack([x for i, x in enumerate(atoms_encoding[s_idx])
                                                       if (self.atom_map[atom_type] in atoms_batch_tokenized[s_idx] and
                                                           atoms_batch_tokenized[s_idx][i] == self.atom_map[
                                                               atom_type] and
                                                           not atoms_polarity_batch[s_idx][i])] + [
-                                                         torch.zeros(self.dim_embedding_atoms)])
+                                                         torch.zeros(self.dim_embedding_atoms, device=self.device)]).to(self.device)
                                          for s_idx in range(len(atoms_polarity_batch))], padding_value=0,
-                                        max_len=self.max_atoms_in_one_type // 2)
+                                        max_len=self.max_atoms_in_one_type // 2).to(self.device)
 
             pos_encoding = self.pos_transformation(pos_encoding)
             neg_encoding = self.neg_transformation(neg_encoding)
@@ -175,6 +177,7 @@ class Linker(Module):
         training_dataloader, validation_dataloader = self.__preprocess_data(batch_size, df_axiom_links,
                                                                             sentences_tokens, sentences_mask,
                                                                             validation_rate)
+        self.to(self.device)
         for epoch_i in range(0, epochs):
             epoch_acc, epoch_loss = self.train_epoch(training_dataloader, validation_dataloader, checkpoint, validate)
 
@@ -197,16 +200,16 @@ class Linker(Module):
         # For each batch of training data...
         for step, batch in enumerate(training_dataloader):
             # Unpack this training batch from our dataloader
-            batch_atoms = batch[0].to("cuda" if torch.cuda.is_available() else "cpu")
-            batch_polarity = batch[1].to("cuda" if torch.cuda.is_available() else "cpu")
-            batch_true_links = batch[2].to("cuda" if torch.cuda.is_available() else "cpu")
-            batch_sentences_tokens = batch[3].to("cuda" if torch.cuda.is_available() else "cpu")
-            batch_sentences_mask = batch[4].to("cuda" if torch.cuda.is_available() else "cpu")
+            batch_atoms = batch[0].to(self.device)
+            batch_polarity = batch[1].to(self.device)
+            batch_true_links = batch[2].to(self.device)
+            batch_sentences_tokens = batch[3].to(self.device)
+            batch_sentences_mask = batch[4].to(self.device)
 
             self.optimizer.zero_grad()
 
             # get sentence embedding from BERT which is already trained
-            logits, sentences_embedding = self.Supertagger.foward(batch_sentences_tokens, batch_sentences_mask)
+            logits, sentences_embedding = self.Supertagger.forward(batch_sentences_tokens, batch_sentences_mask)
 
             # Run the kinker on the categories predictions
             logits_predictions = self(batch_atoms, batch_polarity, sentences_embedding, batch_sentences_mask)
@@ -273,18 +276,18 @@ class Linker(Module):
                                                           atoms_tokenized[s_idx][i] == self.atom_map[
                                                               atom_type] and
                                                           polarities[s_idx][i])] + [
-                                                         torch.zeros(self.dim_embedding_atoms)])
+                                                         torch.zeros(self.dim_embedding_atoms, device=self.device)])
                                          for s_idx in range(len(polarities))], padding_value=0,
-                                        max_len=self.max_atoms_in_one_type // 2)
+                                        max_len=self.max_atoms_in_one_type // 2).to(self.device)
 
             neg_encoding = pad_sequence([torch.stack([x for i, x in enumerate(atoms_encoding[s_idx])
                                                       if (self.atom_map[atom_type] in atoms_tokenized[s_idx] and
                                                           atoms_tokenized[s_idx][i] == self.atom_map[
                                                               atom_type] and
                                                           not polarities[s_idx][i])] + [
-                                                         torch.zeros(self.dim_embedding_atoms)])
+                                                         torch.zeros(self.dim_embedding_atoms, device=self.device)])
                                          for s_idx in range(len(polarities))], padding_value=0,
-                                        max_len=self.max_atoms_in_one_type // 2)
+                                        max_len=self.max_atoms_in_one_type // 2).to(self.device)
 
             pos_encoding = self.pos_transformation(pos_encoding)
             neg_encoding = self.neg_transformation(neg_encoding)
@@ -297,13 +300,13 @@ class Linker(Module):
         return axiom_links
 
     def eval_batch(self, batch, cross_entropy_loss):
-        batch_atoms = batch[0].to("cuda" if torch.cuda.is_available() else "cpu")
-        batch_polarity = batch[1].to("cuda" if torch.cuda.is_available() else "cpu")
-        batch_true_links = batch[2].to("cuda" if torch.cuda.is_available() else "cpu")
-        batch_sentences_tokens = batch[3].to("cuda" if torch.cuda.is_available() else "cpu")
-        batch_sentences_mask = batch[4].to("cuda" if torch.cuda.is_available() else "cpu")
+        batch_atoms = batch[0].to(self.device)
+        batch_polarity = batch[1].to(self.device)
+        batch_true_links = batch[2].to(self.device)
+        batch_sentences_tokens = batch[3].to(self.device)
+        batch_sentences_mask = batch[4].to(self.device)
 
-        logits, sentences_embedding = self.Supertagger.foward(batch_sentences_tokens, batch_sentences_mask)
+        logits, sentences_embedding = self.Supertagger.forward(batch_sentences_tokens, batch_sentences_mask)
         logits_axiom_links_pred = self(batch_atoms, batch_polarity, sentences_embedding,
                                        batch_sentences_mask)
         axiom_links_pred = torch.argmax(logits_axiom_links_pred, dim=3)
@@ -363,4 +366,4 @@ class Linker(Module):
             'neg_transformation': self.neg_transformation.state_dict(),
             'optimizer': self.optimizer,
         }, path)
-        #self.to(self.device)
+        self.to(self.device)
diff --git a/Linker/MHA.py b/Linker/MHA.py
index c66580617a7c665111b0da1711c3f66c5f8abe16..651487b6e841398eed1f55e4fbc2bedcd6c6317b 100644
--- a/Linker/MHA.py
+++ b/Linker/MHA.py
@@ -27,9 +27,8 @@ class AttentionDecoderLayer(Module):
             attention and feedforward operations, respectivaly. Otherwise it's done after.
             Default: ``False`` (after).
     """
-    __constants__ = ['batch_first', 'norm_first']
 
-    def __init__(self) -> None:
+    def __init__(self):
         super(AttentionDecoderLayer, self).__init__()
 
         # init params
@@ -42,18 +41,17 @@ class AttentionDecoderLayer(Module):
 
         # layers
         self.dropout = Dropout(dropout)
-        self.self_attn = MultiheadAttention(dim_decoder, nhead, dropout=dropout, batch_first=True,
+        self.self_attn = MultiheadAttention(dim_decoder, nhead, dropout=dropout,
                                             kdim=dim_decoder, vdim=dim_decoder)
         self.norm1 = LayerNorm(dim_decoder, eps=layer_norm_eps)
         self.multihead_attn = MultiheadAttention(dim_decoder, nhead, dropout=dropout,
-                                                 kdim=dim_encoder, vdim=dim_encoder,
-                                                 batch_first=True)
+                                                 kdim=dim_encoder, vdim=dim_encoder)
         self.norm2 = LayerNorm(dim_decoder, eps=layer_norm_eps)
         self.ffn = FFN(d_model=dim_decoder, d_ff=dim_feedforward, dropout=dropout)
         self.norm3 = LayerNorm(dim_decoder, eps=layer_norm_eps)
 
-    def forward(self, atoms_embedding: Tensor, sents_embedding: Tensor, encoder_mask: Tensor,
-                decoder_mask: Tensor) -> Tensor:
+    def forward(self, atoms_embedding, sents_embedding, encoder_mask,
+                decoder_mask):
         r"""Pass the inputs through the decoder layer.
 
         Args:
@@ -62,24 +60,27 @@ class AttentionDecoderLayer(Module):
             encoder_mask
             decoder_mask
         """
+        atoms_embedding = atoms_embedding.permute(1, 0, 2)
+        sents_embedding = sents_embedding.permute(1, 0, 2)
+
         x = atoms_embedding
         x = self.norm1(x + self._mask_mha_block(atoms_embedding, decoder_mask))
         x = self.norm2(x + self._mha_block(x, sents_embedding, encoder_mask))
         x = self.norm3(x + self._ff_block(x))
 
-        return x
+        return x.permute(1, 0, 2)
 
     # self-attention block
-    def _mask_mha_block(self, x: Tensor, decoder_mask: Tensor) -> Tensor:
+    def _mask_mha_block(self, x, decoder_mask):
         x = self.self_attn(x, x, x, attn_mask=decoder_mask)[0]
         return x
 
     # multihead attention block
-    def _mha_block(self, x: Tensor, sents_embs: Tensor, encoder_mask: Tensor) -> Tensor:
+    def _mha_block(self, x, sents_embs, encoder_mask):
         x = self.multihead_attn(x, sents_embs, sents_embs, attn_mask=encoder_mask)[0]
         return x
 
     # feed forward block
-    def _ff_block(self, x: Tensor) -> Tensor:
+    def _ff_block(self, x):
         x = self.ffn.forward(x)
         return x
diff --git a/bash_GPU.sh b/bash_GPU.sh
new file mode 100644
index 0000000000000000000000000000000000000000..99692203e0a64519649244caa479801da0500a2a
--- /dev/null
+++ b/bash_GPU.sh
@@ -0,0 +1,13 @@
+#!/bin/sh
+#SBATCH --job-name=Deepgrail_Linker
+#SBATCH --partition=RTX6000Node
+#SBATCH --gres=gpu:1
+#SBATCH --mem=32000
+#SBATCH --gres-flags=enforce-binding
+#SBATCH --error="error_rtx1.err"
+#SBATCH --output="out_rtx1.out"
+
+module purge
+module load singularity/3.0.3
+
+srun singularity exec /logiciels/containerCollections/CUDA11/pytorch-NGC-21-03-py3.sif python "train.py"
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index 1491f06d96fdbaac1338114c77226ffc488f38a7..c117e5384efe2b3cf6820f46d061d68059858966 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,6 +1,6 @@
 numpy==1.22.2
 transformers==4.16.2
-torch==1.10.2
+torch==1.9.0
 huggingface-hub==0.4.0
 pandas==1.4.1
 sentencepiece
diff --git a/train.py b/train.py
index bc2f785c1798eb3e177cfa76453f52ddbc44f789..b2e73259a8ad242fb2103e22dbedf531d9e9c1d2 100644
--- a/train.py
+++ b/train.py
@@ -15,9 +15,12 @@ df_axiom_links = read_csv_pgbar(file_path_axiom_links, nb_sentences)
 sentences_batch = df_axiom_links["Sentences"].tolist()
 supertagger = SuperTagger()
 supertagger.load_weights("models/model_supertagger.pt")
+
+
 sents_tokenized, sents_mask = supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch)
 
 print("Linker")
 linker = Linker(supertagger)
+linker = linker.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
 print("Linker Training")
 linker.train_linker(df_axiom_links, sents_tokenized, sents_mask, validation_rate=0.1, epochs=epochs, batch_size=batch_size, checkpoint=True, validate=True)