diff --git a/Configuration/config.ini b/Configuration/config.ini
index db5609f127407dc272bf282fa63e8d8163923c0a..425af5e5890518ebe30742150896c3f9e1c758d2 100644
--- a/Configuration/config.ini
+++ b/Configuration/config.ini
@@ -11,7 +11,12 @@ max_symbols_in_sentence=1250
 max_len_sentence=112
 [MODEL_TRAINING]
 device=cpu
-batch_size=32
+batch_size=16
 epoch=20
 seed_val=42
-learning_rate=0.0005
\ No newline at end of file
+learning_rate=0.005
+use_checkpoint_SAVE=1
+output_path=Output
+use_checkpoint_LOAD=0
+input_path=Input
+model_to_load=model_check.pt
\ No newline at end of file
diff --git a/SuperTagger/Decoder/RNNDecoderLayer.py b/SuperTagger/Decoder/RNNDecoderLayer.py
index 6e7dd74874ea8066ad6e60f823f1d812f8fbf8bb..89f0798647deed1bc950aae0b34a8e49ad14fa39 100644
--- a/SuperTagger/Decoder/RNNDecoderLayer.py
+++ b/SuperTagger/Decoder/RNNDecoderLayer.py
@@ -57,16 +57,22 @@ class RNNDecoderLayer(Module):
 
         # y_hat[batch_size, max_len_sentence, vocab_size] init with probability pad =1
         y_hat = torch.zeros(batch_size, self.max_symbols_in_sentence, self.symbols_vocab_size,
-                            dtype=torch.float)
+                            dtype=torch.float,device="cuda" if torch.cuda.is_available() else "cpu")
         y_hat[:, :, self.symbols_padding_id] = 1
-        decoded_i = torch.ones(batch_size, 1, dtype=torch.long) * self.symbols_start_id
+        use_teacher_forcing = True if random.random() < 0.05 else False
+        if use_teacher_forcing:
+            print("\n FORCING TEACHING \n")
+            decoded_i = symbols_tokenized_batch[:, 0].unsqueeze(1)
+        else :
+            decoded_i = torch.ones(batch_size, 1, dtype=torch.long,device="cuda" if torch.cuda.is_available() else "cpu")* self.symbols_start_id
 
-        sos_mask = torch.zeros(batch_size, dtype=torch.bool)
+        sos_mask = torch.zeros(batch_size, dtype=torch.bool,device="cuda" if torch.cuda.is_available() else "cpu")
 
         # hidden_state goes through multiple linear layers
         hidden_state = pooler_output.unsqueeze(0).repeat(self.num_rnn_layers * (1 + self.bidirectional), 1, 1)
         c_state = torch.zeros(self.num_rnn_layers * (1 + self.bidirectional), batch_size, hidden_size,
-                              dtype=torch.float)
+                              dtype=torch.float,device="cuda" if torch.cuda.is_available() else "cpu")
+
 
         # for each symbol
         for i in range(self.max_symbols_in_sentence):
@@ -80,7 +86,6 @@ class RNNDecoderLayer(Module):
             # Projection of the output of the rnn omitting the last probability (which is pad) so we dont predict PAD
             proj = self.proj(output)[:, :, :-2]
 
-            use_teacher_forcing = True if random.random() < self.teacher_forcing else False
             if use_teacher_forcing:
                 decoded_i = symbols_tokenized_batch[:, i].unsqueeze(1)
             else:
@@ -92,6 +97,8 @@ class RNNDecoderLayer(Module):
             y_hat[~sos_mask, i, :-2] = proj[~sos_mask, -1, :]
             sos_mask = sos_mask_i | sos_mask
 
+            use_teacher_forcing = True if random.random() < 0.25 else False
+
             # Stop if every sentence says padding or if we are full
             if not torch.any(~sos_mask):
                 break
@@ -109,16 +116,16 @@ class RNNDecoderLayer(Module):
 
         # contains the predictions
         y_hat = torch.zeros(batch_size, self.max_symbols_in_sentence, self.symbols_vocab_size,
-                            dtype=torch.float)
+                            dtype=torch.float,device="cuda" if torch.cuda.is_available() else "cpu")
         y_hat[:, :, self.symbols_padding_id] = 1
         # input of the embedder, a created vector that replace the true value
