diff --git a/Configuration/config.ini b/Configuration/config.ini
index e26d80f63bf4ff172cee6b5946384c06ca8c9786..b36e88f259b5e560822b321854e2cbfed73c7323 100644
--- a/Configuration/config.ini
+++ b/Configuration/config.ini
@@ -24,8 +24,8 @@ sinkhorn_iters = 5
 
 [MODEL_TRAINING]
 batch_size = 32
-pretrain_linker_epochs = 1
-epoch = 1
+pretrain_linker_epochs = 10
+epoch = 20
 seed_val = 42
 learning_rate = 2e-3
 
diff --git a/NeuralProofNet/NeuralProofNet.py b/NeuralProofNet/NeuralProofNet.py
index bb12a552d5efcb8cf31f982ac7022bb74092efb5..92c783b4c9aab66e5c7c25cdb75dae1c0ff1f96c 100644
--- a/NeuralProofNet/NeuralProofNet.py
+++ b/NeuralProofNet/NeuralProofNet.py
@@ -67,7 +67,6 @@ class NeuralProofNet(Module):
         if linker_path_model is not None:
             linker.load_weights(linker_path_model)
         self.linker = linker
-        self.Supertagger = self.linker.Supertagger
 
         # Learning
         self.linker_loss = SinkhornLoss()
@@ -96,7 +95,7 @@ class NeuralProofNet(Module):
         """
         print("Start preprocess Data")
         sentences_batch = df_axiom_links["X"].str.strip().tolist()
-        sentences_tokens, sentences_mask = self.Supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch)
+        sentences_tokens, sentences_mask = self.linker.Supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch)
 
         _, polarities, _ = get_GOAL(self.max_len_sentence, df_axiom_links)
         atoms_polarity_batch = pad_sequence(
@@ -126,10 +125,10 @@ class NeuralProofNet(Module):
     def forward(self, batch_sentences_tokens, batch_sentences_mask):
 
         # get sentence embedding from BERT which is already trained
-        output = self.Supertagger.forward(batch_sentences_tokens, batch_sentences_mask)
+        output = self.linker.Supertagger.forward(batch_sentences_tokens, batch_sentences_mask)
         last_hidden_state = output['logit']
         pred_categories = torch.argmax(torch.softmax(last_hidden_state, dim=2), dim=2)
-        pred_categories = self.Supertagger.tags_tokenizer.convert_ids_to_tags(pred_categories)
+        pred_categories = self.linker.Supertagger.tags_tokenizer.convert_ids_to_tags(pred_categories)
 
         # get information from tagger predictions
         atoms_batch, polarities, batch_num_atoms_per_word = get_info_for_tagger(self.max_len_sentence, pred_categories)
@@ -140,6 +139,11 @@ class NeuralProofNet(Module):
         batch_pos_idx = get_pos_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type)
         batch_neg_idx = get_neg_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type)
 
+        batch_num_atoms_per_word = batch_num_atoms_per_word.to(self.device)
+        atoms_batch_tokenized = atoms_batch_tokenized.to(self.device)
+        batch_pos_idx = batch_pos_idx.to(self.device)
+        batch_neg_idx = batch_neg_idx.to(self.device)
+
         logits_links = self.linker(batch_num_atoms_per_word, atoms_batch_tokenized, batch_pos_idx, batch_neg_idx,
                                    output['word_embeding'])
 
diff --git a/train.py b/train.py
index 1bdeb557da3ede0362de02ab05b4f75514d7cd06..d50fd88805d5589b0fb8dd89dd65bcc575e5da3d 100644
--- a/train.py
+++ b/train.py
@@ -6,9 +6,8 @@ from utils import read_csv_pgbar
 from find_config import configurate
 from Configuration import Configuration
 
-
 torch.cuda.empty_cache()
-nb_sentences = 100000000
+nb_sentences = 1000000000
 file_path_axiom_links = 'Datasets/goldANDsilver_dataset_links.csv'
 model_tagger = "models/flaubert_super_98_V2_50e.pt"
 
@@ -29,7 +28,7 @@ df_axiom_links = read_csv_pgbar(file_path_axiom_links, nb_sentences)
 print("#" * 20)
 print("#" * 20)
 neural_proof_net = NeuralProofNet(model_tagger)
-neural_proof_net.train_neuralproofnet(df_axiom_links, validation_rate=0.05, epochs=epochs, batch_size=batch_size,
-                    checkpoint=True, tensorboard=True)
+neural_proof_net.train_neuralproofnet(df_axiom_links, validation_rate=0.1, epochs=epochs, batch_size=batch_size,
+                                      checkpoint=True, tensorboard=True)
+print("#" * 20)
 print("#" * 20)
-print("#" * 20)
\ No newline at end of file