diff --git a/SuperTagger/SuperTagger.py b/SuperTagger/SuperTagger.py
index e1208bdcec6d56f74dd61c57d539e11c51cdc25a..3332a7269fc3ead4fdba1e3d22bdbe7d4cf35f9d 100644
--- a/SuperTagger/SuperTagger.py
+++ b/SuperTagger/SuperTagger.py
@@ -21,16 +21,18 @@ from SuperTagger.Utils.SentencesTokenizer import SentencesTokenizer
 from SuperTagger.Utils.SymbolTokenizer import SymbolTokenizer
 from SuperTagger.Utils.Tagging_bert_model import Tagging_bert_model
 
-
 def categorical_accuracy(preds, truth):
-    flat_preds = preds.flatten()
-    flat_labels = truth.flatten()
     good_label = 0
-    for i in range(len(flat_preds)):
-        if flat_labels[i] == flat_preds[i] and flat_labels[i]!=0:
-            good_label += 1
-
-    return good_label / len(flat_labels)
+    nb_label = 0
+    for i in range(len(truth)):
+        sublist_truth = truth[i]
+        sublist_preds = preds[i]
+        for j in range(len(sublist_truth)):
+            if sublist_truth[j] != 0:
+                if sublist_truth[j] == sublist_preds[j]:
+                    good_label += 1
+                nb_label += 1
+    return good_label / nb_label
 
 
 def format_time(elapsed):
@@ -102,7 +104,6 @@ class SuperTagger:
         self.model = Tagging_bert_model(bert_name, num_label + 1)
         index_to_tags = {k + 1: v for k, v in index_to_tags.items()}
         index_to_tags[0] = '<unk>'