-        decoded_i = torch.ones(batch_size, 1, dtype=torch.long) * self.symbols_start_id
+        decoded_i = torch.ones(batch_size, 1, dtype=torch.long,device="cuda" if torch.cuda.is_available() else "cpu") * self.symbols_start_id
 
-        sos_mask = torch.zeros(batch_size, dtype=torch.bool)
+        sos_mask = torch.zeros(batch_size, dtype=torch.bool,device="cuda" if torch.cuda.is_available() else "cpu")
 
         hidden_state = pooler_output.unsqueeze(0).repeat(self.num_rnn_layers * (1 + self.bidirectional), 1, 1)
         c_state = torch.zeros(self.num_rnn_layers * (1 + self.bidirectional), batch_size, hidden_size,
-                              dtype=torch.float)
+                              dtype=torch.float,device="cuda" if torch.cuda.is_available() else "cpu")
 
         for i in range(self.max_symbols_in_sentence):
             symbols_embedding = self.symbols_embedder(decoded_i)
@@ -139,4 +146,4 @@ class RNNDecoderLayer(Module):
             if not torch.any(~sos_mask):
                 break
 
-        return torch.argmax(y_hat, dim=2)
+        return y_hat
diff --git a/SuperTagger/Encoder/EncoderLayer.py b/SuperTagger/Encoder/EncoderLayer.py
index a3243d6a7c8dc4eb71da5d11bb1b94c60c9b5cd6..c954584f332ff6207371cda0bc93aae8fe6edfea 100644
--- a/SuperTagger/Encoder/EncoderLayer.py
+++ b/SuperTagger/Encoder/EncoderLayer.py
@@ -3,16 +3,17 @@ import sys
 import torch
 from torch import nn
 
+from Configuration import Configuration
+
 
 class EncoderLayer(nn.Module):
     """Encoder class, imput of supertagger"""
 
-    def __init__(self, model, device=torch.device('cpu')):
+    def __init__(self, model):
         super(EncoderLayer, self).__init__()
         self.name = "Encoder"
 
         self.bert = model
-        self.device = device
 
         self.hidden_size = self.bert.config.hidden_size
 
@@ -24,8 +25,8 @@ class EncoderLayer(nn.Module):
                 last_hidden_state: [batch_size, max_len_sentence, dim_encoder]  Sequence of hidden-states at the output of the last layer of the model.
                 pooler_output: [batch_size, dim_encoder] Last layer hidden-state of the first token of the sequence (classification token) after further processing through the layers used for the auxiliary pretraining task
         """
-        b_input_ids = batch[0].to(self.device)
-        b_input_mask = batch[1].to(self.device)
+        b_input_ids = batch[0]
+        b_input_mask = batch[1]
 
         outputs = self.bert(
             input_ids=b_input_ids, attention_mask=b_input_mask)
@@ -33,7 +34,7 @@ class EncoderLayer(nn.Module):
         return outputs[0], outputs[1]
 
     @staticmethod
-    def load(model_path: str, device):
+    def load(model_path: str):
         r""" Load the model from a file.
         Args :
             model_path (str): path to model
@@ -43,7 +44,7 @@ class EncoderLayer(nn.Module):
         params = torch.load(
             model_path, map_location=lambda storage, loc: storage)
         args = params['args']
-        model = EncoderLayer(device=device, **args)
+        model = EncoderLayer(**args)
         model.load_state_dict(params['state_dict'])
 
         return model
diff --git a/SuperTagger/EncoderDecoder.py b/SuperTagger/EncoderDecoder.py
index 346363549894909925cea70b774d75a9e7367f8e..4efa870990d8f2b5b4fda48f1d6623d7e1a830f6 100644
--- a/SuperTagger/EncoderDecoder.py
+++ b/SuperTagger/EncoderDecoder.py
@@ -22,7 +22,6 @@ class EncoderDecoder(Module):
         self.max_len_sentence = int(Configuration.modelDecoderConfig['max_len_sentence'])
         self.max_symbols_in_sentence = int(Configuration.modelDecoderConfig['max_symbols_in_sentence'])
         self.dim_decoder = int(Configuration.modelDecoderConfig['dim_decoder'])
