diff --git a/Linker/Linker.py b/Linker/Linker.py
index a3306db59234b41f443f5e38a1dd7a584504b355..6d14226412d5ca9e58b38ac19f11b7dd8d865597 100644
--- a/Linker/Linker.py
+++ b/Linker/Linker.py
@@ -11,7 +11,6 @@ from torch.optim import AdamW
 from torch.utils.data import TensorDataset, random_split
 from torch.utils.tensorboard import SummaryWriter
 from tqdm import tqdm
-from transformers import get_cosine_schedule_with_warmup
 
 from Configuration import Configuration
 from Linker.AtomEmbedding import AtomEmbedding
@@ -92,6 +91,9 @@ class Linker(Module):
 
         self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
+        self.to(self.device)
+
+
     def __preprocess_data(self, batch_size, df_axiom_links, validation_rate=0.1):
         r"""
         Args:
@@ -161,13 +163,13 @@ class Linker(Module):
                 [get_pos_encoding_for_s_idx(self.dim_embedding_atoms, atoms_encoding, atoms_batch_tokenized,
                                             atoms_polarity_batch, atom_type, s_idx)
                  for s_idx in range(len(atoms_polarity_batch))], padding_value=0,
-                max_len=self.max_atoms_in_one_type // 2).to(self.device)
+                max_len=self.max_atoms_in_one_type // 2)
 
             neg_encoding = pad_sequence(
                 [get_neg_encoding_for_s_idx(self.dim_embedding_atoms, atoms_encoding, atoms_batch_tokenized,
                                             atoms_polarity_batch, atom_type, s_idx)
                  for s_idx in range(len(atoms_polarity_batch))], padding_value=0,
-                max_len=self.max_atoms_in_one_type // 2).to(self.device)
+                max_len=self.max_atoms_in_one_type // 2)
 
             pos_encoding = self.pos_transformation(pos_encoding)
             neg_encoding = self.neg_transformation(neg_encoding)
@@ -195,8 +197,6 @@ class Linker(Module):
         """
         training_dataloader, validation_dataloader = self.__preprocess_data(batch_size, df_axiom_links,
                                                                             validation_rate)
-        self.to(self.device)
-
         if checkpoint or tensorboard:
             checkpoint_dir, writer = output_create_dir()
 
@@ -210,8 +210,7 @@ class Linker(Module):
             print(f'Epoch: {epoch_i + 1:02} | Epoch Time: {training_time}')
             print(f'\tTrain Loss: {avg_train_loss:.3f} | Train Acc: {avg_accuracy_train * 100:.2f}%')
             if validation_rate > 0.0:
-                with torch.no_grad():
-                    loss_test, accuracy_test = self.eval_epoch(validation_dataloader, self.cross_entropy_loss)
+                loss_test, accuracy_test = self.eval_epoch(validation_dataloader, self.cross_entropy_loss)
                 print(f'\tVal Loss: {loss_test:.3f} | Val Acc: {accuracy_test * 100:.2f}%')
 
             if checkpoint:
@@ -241,14 +240,13 @@ class Linker(Module):
              accuracy on validation set
              loss on train set
         """
+        self.train()
 
         # Reset the total loss for this epoch.
         epoch_loss = 0
         accuracy_train = 0
         t0 = time.time()
 
-        self.train()
-
         # For each batch of training data...
         with tqdm(training_dataloader, unit="batch") as tepoch:
             for batch in tepoch:
@@ -299,44 +297,44 @@ class Linker(Module):
             axiom_links : atom_vocab_size, batch-size, max_atoms_in_one_cat)
         """
         self.eval()
-
+        with torch.no_grad():
         # get atoms
-        atoms_batch = get_atoms_batch(categories)
-        atoms_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms_batch)
+            atoms_batch = get_atoms_batch(categories)
+            atoms_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms_batch)
 
-        # get polarities
-        polarities = find_pos_neg_idexes(self.max_atoms_in_sentence, categories)
+            # get polarities
+            polarities = find_pos_neg_idexes(self.max_atoms_in_sentence, categories)
 