-        print(index_to_tags)
         self.index_to_tags = index_to_tags
         self.bert_name = bert_name
         self.sent_tokenizer = SentencesTokenizer(AutoTokenizer.from_pretrained(
@@ -116,16 +117,17 @@ class SuperTagger:
     def predict(self, sentences):
 
         assert self.trainable or self.model is None, "Please use the create_new_model(...) or load_weights(...) function before the predict, the model is not integrated"
+        self.model.eval()
+        with torch.no_grad():
+            sents_tokenized_t, sents_mask_t = self.sent_tokenizer.fit_transform_tensors(sentences)
 
-        sents_tokenized_t, sents_mask_t = self.sent_tokenizer.fit_transform_tensors(sentences)
-
-        self.model = self.model.cpu()
+            self.model = self.model.cpu()
 
-        pred = self.model.predict((sents_tokenized_t, sents_mask_t))
+            pred = self.model.predict((sents_tokenized_t, sents_mask_t))
 
-        print(pred)
+            print(pred)
 
-        return self.tags_tokenizer.convert_ids_to_tags(pred.detach())
+            return pred,self.tags_tokenizer.convert_ids_to_tags(pred.detach())
 
     def train(self, sentences, tags, validation_rate=0.1, epochs=20, batch_size=32, tensorboard=False,
               checkpoint=False):
@@ -158,13 +160,15 @@ class SuperTagger:
                 print(f'\tVal Loss: {eval_loss:.3f} | Val Acc: {eval_accuracy * 100:.2f}%')
 
             if tensorboard:
-                writer.add_scalars(f'Train_Accuracy/Loss', {
-                    'Accuracy_train': epoch_acc,
-                    'Loss_train': epoch_loss}, epoch_i + 1)
+                writer.add_scalars(f'Accuracy', {
+                    'Train': epoch_acc}, epoch_i + 1)
+                writer.add_scalars(f'Loss', {
+                    'Train': epoch_loss}, epoch_i + 1)
                 if validation_rate > 0.0:
-                    writer.add_scalars(f'Validation_Accuracy/Loss', {
-                        'Accuracy_val': eval_accuracy,
-                        'Loss_val': eval_loss, }, epoch_i + 1)
+                    writer.add_scalars(f'Accuracy', {
+                        'Validation': eval_accuracy}, epoch_i + 1)
+                    writer.add_scalars(f'Loss', {
+                        'Validation': eval_loss}, epoch_i + 1)
 
             if checkpoint:
                 self.__checkpoint_save(path=os.path.join(checkpoint_dir, 'model_check.pt'))
@@ -210,13 +214,10 @@ class SuperTagger:
                 b_sents_mask = batch[1].to(self.device)
                 targets = batch[2].to(self.device)
 
-                self.optimizer.zero_grad()
-
                 _, logit = self.model((b_sents_tokenized, b_sents_mask, targets))
 
                 predictions = torch.argmax(logit, dim=2).detach().cpu().numpy()
                 label_ids = targets.cpu().numpy()
-                print()
                 #torch.nn.functional.one_hot(targets).long()
                 # torch.argmax(logit)
 
@@ -229,6 +230,8 @@ class SuperTagger:
                 epoch_acc += acc
                 epoch_loss += loss.item()
 
+                self.optimizer.zero_grad()
+                torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
                 self.optimizer.step()
                 i += 1
 
diff --git a/SuperTagger/Utils/SentencesTokenizer.py b/SuperTagger/Utils/SentencesTokenizer.py
index f14f89ff7f8c2539b422a6179e2ef2d840b3a443..f1fbea51286ffb4f86e8a0b4f199bd78eb292772 100644
--- a/SuperTagger/Utils/SentencesTokenizer.py
+++ b/SuperTagger/Utils/SentencesTokenizer.py
@@ -13,7 +13,7 @@ class SentencesTokenizer():
 
     def fit_transform_tensors(self, sents):
         # , return_tensors = 'pt'
-        temp = self.tokenizer(sents, padding=True, return_offsets_mapping = True, return_tensors = 'pt')
+        temp = self.tokenizer(sents, padding=True, return_tensors = 'pt')
         #
         # len_sent_max = len(temp['attention_mask'][0])
         #
diff --git a/SuperTagger/Utils/utils.py b/SuperTagger/Utils/utils.py
index efe708c3cc493f57676908a2c36611d1da98ea6f..03aadfeebc90e85a8b15d912c62459efdc2c9cc1 100644
--- a/SuperTagger/Utils/utils.py
+++ b/SuperTagger/Utils/utils.py
@@ -22,7 +22,6 @@ def read_csv_pgbar(csv_path, nrows=float('inf'), chunksize=100):
 
     df = pd.concat((f for f in chunk_list), axis=0)
     print("#" * 20)
-
     return df
 
 
diff --git a/main.py b/main.py
index 5e881a78a04e28b8c33844d92ad43f6a48014081..bc66aec3e7b48ee90d5ce83b52955f2004ce363b 100644
--- a/main.py
+++ b/main.py
@@ -7,10 +7,6 @@ from SuperTagger.Utils.SymbolTokenizer import SymbolTokenizer
 from SuperTagger.Utils.utils import read_csv_pgbar
 
 
-def categorical_accuracy(preds, truth):
-    flat_preds = preds[:len(truth)].flatten()
-    flat_labels = truth.flatten()
-    return np.sum(flat_preds == flat_labels) / len(flat_labels)
 
 
 def load_obj(name):
@@ -21,13 +17,13 @@ def load_obj(name):
 
 file_path = 'Datasets/m2_dataset_V2.csv'
 
-df = read_csv_pgbar(file_path, 10)
+df = read_csv_pgbar(file_path,1000)
 
 texts = df['X'].tolist()
 tags = df['Z'].tolist()
-
-texts = texts[:3]
-tags = tags[:3]
+# texts = texts[12650:12800]
+# tags = tags[12650:12800]
+print(len(tags))
 
 tagger = SuperTagger()
 
@@ -49,22 +45,32 @@ tagger.load_weights("models/model_check.pt")
 #
 # print(sent_tokenizer.convert_ids_to_tokens(sents_tokenized_t))
 
-pred = tagger.predict(texts)
-
-print(tags[1])
-print()
-print(pred[1])
+pred, pred_convert = tagger.predict(texts)
+#
+# print(texts)
+# print()
+# print(tags)
+# print()
+# print(pred)
+# print()
+# print(pred_convert)
 
 
 def categorical_accuracy(preds, truth):
-    flat_preds = preds.flatten()
-    flat_labels = truth.flatten()
+    # flat_preds = [str(item) for sublist in preds for item in sublist]
+    # flat_labels = [str(item) for sublist in truth for item in sublist]
+    nb_label = 0
     good_label = 0
-    for i in range(len(flat_preds)):
-        if flat_labels[i] == flat_preds[i] and flat_labels[i] != 0:
-            good_label += 1
+    for i in range(len(truth)):
+        sublist_truth = truth[i]
+        sublist_preds = preds[i]
+        for j in range(min(len(sublist_truth),len(sublist_preds))):
+            if str(sublist_truth[j]) == str(sublist_preds[j]):
+                good_label += 1
+            nb_label += 1
+
+    return good_label / nb_label
 
-    return good_label / len(flat_labels)
 
 
-print(categorical_accuracy(np.array(pred), np.array(tags)))
+print(categorical_accuracy(pred_convert, np.array(tags)))
diff --git a/train.py b/train.py
index 1188d8adc1a89b2d633a5a6d0f61a43d77a9f23d..daeef4c30f968b6f9d9db7c595fe93e1fb3363ec 100644
--- a/train.py
+++ b/train.py
@@ -11,7 +11,7 @@ def load_obj(name):
 file_path = 'Datasets/m2_dataset_V2.csv'
 
 
-df = read_csv_pgbar(file_path,50)
+df = read_csv_pgbar(file_path,1000)
 
 
 texts = df['X'].tolist()