-        self.device = torch.device(Configuration.modelTrainingConfig['device'])
 
         self.symbols_map = symbols_map
         self.sents_padding_id = BASE_TOKENIZER.pad_token_id
@@ -42,7 +41,7 @@ class EncoderDecoder(Module):
             symbols_tokenized_batch: [batch_size, max_symbols_in_sentence] the true symbols for each sentence.
         """
         last_hidden_state, pooler_output = self.encoder([sents_tokenized_batch, sents_mask_batch])
-        last_hidden_state = self.dropout(last_hidden_state)
+        #last_hidden_state = self.dropout(last_hidden_state)
         return self.decoder(symbols_tokenized_batch, last_hidden_state, pooler_output)
 
     def decode_greedy_rnn(self, sents_tokenized_batch, sents_mask_batch):
@@ -59,29 +58,32 @@ class EncoderDecoder(Module):
 
         return predictions
 
-    def eval_batch(self, batch):
+    def eval_batch(self, batch, cross_entropy_loss):
         r"""Calls the evaluating methods after predicting the symbols from the sentence contained in batch
 
         Args:
             batch: contains the tokenized sentences, their masks and the true symbols.
         """
-        b_sents_tokenized = batch[0]
-        b_sents_mask = batch[1]
-        b_symbols_tokenized = batch[2]
+        b_sents_tokenized = batch[0].to("cuda" if torch.cuda.is_available() else "cpu")
+        b_sents_mask = batch[1].to("cuda" if torch.cuda.is_available() else "cpu")
+        b_symbols_tokenized = batch[2].to("cuda" if torch.cuda.is_available() else "cpu")
 
         type_predictions = self.decode_greedy_rnn(b_sents_tokenized, b_sents_mask)
 
-        predict_trad = [{v: k for k, v in self.symbols_map.items()}[int(i)] for i in type_predictions[0]]
+        pred = torch.argmax(type_predictions, dim=2)
+
+        predict_trad = [{v: k for k, v in self.symbols_map.items()}[int(i)] for i in pred[0]]
         true_trad = [{v: k for k, v in self.symbols_map.items()}[int(i)] for i in b_symbols_tokenized[0]]
-        print("\nsub true (", len([i for i in true_trad if i != '[PAD]']), ") : ",
+        l = len([i for i in true_trad if i != '[PAD]'])
+        print("\nsub true (", l, ") : ",
               [token for token in true_trad if token != '[PAD]'])
         print("\nsub predict (", len([i for i in predict_trad if i != '[PAD]']), ") : ",
-              [token for token in predict_trad if token != '[PAD]'])
+              [token for token in predict_trad[:l] if token != '[PAD]'])
 
-        return measure_supertagging_accuracy(type_predictions, b_symbols_tokenized,
-                                             ignore_idx=self.symbols_map["[PAD]"])
+        return measure_supertagging_accuracy(pred, b_symbols_tokenized,
+                                             ignore_idx=self.symbols_map["[PAD]"]), float(cross_entropy_loss(type_predictions, b_symbols_tokenized))
 
-    def eval_epoch(self, dataloader):
+    def eval_epoch(self, dataloader, cross_entropy_loss):
         r"""Average the evaluation of all the batch.
 
         Args:
@@ -90,11 +92,12 @@ class EncoderDecoder(Module):
         s_total, s_correct, w_total, w_correct = (0.1,) * 4
 
         for step, batch in enumerate(dataloader):
-            batch_output = self.eval_batch(batch)
+            batch = batch
+            batch_output , loss = self.eval_batch(batch, cross_entropy_loss)
             ((bs_correct, bs_total), (bw_correct, bw_total)) = batch_output
             s_total += bs_total
             s_correct += bs_correct
             w_total += bw_total
             w_correct += bw_correct
 