-        # atoms embedding
-        atoms_embedding = self.atoms_embedding(atoms_tokenized)
+            # atoms embedding
+            atoms_embedding = self.atoms_embedding(atoms_tokenized)
 
-        # MHA ou LSTM avec sortie de BERT
-        atoms_encoding = self.linker_encoder(atoms_embedding, sents_embedding, sents_mask,
-                                             self.make_decoder_mask(atoms_tokenized))
+            # MHA ou LSTM avec sortie de BERT
+            atoms_encoding = self.linker_encoder(atoms_embedding, sents_embedding, sents_mask,
+                                                 self.make_decoder_mask(atoms_tokenized))
 
-        link_weights = []
-        for atom_type in list(self.atom_map.keys())[:-1]:
-            pos_encoding = pad_sequence(
-                [get_pos_encoding_for_s_idx(self.dim_embedding_atoms, atoms_encoding, atoms_tokenized,
-                                            polarities, atom_type, s_idx)
-                 for s_idx in range(len(polarities))], padding_value=0,
-                max_len=self.max_atoms_in_one_type // 2).to(self.device)
+            link_weights = []
+            for atom_type in list(self.atom_map.keys())[:-1]:
+                pos_encoding = pad_sequence(
+                    [get_pos_encoding_for_s_idx(self.dim_embedding_atoms, atoms_encoding, atoms_tokenized,
+                                                polarities, atom_type, s_idx)
+                     for s_idx in range(len(polarities))], padding_value=0,
+                    max_len=self.max_atoms_in_one_type // 2)
 
-            neg_encoding = pad_sequence(
-                [get_neg_encoding_for_s_idx(self.dim_embedding_atoms, atoms_encoding, atoms_tokenized,
-                                            polarities, atom_type, s_idx)
-                 for s_idx in range(len(polarities))], padding_value=0,
-                max_len=self.max_atoms_in_one_type // 2).to(self.device)
+                neg_encoding = pad_sequence(
+                    [get_neg_encoding_for_s_idx(self.dim_embedding_atoms, atoms_encoding, atoms_tokenized,
+                                                polarities, atom_type, s_idx)
+                     for s_idx in range(len(polarities))], padding_value=0,
+                    max_len=self.max_atoms_in_one_type // 2)
 
-            pos_encoding = self.pos_transformation(pos_encoding)
-            neg_encoding = self.neg_transformation(neg_encoding)
+                pos_encoding = self.pos_transformation(pos_encoding)
+                neg_encoding = self.neg_transformation(neg_encoding)
 
-            weights = torch.bmm(pos_encoding, neg_encoding.transpose(2, 1))
-            link_weights.append(sinkhorn(weights, iters=3))
+                weights = torch.bmm(pos_encoding, neg_encoding.transpose(2, 1))
+                link_weights.append(sinkhorn(weights, iters=3))
 
-        logits_predictions = torch.stack(link_weights).permute(1, 0, 2, 3)
-        axiom_links = torch.argmax(F.log_softmax(logits_predictions, dim=3), dim=3)
-        return axiom_links
+            logits_predictions = torch.stack(link_weights).permute(1, 0, 2, 3)
+            axiom_links = torch.argmax(F.log_softmax(logits_predictions, dim=3), dim=3)
+            return axiom_links
 
     def eval_batch(self, batch, cross_entropy_loss):
         batch_atoms = batch[0].to(self.device)
@@ -361,12 +359,14 @@ class Linker(Module):
         Args:
             dataloader: contains all the batch which contain the tokenized sentences, their masks and the true symbols
         """
+        self.eval()
         accuracy_average = 0
         loss_average = 0
-        for step, batch in enumerate(dataloader):
-            loss, accuracy = self.eval_batch(batch, cross_entropy_loss)
-            accuracy_average += accuracy
-            loss_average += float(loss)
+        with torch.no_grad():
+            for step, batch in enumerate(dataloader):
+                loss, accuracy = self.eval_batch(batch, cross_entropy_loss)
+                accuracy_average += accuracy
+                loss_average += float(loss)
 
         return loss_average / len(dataloader), accuracy_average / len(dataloader)