diff --git a/Datasets/Utils/PostpreprocesTXT.py b/Datasets/Utils/PostpreprocesTXT.py
index 9c388d0a0bb2fc39ffe8e7409363793fccc1de31..b10564ec2b0437e60ec5cdd7fa483e28f43beffa 100644
--- a/Datasets/Utils/PostpreprocesTXT.py
+++ b/Datasets/Utils/PostpreprocesTXT.py
@@ -131,10 +131,10 @@ X, Y1, Y2, Z, vocabulary, vnorm, partsofspeech1, partsofspeech2, superset, maxle
 
 df = pd.DataFrame(columns = ["X", "Y1", "Y2", "Z"])
 
-df['X'] = X
-df['Y1'] = Y1
-df['Y2'] = Y2
-df['Z'] = Z
+df['X'] = X[:len(X)-1]
+df['Y1'] = Y1[:len(X)-1]
+df['Y2'] = Y2[:len(X)-1]
+df['Z'] = Z[:len(X)-1]
 
 df.to_csv("../m2_dataset_V2.csv", index=False)
 
diff --git a/Datasets/m2_dataset_V2.csv b/Datasets/m2_dataset_V2.csv
index 662c963b449b5cf4e1807bc233f5d4c821c3530b..f6d823ddeb68a45919be5cad65c55635ebe8a664 100644
--- a/Datasets/m2_dataset_V2.csv
+++ b/Datasets/m2_dataset_V2.csv
@@ -15768,4 +15768,3 @@ X,Y1,Y2,Z
  L' effet indésirable le plus fréquent avec Angiox ( observé chez plus d' un patient sur 10 ) est le saignement bénin .,"['DET', 'NC', 'ADJ', 'DET', 'ADV', 'ADJ', 'P', 'NPP', 'PONCT', 'VPP', 'P', 'ADV', 'P', 'DET', 'NC', 'P', 'PRO', 'PONCT', 'V', 'DET', 'NC', 'ADJ', 'PONCT']","['DET:ART', 'NOM', 'ADJ', 'DET:ART', 'ADV', 'ADJ', 'PRP', 'NAM', 'PUN', 'VER:pper', 'PRP', 'ADV', 'PRP', 'DET:ART', 'NOM', 'PRP', 'NUM', 'PUN', 'VER:pres', 'DET:ART', 'NOM', 'ADJ', 'PUN']","['dr(0,np,n)', 'n', 'dl(0,n,n)', 'dr(0,dl(0,n,n),dl(0,n,n))', 'dr(0,dl(0,n,n),dl(0,n,n))', 'dl(0,n,n)', 'dr(0,dl(0,np,np),np)', 'np', 'dr(0,dl(0,np,np),dl(0,n,n))', 'dl(0,n,n)', 'dr(0,dl(1,dl(0,n,n),dl(0,n,n)),np)', 'dr(0,np,pp_de)', 'dr(0,pp_de,np)', 'dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'np', 'let', 'dr(0,dl(0,np,s),np)', 'dr(0,np,n)', 'n', 'dl(0,n,n)', 'dl(0,s,txt)']"
 " Pour avoir le détail de tous les effets indésirables observés lors de l' utilisation de Angiox , voir la notice .","['P', 'VINF', 'DET', 'NC', 'P', 'ADV', 'DET', 'NC', 'ADJ', 'VPP', 'ADV', 'P', 'DET', 'NC', 'P', 'NPP', 'PONCT', 'VINF', 'DET', 'NC', 'PONCT']","['PRP', 'VER:infi', 'DET:ART', 'NOM', 'PRP', 'ADV', 'DET:ART', 'NOM', 'ADJ', 'VER:pper', 'ADV', 'PRP', 'DET:ART', 'NOM', 'PRP', 'NAM', 'PUN', 'VER:infi', 'DET:ART', 'NOM', 'PUN']","['dr(0,dr(0,dl(0,np,s),dl(0,np,s)),dl(0,np,s_inf))', 'dr(0,dl(0,np,s_inf),np)', 'dr(0,np,n)', 'n', 'dr(0,dl(0,n,n),np)', 'dr(0,np,np)', 'dr(0,np,n)', 'n', 'dl(0,n,n)', 'dl(0,n,n)', 'dr(0,dl(1,dl(0,n,n),dl(0,n,n)),pp_de)', 'dr(0,pp_de,np)', 'dr(0,np,n)', 'n', 'dr(0,dl(0,n,n),np)', 'np', 'let', 'dr(0,dl(0,np,s_inf),np)', 'dr(0,np,n)', 'n', 'dl(0,dl(0,np,s),txt)']"
 " Angiox ne doit pas être utilisé chez les personnes pouvant présenter une hypersensibilité ( allergie ) à la bivalirudine , aux autres hirudines , ou à l' un des autres composants constituant Angiox .","['NPP', 'ADV', 'V', 'ADV', 'VINF', 'VPP', 'P', 'DET', 'NC', 'VPR', 'VINF', 'DET', 'NC', 'PONCT', 'NC', 'PONCT', 'P', 'DET', 'NC', 'PONCT', 'P+D', 'ADJ', 'NC', 'PONCT', 'CC', 'P', 'DET', 'NC', 'P+D', 'ADJ', 'NC', 'VPR', 'NPP', 'PONCT']","['NAM', 'ADV', 'VER:pres', 'ADV', 'VER:infi', 'VER:pper', 'PRP', 'DET:ART', 'NOM', 'VER:ppre', 'VER:infi', 'DET:ART', 'NOM', 'PUN', 'NOM', 'PUN', 'PRP', 'DET:ART', 'NOM', 'PUN', 'PRP:det', 'ADJ', 'NOM', 'PUN', 'KON', 'PRP', 'DET:ART', 'NUM', 'PRP:det', 'ADJ', 'NOM', 'VER:ppre', 'NAM', 'PUN']","['np', 'dr(0,dl(0,np,s),dl(0,np,s))', 'dr(0,dl(0,np,s),dl(0,np,s_inf))', 'dl(1,s,s)', 'dr(0,dl(0,np,s_inf),dl(0,np,s_pass))', 'dl(0,np,s_pass)', 'dr(0,dl(1,s,s),np)', 'dr(0,np,n)', 'n', 'dr(0,dl(0,n,n),dl(0,np,s_inf))', 'dr(0,dl(0,np,s_inf),np)', 'dr(0,np,n)', 'n', 'dr(0,dl(0,n,n),n)', 'n', 'let', 'dr(0,dl(0,n,n),np)', 'dr(0,np,n)', 'n', 'dr(0,dl(0,dl(0,n,n),dl(0,n,n)),dl(0,n,n))', 'dr(0,dl(0,n,n),n)', 'dr(0,n,n)', 'n', 'let', 'dr(0,dl(0,dl(0,n,n),dl(0,n,n)),dl(0,n,n))', 'dr(0,dl(0,n,n),np)', 'dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),n)', 'dr(0,n,n)', 'n', 'dr(0,dl(0,n,n),np)', 'np', 'dl(0,s,txt)']"