-        return s_correct / s_total, w_correct / w_total
+        return s_correct / s_total, w_correct / w_total, loss
diff --git a/SuperTagger/utils.py b/SuperTagger/utils.py
index c6b3301731f063400b1150f180d43b3d56de02ed..5c0ec76ed9082510c5b0d5264e10d42e48d9070e 100644
--- a/SuperTagger/utils.py
+++ b/SuperTagger/utils.py
@@ -1,11 +1,11 @@
-import copy
 import datetime
 
 import pandas as pd
-from torch.nn import ModuleList
+import torch
 from tqdm import tqdm
 
 
+
 def read_csv_pgbar(csv_path, nrows=float('inf'), chunksize=500):
     print("\n" + "#" * 20)
     print("Loading csv...")
@@ -36,4 +36,34 @@ def format_time(elapsed):
     elapsed_rounded = int(round(elapsed))
 
     # Format as hh:mm:ss
-    return str(datetime.timedelta(seconds=elapsed_rounded))
\ No newline at end of file
+    return str(datetime.timedelta(seconds=elapsed_rounded))
+
+def checkpoint_save(model, opt, epoch, dir, loss):
+
+    torch.save({
+        'epoch': epoch,
+        'model_state_dict': model.state_dict(),
+        'optimizer_state_dict': opt.state_dict(),
+        'loss': loss,
+    }, dir + '/model_check.pt')
+
+
+def checkpoint_load(model, opt, path):
+    epoch = 0
+    loss = 0
+    print("#" *15)
+
+    try:
+        checkpoint = torch.load(path)
+        model.load_state_dict(checkpoint['model_state_dict'])
+        opt.load_state_dict(checkpoint['optimizer_state_dict'])
+        epoch = checkpoint['epoch']
+        loss = checkpoint['loss']
+        print("\n The loading checkpoint was successful  ! \n")
+        print("#" * 10)
+    except Exception as e:
+        print("\nCan't load checkpoint model because : "+ str(e) +"\n\nUse default model \n")
+    print("#" * 15)
+    return model, opt, epoch, loss
+
+
diff --git a/bash_GPU.sh b/bash_GPU.sh
new file mode 100644
index 0000000000000000000000000000000000000000..834ebf556347599d1c9d108c5aa4740dfcb844dd
--- /dev/null
+++ b/bash_GPU.sh
@@ -0,0 +1,13 @@
+#!/bin/sh
+#SBATCH --job-name=N-tensorboard
+#SBATCH --partition=RTX6000Node
+#SBATCH --gres=gpu:1
+#SBATCH --mem=32000
+#SBATCH --gres-flags=enforce-binding
+#SBATCH --error="/users/celdev/jrabault/PNRIA - DeepGrail/OUT/error_rtx1.err"
+#SBATCH --output="/users/celdev/jrabault/PNRIA - DeepGrail/OUT/out_rtx1.out"
+
+module purge
+module load singularity/3.0.3
+
+srun singularity exec /logiciels/containerCollections/CUDA11/pytorch-NGC-21-03-py3.sif python "/users/celdev/jrabault/PNRIA - DeepGrail/train.py" 
\ No newline at end of file
diff --git a/train.py b/train.py
index ee2a402569441d2869fdb5f290cfcfcde99a9ca6..6353e912a4068745a9f3a1ad5ad1e01dd88c4a20 100644
--- a/train.py
+++ b/train.py
@@ -1,13 +1,16 @@
 import math
+import os
 import time
 
 import numpy as np
 import torch
 from torch.optim import AdamW
+import transformers
 import torch.nn.functional as F
 from torch.utils.data import Dataset, TensorDataset, random_split
 from transformers import (AutoTokenizer, get_cosine_schedule_with_warmup)
 from transformers import (CamembertModel)
+from torch.utils.tensorboard import SummaryWriter
 
 from Configuration import Configuration
 from SuperTagger.Encoder.EncoderInput import EncoderInput
@@ -15,7 +18,12 @@ from SuperTagger.EncoderDecoder import EncoderDecoder
 from SuperTagger.Symbol.SymbolTokenizer import SymbolTokenizer
 from SuperTagger.Symbol.symbol_map import symbol_map
 from SuperTagger.eval import NormCrossEntropy
-from SuperTagger.utils import format_time, read_csv_pgbar
+from SuperTagger.utils import format_time, read_csv_pgbar, checkpoint_save, checkpoint_load
+
+from datetime import datetime
+
+transformers.TOKENIZERS_PARALLELISM = True
+torch.cuda.empty_cache()
 
 # region ParamsModel
 
@@ -30,8 +38,7 @@ num_gru_layers = int(Configuration.modelDecoderConfig['num_rnn_layers'])
 
 file_path = 'Datasets/m2_dataset.csv'
 batch_size = int(Configuration.modelTrainingConfig['batch_size'])
-nb_sentences = batch_size * 11
-device = torch.device(Configuration.modelTrainingConfig['device'])
+nb_sentences = batch_size * 10
 # Number of training epochs. The BERT authors recommend between 2 and 4.
 epochs = int(Configuration.modelTrainingConfig['epoch'])
 seed_val = int(Configuration.modelTrainingConfig['seed_val'])
@@ -40,6 +47,30 @@ loss_scaled_by_freq = True
 
 # endregion ParamsTraining
 
+# region OutputTraining
+
+outpout_path = str(Configuration.modelTrainingConfig['output_path'])
+
+training_dir = os.path.join(outpout_path, 'Tranning_' + datetime.today().strftime('%d-%m_%H-%M'))
+logs_dir = os.path.join(training_dir, 'logs')
+
+checkpoint_dir = training_dir
+writer = SummaryWriter(log_dir=logs_dir)
+
+use_checkpoint_SAVE = bool(Configuration.modelTrainingConfig['use_checkpoint_SAVE'])
+
+# endregion OutputTraining
+
+# region InputTraining
+
+input_path = str(Configuration.modelTrainingConfig['input_path'])
+model_to_load = str(Configuration.modelTrainingConfig['model_to_load'])
+model_to_load_path = os.path.join(input_path, model_to_load)
+use_checkpoint_LOAD = bool(Configuration.modelTrainingConfig['use_checkpoint_LOAD'])
+print(use_checkpoint_LOAD)
+
+# endregion InputTraining
+
 # region Print config
 
 print("##" * 15 + "\nConfiguration : \n")
@@ -59,11 +90,24 @@ print("\tbatch_size :", batch_size)
 print("\tepochs :", epochs)
 print("\tseed_val :", seed_val)
 
+print("\n Output\n")
+print("\tuse checkpoint save :", use_checkpoint_SAVE)
+print("\tcheckpoint_dir :", checkpoint_dir)
+print("\tlogs_dir :", logs_dir)
+
+print("\n Input\n")
+print("\tModel to load :", model_to_load_path)
+print("\tLoad checkpoint :", use_checkpoint_LOAD)
+
 print("\nLoss and optimizer : ")
 
 print("\tlearning_rate :", learning_rate)
 print("\twith loss scaled by freq :", loss_scaled_by_freq)
 
+print("\n Device\n")
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+print("\t", device)
+
 print()
 print("##" * 15)
 
@@ -79,6 +123,7 @@ BASE_MODEL = CamembertModel.from_pretrained("camembert-base")
 symbols_tokenizer = SymbolTokenizer(symbol_map, max_symbols_in_sentence, max_len_sentence)
 sents_tokenizer = EncoderInput(BASE_TOKENIZER)
 model = EncoderDecoder(BASE_TOKENIZER, BASE_MODEL, symbol_map)
+model = model.to("cuda" if torch.cuda.is_available() else "cpu")
 
 # endregion Model
 