-,[],[],[]
diff --git a/SuperTagger/SuperTagger.py b/SuperTagger/SuperTagger.py
index 54072932deb45f25312f819ee8285ec6ef6b7b00..e1208bdcec6d56f74dd61c57d539e11c51cdc25a 100644
--- a/SuperTagger/SuperTagger.py
+++ b/SuperTagger/SuperTagger.py
@@ -9,6 +9,7 @@ import datetime
 import numpy as np
 import torch
 from torch import nn
+from torch.autograd import Variable
 from torch.utils.tensorboard import SummaryWriter
 from tqdm import tqdm
 from transformers import AutoTokenizer
@@ -55,6 +56,7 @@ class SuperTagger:
         self.model = None
 
         self.optimizer = None
+        self.loss = nn.CrossEntropyLoss(ignore_index=0)
 
         self.epoch_i = 0
         self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -121,6 +123,8 @@ class SuperTagger:
 
         pred = self.model.predict((sents_tokenized_t, sents_mask_t))
 
+        print(pred)
+
         return self.tags_tokenizer.convert_ids_to_tags(pred.detach())
 
     def train(self, sentences, tags, validation_rate=0.1, epochs=20, batch_size=32, tensorboard=False,
@@ -208,10 +212,15 @@ class SuperTagger:
 
                 self.optimizer.zero_grad()
 
-                loss, logit = self.model((b_sents_tokenized, b_sents_mask, targets))
+                _, logit = self.model((b_sents_tokenized, b_sents_mask, targets))
 
                 predictions = torch.argmax(logit, dim=2).detach().cpu().numpy()
                 label_ids = targets.cpu().numpy()
+                print()
+                #torch.nn.functional.one_hot(targets).long()
+                # torch.argmax(logit)
+
+                loss = self.loss(torch.transpose(logit, 1, 2),Variable(targets))
 
                 acc = categorical_accuracy(predictions, label_ids)
 
@@ -243,7 +252,9 @@ class SuperTagger:
                 b_sents_mask = batch[1].to(self.device)
                 b_symbols_tokenized = batch[2].to(self.device)
 
-                loss, logits = self.model((b_sents_tokenized, b_sents_mask, b_symbols_tokenized))
+                _, logits = self.model((b_sents_tokenized, b_sents_mask, b_symbols_tokenized))
+
+                loss = self.loss(torch.transpose(logits, 1, 2), Variable(b_symbols_tokenized))
 
                 predictions = torch.argmax(logits, dim=2).detach().cpu().numpy()
                 label_ids = b_symbols_tokenized.cpu().numpy()
diff --git a/SuperTagger/Utils/SentencesTokenizer.py b/SuperTagger/Utils/SentencesTokenizer.py
index 1cdb1ee9aea72095840845970bde33f09959b9e0..f14f89ff7f8c2539b422a6179e2ef2d840b3a443 100644
--- a/SuperTagger/Utils/SentencesTokenizer.py
+++ b/SuperTagger/Utils/SentencesTokenizer.py
@@ -13,29 +13,29 @@ class SentencesTokenizer():
 
     def fit_transform_tensors(self, sents):
         # , return_tensors = 'pt'
-        temp = self.tokenizer(sents, padding=True, return_offsets_mapping = True)
-
-        len_sent_max = len(temp['attention_mask'][0])
-
-        input_ids = np.ones((len(sents),len_sent_max))
-        attention_mask = np.zeros((len(sents),len_sent_max))
-
-        for i in range(len(temp['offset_mapping'])):
-            h = 1
-            input_ids[i][0] = self.tokenizer.cls_token_id
-            attention_mask[i][0] = 1
-            for j in range (1,len_sent_max-1):
-                if temp['offset_mapping'][i][j][1] != temp['offset_mapping'][i][j+1][0]:
-                    input_ids[i][h] = temp['input_ids'][i][j]
-                    attention_mask[i][h] = 1
-                    h += 1
-            input_ids[i][h] = self.tokenizer.eos_token_id
-            attention_mask[i][h] = 1
-
-        input_ids = torch.tensor(input_ids).long()
-        attention_mask = torch.tensor(attention_mask)
-
-        return input_ids, attention_mask
+        temp = self.tokenizer(sents, padding=True, return_offsets_mapping = True, return_tensors = 'pt')
+        #
+        # len_sent_max = len(temp['attention_mask'][0])
+        #
+        # input_ids = np.ones((len(sents),len_sent_max))
+        # attention_mask = np.zeros((len(sents),len_sent_max))
+        #
+        # for i in range(len(temp['offset_mapping'])):
+        #     h = 1
+        #     input_ids[i][0] = self.tokenizer.cls_token_id
+        #     attention_mask[i][0] = 1
+        #     for j in range (1,len_sent_max-1):
+        #         if temp['offset_mapping'][i][j][1] != temp['offset_mapping'][i][j+1][0]:
+        #             input_ids[i][h] = temp['input_ids'][i][j]
+        #             attention_mask[i][h] = 1
+        #             h += 1
+        #     input_ids[i][h] = self.tokenizer.eos_token_id
+        #     attention_mask[i][h] = 1
+        #
+        # input_ids = torch.tensor(input_ids).long()
+        # attention_mask = torch.tensor(attention_mask)
+
+        return temp["input_ids"], temp["attention_mask"]
 
     def convert_ids_to_tokens(self, inputs_ids, skip_special_tokens=False):
         return self.tokenizer.batch_decode(inputs_ids, skip_special_tokens=skip_special_tokens)
diff --git a/main.py b/main.py
index 846264bf573c7d8f8c2079fca06ff095a05903bc..5e881a78a04e28b8c33844d92ad43f6a48014081 100644
--- a/main.py
+++ b/main.py
@@ -6,11 +6,13 @@ from SuperTagger.Utils.SentencesTokenizer import SentencesTokenizer
 from SuperTagger.Utils.SymbolTokenizer import SymbolTokenizer
 from SuperTagger.Utils.utils import read_csv_pgbar
 
+
 def categorical_accuracy(preds, truth):
     flat_preds = preds[:len(truth)].flatten()
     flat_labels = truth.flatten()
     return np.sum(flat_preds == flat_labels) / len(flat_labels)
 
+
 def load_obj(name):
     with open(name + '.pkl', 'rb') as f:
         import pickle
@@ -19,16 +21,13 @@ def load_obj(name):
 
 file_path = 'Datasets/m2_dataset_V2.csv'
 
-
 df = read_csv_pgbar(file_path, 10)
 
 texts = df['X'].tolist()
 tags = df['Z'].tolist()
 
-
-texts = texts[:1]
-tags = tags[:1]
-
+texts = texts[:3]
+tags = tags[:3]
 
 tagger = SuperTagger()
 
@@ -52,12 +51,20 @@ tagger.load_weights("models/model_check.pt")
 
 pred = tagger.predict(texts)
 
-print(tags)
+print(tags[1])
 print()
-print(pred[0])
+print(pred[1])
+
 
-print(pred[0][0] == tags[0])
+def categorical_accuracy(preds, truth):
+    flat_preds = preds.flatten()
+    flat_labels = truth.flatten()
+    good_label = 0
+    for i in range(len(flat_preds)):
+        if flat_labels[i] == flat_preds[i] and flat_labels[i] != 0:
+            good_label += 1
 
-print(np.sum(pred[0][:len(tags)] == tags) / len(tags))
+    return good_label / len(flat_labels)
 
 
+print(categorical_accuracy(np.array(pred), np.array(tags)))
diff --git a/train.py b/train.py
index e1d34ea0e5ab7293519709701fe3e8a4dc7c7feb..1188d8adc1a89b2d633a5a6d0f61a43d77a9f23d 100644
--- a/train.py
+++ b/train.py
@@ -11,7 +11,7 @@ def load_obj(name):
 file_path = 'Datasets/m2_dataset_V2.csv'
 
 
-df = read_csv_pgbar(file_path,100)
+df = read_csv_pgbar(file_path,50)
 
 
 texts = df['X'].tolist()
@@ -31,7 +31,7 @@ tagger = SuperTagger()
 
 tagger.create_new_model(len(index_to_super),'camembert-base',index_to_super)
 
-tagger.train(texts,tags,validation_rate=0,tensorboard=True,checkpoint=True)
+tagger.train(texts,tags,validation_rate=0.1,tensorboard=True,checkpoint=True)
 
 pred = tagger.predict(test_s)