@@ -91,7 +136,7 @@ sents_tokenized, sents_mask = sents_tokenizer.fit_transform_tensors(df['Sentence
 dataset = TensorDataset(sents_tokenized, sents_mask, symbols_tokenized)
 
 # Calculate the number of samples to include in each set.
-train_size = int(0.9 * len(dataset))
+train_size = int(0.95 * len(dataset))
 val_size = len(dataset) - train_size
 
 # Divide the dataset by randomly selecting samples.
@@ -108,19 +153,23 @@ validation_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batc
 # region Fit tunning
 
 # Optimizer
-optimizer_encoder = AdamW(model.encoder.parameters(),
-                          lr=5e-5,
-                          eps=1e-8)
-optimizer_decoder = AdamW(model.decoder.parameters(),
-                          lr=learning_rate,
-                          eps=1e-8)
+# optimizer_encoder = AdamW(model.encoder.parameters(),
+#                           lr=5e-5,
+#                           eps=1e-8)
+# optimizer_decoder = AdamW(model.decoder.parameters(),
+#                           lr=learning_rate,
+#                           eps=1e-8)
+
+optimizer = AdamW(model.parameters(),
+                  lr=learning_rate,
+                  eps=1e-8)
 
 # Total number of training steps is [number of batches] x [number of epochs].
 # (Note that this is not the same as the number of training samples).
 total_steps = len(training_dataloader) * epochs
 
 # Create the learning rate scheduler.
-scheduler = get_cosine_schedule_with_warmup(optimizer_decoder,
+scheduler = get_cosine_schedule_with_warmup(optimizer,
                                             num_warmup_steps=0,
                                             num_training_steps=total_steps)
 
@@ -128,7 +177,8 @@ scheduler = get_cosine_schedule_with_warmup(optimizer_decoder,
 if loss_scaled_by_freq:
     weights = torch.as_tensor(
         [6.9952, 1.0763, 1.0317, 43.274, 16.5276, 11.8821, 28.2416, 2.7548, 1.0728, 3.1847, 8.4521, 6.77, 11.1887,
-         6.6692, 23.1277, 11.8821, 4.4338, 1.2303, 5.0238, 8.4376, 1.0656, 4.6886, 1.028, 4.273, 4.273, 0])
+         6.6692, 23.1277, 11.8821, 4.4338, 1.2303, 5.0238, 8.4376, 1.0656, 4.6886, 1.028, 4.273, 4.273, 0],
+        device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
     cross_entropy_loss = NormCrossEntropy(symbols_tokenizer.pad_token_id, symbols_tokenizer.sep_token_id,
                                           weights=weights)
 else:
@@ -148,6 +198,10 @@ total_t0 = time.time()
 
 validate = True
 
+if use_checkpoint_LOAD:
+    model, optimizer, last_epoch, loss = checkpoint_load(model, optimizer, model_to_load_path)
+    epochs = epochs - last_epoch
+
 
 def run_epochs(epochs):
     # For each epoch...
@@ -173,49 +227,62 @@ def run_epochs(epochs):
         # For each batch of training data...
         for step, batch in enumerate(training_dataloader):
 
+            # if epoch_i == 0 and step == 0:
+            #     writer.add_graph(model, input_to_model=batch[0], verbose=False)
+
             # Progress update every 40 batches.
-            if step % math.ceil(batch_size / 10) == 0 and not step == 0:
+            if step % 40 == 0 and not step == 0:
                 # Calculate elapsed time in minutes.
                 elapsed = format_time(time.time() - t0)
                 # Report progress.
                 print('  Batch {:>5,}  of  {:>5,}.    Elapsed: {:}.'.format(step, len(training_dataloader), elapsed))
 
-            # Unpack this training batch from our dataloader.
-            b_sents_tokenized = batch[0]
-            b_sents_mask = batch[1]
-            b_symbols_tokenized = batch[2]
+                # Unpack this training batch from our dataloader.
+            b_sents_tokenized = batch[0].to("cuda" if torch.cuda.is_available() else "cpu")
+            b_sents_mask = batch[1].to("cuda" if torch.cuda.is_available() else "cpu")
+            b_symbols_tokenized = batch[2].to("cuda" if torch.cuda.is_available() else "cpu")
+
+            # optimizer_encoder.zero_grad()
+            # optimizer_decoder.zero_grad()
 
-            optimizer_encoder.zero_grad()
-            optimizer_decoder.zero_grad()
+            optimizer.zero_grad()
 
             logits_predictions = model(b_sents_tokenized, b_sents_mask, b_symbols_tokenized)
 
             predict_trad = [{v: k for k, v in symbol_map.items()}[int(i)] for i in
                             torch.argmax(F.softmax(logits_predictions, dim=2), dim=2)[0]]
             true_trad = [{v: k for k, v in symbol_map.items()}[int(i)] for i in b_symbols_tokenized[0]]
-            print("\ntrain true (", len([i for i in true_trad if i != '[PAD]']), ") : ",
-                  [token for token in true_trad if token != '[PAD]'])
-            print("\ntrain predict (", len([i for i in predict_trad if i != '[PAD]']), ") : ",
-                  [token for token in predict_trad if token != '[PAD]'])
+
+            l = len([i for i in true_trad if i != '[PAD]'])
+
+            if step % 40 == 0 and not step == 0:
+                writer.add_text("Sample", "\ntrain true (" + str(l) + ") : " + str([token for token in true_trad if token != '[PAD]']) + "\ntrain predict (" + str(len([i for i in predict_trad if i != '[PAD]'])) + ") : " + str([token for token in predict_trad[:l] if token != '[PAD]']))
 
             loss = cross_entropy_loss(logits_predictions, b_symbols_tokenized)
             # Perform a backward pass to calculate the gradients.
-            total_train_loss += loss.item()
+            total_train_loss += float(loss)
             loss.backward()
 
             # Clip the norm of the gradients to 1.0.
             # This is to help prevent the "exploding gradients" problem.
-            torch.nn.utils.clip_grad_norm_(model.encoder.parameters(), 1.0)
-            torch.nn.utils.clip_grad_norm_(model.decoder.parameters(), 1.0)
+            # torch.nn.utils.clip_grad_norm_(model.encoder.parameters(), 1.0)
+            # torch.nn.utils.clip_grad_norm_(model.decoder.parameters(), 1.0)
 
             # Update parameters and take a step using the computed gradient.
             # The optimizer dictates the "update rule"--how the parameters are
             # modified based on their gradients, the learning rate, etc.
-            optimizer_encoder.step()
-            optimizer_decoder.step()
+            # optimizer_encoder.step()
+            # optimizer_decoder.step()
+
+            optimizer.step()
 
             scheduler.step()
 
+        # checkpoint
+
+        if use_checkpoint_SAVE:
+            checkpoint_save(model, optimizer, epoch_i, checkpoint_dir, loss)
+
         avg_train_loss = total_train_loss / len(training_dataloader)
 
         # Measure how long this epoch took.
@@ -225,15 +292,26 @@ def run_epochs(epochs):
             model.eval()
             with torch.no_grad():
                 print("Start eval")
-                accuracy_sents, accuracy_symbol = model.eval_epoch(validation_dataloader)
+                accuracy_sents, accuracy_symbol, v_loss = model.eval_epoch(validation_dataloader, cross_entropy_loss)
                 print("")
                 print("  Average accuracy sents on epoch: {0:.2f}".format(accuracy_sents))
                 print("  Average accuracy symbol on epoch: {0:.2f}".format(accuracy_symbol))
+                writer.add_scalar('Accuracy/sents', accuracy_sents, epoch_i + 1)
+                writer.add_scalar('Accuracy/symbol', accuracy_symbol, epoch_i + 1)
 
         print("")
         print("  Average training loss: {0:.2f}".format(avg_train_loss))
         print("  Training epcoh took: {:}".format(training_time))
 
+        # writer.add_scalar('Loss/train', total_train_loss, epoch_i+1)
+
+        writer.add_scalars('Training vs. Validation Loss',
+                           {'Training': avg_train_loss, 'Validation': v_loss},
+                           epoch_i + 1)
+        writer.flush()
+
 
 run_epochs(epochs)
 # endregion Train
+
+