From 90032fa2c3393cf589a7f61e671770e58bb7705f Mon Sep 17 00:00:00 2001
From: PNRIA - Julien <julien.rabault@irit.fr>
Date: Mon, 9 May 2022 14:49:41 +0200
Subject: [PATCH] v.0.0

---
 .gitignore                                    |   3 +
 Configuration/config.ini                      |   8 +-
 {Utils => Datasets/Utils}/PostpreprocesTXT.py |   7 +-
 {Utils => Datasets/Utils}/m2.txt              |   0
 Datasets/index_to_pos1.pkl                    | Bin 255 -> 0 bytes
 Datasets/index_to_pos2.pkl                    | Bin 350 -> 0 bytes
 SuperTagger/EncoderDecoder.py                 | 103 ------
 SuperTagger/SuperTagger.py                    | 268 +++++++++++++++
 SuperTagger/SymbolTokenizer.py                |  56 ----
 .../SentencesTokenizer.py}                    |  10 +-
 SuperTagger/Utils/SymbolTokenizer.py          |  43 +++
 SuperTagger/Utils/Tagging_bert_model.py       |  40 +++
 SuperTagger/Utils/utils.py                    |  29 ++
 SuperTagger/eval.py                           |  46 ---
 SuperTagger/utils.py                          |  66 ----
 main.py                                       |  63 ++++
 train.py                                      | 314 ++----------------
 17 files changed, 482 insertions(+), 574 deletions(-)
 rename {Utils => Datasets/Utils}/PostpreprocesTXT.py (95%)
 rename {Utils => Datasets/Utils}/m2.txt (100%)
 delete mode 100644 Datasets/index_to_pos1.pkl
 delete mode 100644 Datasets/index_to_pos2.pkl
 delete mode 100644 SuperTagger/EncoderDecoder.py
 create mode 100644 SuperTagger/SuperTagger.py
 delete mode 100644 SuperTagger/SymbolTokenizer.py
 rename SuperTagger/{EncoderTokenizer.py => Utils/SentencesTokenizer.py} (79%)
 create mode 100644 SuperTagger/Utils/SymbolTokenizer.py
 create mode 100644 SuperTagger/Utils/Tagging_bert_model.py
 create mode 100644 SuperTagger/Utils/utils.py
 delete mode 100644 SuperTagger/eval.py
 delete mode 100644 SuperTagger/utils.py
 create mode 100644 main.py

diff --git a/.gitignore b/.gitignore
index 46fc94c..4da91f5 100644
--- a/.gitignore
+++ b/.gitignore
@@ -8,3 +8,6 @@ venv
 push pull texte
 logs
 Output
+.data
+TensorBoard
+models
diff --git a/Configuration/config.ini b/Configuration/config.ini
index 802a83f..3fb0157 100644
--- a/Configuration/config.ini
+++ b/Configuration/config.ini
@@ -8,11 +8,11 @@ dropout=0.1
 teacher_forcing=0.05
 [MODEL_TRAINING]
 batch_size=16
-epoch=20
+epoch=10
 seed_val=42
 learning_rate=0.005
-use_checkpoint_SAVE=0
+use_checkpoint_SAVE=1
 output_path=Output
-use_checkpoint_LOAD=0
-input_path=Input
+use_checkpoint_LOAD=1
+input_path=models_save
 model_to_load=model_check.pt
\ No newline at end of file
diff --git a/Utils/PostpreprocesTXT.py b/Datasets/Utils/PostpreprocesTXT.py
similarity index 95%
rename from Utils/PostpreprocesTXT.py
rename to Datasets/Utils/PostpreprocesTXT.py
index fbfcb8d..9c388d0 100644
--- a/Utils/PostpreprocesTXT.py
+++ b/Datasets/Utils/PostpreprocesTXT.py
@@ -97,7 +97,8 @@ def read_maxentdata(file):
                     partsofspeech1.add(pos1)
                     partsofspeech2.add(pos2)
                     superset.add(supertag)
-                    words +=  ' ' +(str(orig_word))
+                    # words +=  ' ' +(str(orig_word))
+                    words += ' ' + (str(orig_word))
                     postags1.append(pos1)
                     postags2.append(pos2)
                     supertags.append(supertag)
@@ -135,7 +136,7 @@ df['Y1'] = Y1
 df['Y2'] = Y2
 df['Z'] = Z
 
-df.to_csv("../Datasets/m2_dataset_V2.csv", index=False)
+df.to_csv("../m2_dataset_V2.csv", index=False)
 
 
 t =  np.unique(np.array(list(itertools.chain(*Z))))
@@ -144,4 +145,4 @@ print(t.size)
 
 dict = { i : t[i] for i in range(0, len(t) ) }
 
-save_obj(dict,"../Datasets/index_to_super")
\ No newline at end of file
+save_obj(dict,"../index_to_super")
\ No newline at end of file
diff --git a/Utils/m2.txt b/Datasets/Utils/m2.txt
similarity index 100%
rename from Utils/m2.txt
rename to Datasets/Utils/m2.txt
diff --git a/Datasets/index_to_pos1.pkl b/Datasets/index_to_pos1.pkl
deleted file mode 100644
index 005a0bc309f087dd5b4af75a12d59239b9b723b6..0000000000000000000000000000000000000000
GIT binary patch
literal 0
HcmV?d00001

literal 255
zcmZo*nfiqR0&1sdcr*4eIR{ViX6j)L2=Wj2nBvXc!|WFjFvXjtht<V31SG=R!x9kW
z>Ndrjt%o^4+hvM3dk<@Xzn^o+6mO0m#(*i_oIT9WK0#Bwxq6rbg8Zj=bN8?YXoFci
zJ*<u{VIV_zdzk#3r+D-Au!MQ~22Anh2kG&j;w{j_7&gUQ5G3s9HpN@0hdC@DXo|OR
z50h)i6mJobs^BT!qCHH`&QrX_dKf*Yc#HQiyMQDmKxzY~cuV%M1qAsAx%y1;mg-@4
fbn%+vE#1Qu7CgmU2E+}U;w=kuhF7@9lu|tavtCO8

diff --git a/Datasets/index_to_pos2.pkl b/Datasets/index_to_pos2.pkl
deleted file mode 100644
index ab2916fbf6a280de8ecbffdf6108dcad81cce6a2..0000000000000000000000000000000000000000
GIT binary patch
literal 0
HcmV?d00001

literal 350
zcmZo*nHtQ<00y;FG`tymm;-_Wrg$^;aD=%ASrupI7EJMGhHx@-3(}@|v-B|g`TI`s
zX6<1P2=$xd&DO&n5ae&=>E|-Vn;pV-arK?z%>iKt_y<q%=Imki_V=6O&DF#1;u>P*
z7!)$ao4beE(Isq(H%||HK&YQpa%RaCZ(gV#*Ptohd=N_uic*WGc=Lmt<v+z+poiHn
z)OU)vAVe}VFD-M5w-AJrR$5Xz#akG{NzTtpnc^)1;S`r96-@CK1sNV-m6BRA#apb0
z+0W5;inn+Vv!`Fk6mJQToWE6&tIrf~Nr<L`f}+$Z-cmixj!r>Syrp}XgCl*Xc*}rX
e?KQ<)7UC1I@8n<_Qj4Z|%R@Msxdo|HO7#F9gKS^`

diff --git a/SuperTagger/EncoderDecoder.py b/SuperTagger/EncoderDecoder.py
deleted file mode 100644
index 1c8e510..0000000
--- a/SuperTagger/EncoderDecoder.py
+++ /dev/null
@@ -1,103 +0,0 @@
-import torch
-from torch import nn
-from torch.nn import Dropout, LSTM
-from torch.nn import Module
-
-from Configuration import Configuration
-from torch.nn.utils.rnn import pack_padded_sequence
-from SuperTagger.eval import measure_supertagging_accuracy
-
-
-class EncoderDecoder(Module):
-    """
-    A standard Encoder-Decoder architecture. Base for this and many 
-    other models.
-
-    decoder : instance of Decoder
-    """
-
-    def __init__(self, BASE_MODEL, numPos1Classes, numPos2Classes, numSuperClasses):
-        super(EncoderDecoder, self).__init__()
-
-        self.bert = BASE_MODEL
-
-        self.dim_encoder = int(Configuration.modelDecoderConfig['dim_encoder'])
-        self.dim_decoder = int(Configuration.modelDecoderConfig['dim_decoder'])
-        self.num_rnn_layers = int(Configuration.modelDecoderConfig['num_rnn_layers'])
-        self.bidirectional = True
-        dropout = float(Configuration.modelDecoderConfig['dropout'])
-        self.dropout = Dropout(dropout)
-
-        self.bert = BASE_MODEL
-
-        self.lstm_shared = LSTM(input_size=self.dim_encoder, hidden_size=self.dim_encoder, num_layers=self.num_rnn_layers,
-                        dropout=dropout,
-                        bidirectional=self.bidirectional, batch_first=True, )
-
-        #Pos1
-        self.pos1_1 = nn.Linear(self.dim_encoder,self.dim_decoder)
-        self.pos1_2 = nn.Linear(self.dim_decoder, numPos1Classes)
-
-        #Pos2
-        self.pos2_1 = nn.Linear(self.dim_encoder, self.dim_decoder)
-        self.pos2_2 = nn.Linear(self.dim_decoder, numPos2Classes)
-
-        #super
-        self.lstm_super = LSTM(input_size=self.dim_encoder*2, hidden_size=self.dim_encoder,
-                                num_layers=self.num_rnn_layers,
-                                dropout=dropout,
-                                bidirectional=self.bidirectional, batch_first=True, )
-        self.pos_super_1 = nn.Linear(self.dim_encoder,self.dim_decoder)
-        self.pos_super_2 = nn.Linear(self.dim_decoder, numSuperClasses)
-
-
-
-    def forward(self, batch):
-
-        b_input_ids = batch[0]
-        b_input_mask = batch[1]
-
-        encoded_layers, _ = self.bert(
-            input_ids=b_input_ids, attention_mask=b_input_mask, return_dict=False)
-
-        lstm_output = self.dropout(encoded_layers)
-
-        print("encoded_layers : ", encoded_layers.size())
-
-        # lstm_output, _ = self.lstm_shared(encoded_layers)  ## extract the 1st token's embeddings
-        # print("last_hidden : ", lstm_output.size())
-        #
-        # print("output_shared : ", lstm_output.size())
-
-        # Pos1
-        pos_1_output= self.pos1_1(lstm_output)
-        pos_1_output = self.dropout(pos_1_output)
-        pos_1_output = self.pos1_2(pos_1_output)
-
-        # Pos1
-        pos_2_output = self.pos2_1(lstm_output)
-        pos_2_output = self.dropout(pos_2_output)
-        pos_2_output = self.pos2_2(pos_2_output)
-
-        # super
-        # enc_hiddens, _ = self.lstm_super(lstm_output)
-        super_output = self.pos_super_1(lstm_output)
-        super_output = self.dropout(super_output)
-        super_output = self.pos_super_2(super_output)
-
-
-        return pos_1_output, pos_2_output, super_output
-
-    def predict(self, sents_tokenized_batch, sents_mask_batch):
-        r"""Predicts the symbols for each sentence in sents_tokenized_batch.
-
-        Args:
-            sents_tokenized_batch: [batch_size, max_len_sentence] the tokenized sentences
-            sents_mask_batch : mask output from the encoder tokenizer
-        """
-        last_hidden_state, pooler_output = self.encoder([sents_tokenized_batch, sents_mask_batch])
-        last_hidden_state = self.dropout(last_hidden_state)
-
-        predictions = self.decoder.predict_rnn(last_hidden_state, pooler_output)
-
-        return predictions
\ No newline at end of file
diff --git a/SuperTagger/SuperTagger.py b/SuperTagger/SuperTagger.py
new file mode 100644
index 0000000..3222dcf
--- /dev/null
+++ b/SuperTagger/SuperTagger.py
@@ -0,0 +1,268 @@
+import os
+import sys
+
+import time
+import datetime
+from datetime import datetime as dt
+import datetime
+
+import numpy as np
+import torch
+from torch.utils.tensorboard import SummaryWriter
+from tqdm import tqdm
+from transformers import AutoTokenizer
+
+from torch.utils.data import Dataset, TensorDataset, random_split
+
+from SuperTagger.Utils.SentencesTokenizer import SentencesTokenizer
+from SuperTagger.Utils.SymbolTokenizer import SymbolTokenizer
+from SuperTagger.Utils.Tagging_bert_model import Tagging_bert_model
+
+
+def categorical_accuracy(preds, truth):
+    flat_preds = preds.flatten()
+    flat_labels = truth.flatten()
+
+    return np.sum(flat_preds == flat_labels) / len(flat_labels)
+
+def format_time(elapsed):
+    '''
+    Takes a time in seconds and returns a string hh:mm:ss
+    '''
+    # Round to the nearest second.
+    elapsed_rounded = int(round(elapsed))
+
+    # Format as hh:mm:ss
+    return str(datetime.timedelta(seconds=elapsed_rounded))
+
+
+class SuperTagger:
+
+    def __init__(self):
+
+        self.index_to_tags = None
+        self.num_label = None
+        self.bert_name = None
+        self.sent_tokenizer = None
+        self.tags_tokenizer = None
+        self.model = None
+
+        self.optimizer = None
+        self.loss = None
+
+        self.epoch_i = 0
+        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+        self.trainable = False
+        self.model_load = False
+
+
+    def load_weights(self, model_file):
+        self.trainable = False
+
+        print("#" * 15)
+        try:
+            params = torch.load(model_file , map_location=self.device)
+            args = params['args']
+            self.bert_name = args['bert_name']
+            self.index_to_tags = args['index_to_tags']
+            self.num_label = len(self.index_to_tags)
+            self.model = Tagging_bert_model(self.bert_name, self.num_label)
+            self.tags_tokenizer = SymbolTokenizer(self.index_to_tags)
+            self.sent_tokenizer = SentencesTokenizer(AutoTokenizer.from_pretrained(
+                self.bert_name,
+                do_lower_case=True))
+            self.model.load_state_dict(params['state_dict'])
+            self.optimizer=params['optimizer']
+            self.epoch_i = args['epoch']
+            print("\n The loading checkpoint was successful ! \n")
+            print("\tBert model : ", self.bert_name)
+            print("\tLast epoch : ", self.epoch_i)
+            print()
+        except Exception as e:
+            print("\n/!\ Can't load checkpoint model /!\ because :\n\n " + str(e), file=sys.stderr)
+            raise e
+        print("#" * 15)
+
+        self.model_load = True
+        self.trainable = True
+
+    def create_new_model(self, num_label, bert_name, index_to_tags: dict):
+
+        assert len(
+            index_to_tags) == num_label, f" len(index_to_tags) : {len(index_to_tags)} must be equels with num_label: {num_label}"
+
+        self.model = Tagging_bert_model(bert_name, num_label+1)
+        index_to_tags = {k + 1: v for k, v in index_to_tags.items()}
+        index_to_tags[0] = '<unk>'
+        self.index_to_tags = index_to_tags
+        self.bert_name = bert_name
+        self.sent_tokenizer = SentencesTokenizer(AutoTokenizer.from_pretrained(
+            bert_name,
+            do_lower_case=True))
+        self.optimizer = torch.optim.Adam(params=self.model.parameters(), lr=2e-05)
+        self.tags_tokenizer = SymbolTokenizer(index_to_tags)
+        self.trainable = True
+        self.model_load = True
+
+    def predict(self, sentences):
+
+        assert self.trainable or self.model is None, "Please use the create_new_model(...) or load_weights(...) function before the predict, the model is not integrated"
+
+        sents_tokenized_t, sents_mask_t = self.sent_tokenizer.fit_transform_tensors(sentences)
+
+        self.model = self.model.cpu()
+
+        pred = self.model.predict((sents_tokenized_t, sents_mask_t))
+
+        return self.tags_tokenizer.convert_ids_to_tags(pred.detach())
+
+    def train(self, sentences, tags, validation_rate=0.1, epochs=20, batch_size=32, tensorboard=False,
+              checkpoint=False):
+
+        assert self.trainable or self.model is None, "Please use the create_new_model(...) or load_weights(...) function before the train, the model is not integrated"
+
+        if checkpoint or tensorboard:
+            checkpoint_dir, writer = self.__output_create()
+
+        training_dataloader, validation_dataloader = self.__preprocess_data(batch_size, sentences, tags,
+                                                                            1-validation_rate)
+        epochs = epochs - self.epoch_i
+        self.model = self.model.to(self.device)
+        self.model.train()
+
+        for epoch_i in range(0, epochs):
+            print("")
+            print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs))
+            print('Training...')
+
+            epoch_acc, epoch_loss, training_time = self.__train_epoch(training_dataloader)
+
+            if validation_rate>0.0:
+                eval_accuracy, eval_loss, nb_eval_steps = self.__eval_epoch(validation_dataloader)
+
+            print("")
+            print(f'Epoch: {epoch_i + 1:02} | Epoch Time: {training_time}')
+            print(f'\tTrain Loss: {epoch_loss:.3f} | Train Acc: {epoch_acc * 100:.2f}%')
+            if validation_rate > 0.0:
+                print(f'\tVal Loss: {eval_loss:.3f} | Val Acc: {eval_accuracy * 100:.2f}%')
+
+            if tensorboard:
+                writer.add_scalars(f'Train_Accuracy/Loss', {
+                    'Accuracy_train': epoch_acc,
+                    'Loss_train': epoch_loss}, epoch_i + 1)
+                if validation_rate > 0.0:
+                    writer.add_scalars(f'Validation_Accuracy/Loss', {
+                        'Accuracy_val': eval_accuracy,
+                        'Loss_val': eval_loss,}, epoch_i + 1)
+
+            if checkpoint:
+                self.__checkpoint_save(path=os.path.join(checkpoint_dir, 'model_check.pt'))
+
+    def __preprocess_data(self, batch_size, sentences, tags, validation_rate):
+
+        validation_dataloader=None
+
+        sents_tokenized_t, sents_mask_t = self.sent_tokenizer.fit_transform_tensors(sentences)
+        tags_t = self.tags_tokenizer.convert_batchs_to_ids(tags, sents_tokenized_t)
+        dataset = TensorDataset(sents_tokenized_t, sents_mask_t, tags_t)
+        train_size = int(validation_rate * len(dataset))
+        print('{:>5,} training samples'.format(train_size))
+
+        if validation_rate>0:
+            val_size = len(dataset) - train_size
+            train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
+            print('{:>5,} validation samples'.format(val_size))
+            validation_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
+        else:
+            train_dataset = dataset
+        training_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
+        return training_dataloader, validation_dataloader
+
+    def __output_create(self):
+        from datetime import datetime
+        outpout_path = 'TensorBoard'
+        training_dir = os.path.join(outpout_path, 'Tranning_' + datetime.today().strftime('%d-%m_%H-%M'))
+        logs_dir = os.path.join(training_dir, 'logs')
+        writer = SummaryWriter(log_dir=logs_dir)
+        return training_dir, writer
+
+    def __train_epoch(self, training_dataloader):
+        epoch_loss = 0
+        epoch_acc = 0
+        t0 = time.time()
+        i = 0
+        with tqdm(training_dataloader, unit="batch") as tepoch:
+            for batch in tepoch:
+
+                # Unpack this training batch from our dataloader.
+                b_sents_tokenized = batch[0].to(self.device)
+                b_sents_mask = batch[1].to(self.device)
+                targets = batch[2].to(self.device)
+
+                self.optimizer.zero_grad()
+
+                loss, logit = self.model((b_sents_tokenized, b_sents_mask, targets))
+
+                acc = categorical_accuracy(np.argmax(logit.detach().cpu().numpy(), axis=2), targets.detach().cpu().numpy())
+
+                epoch_acc += acc.item()
+                epoch_loss += loss.item()
+
+                loss.backward()
+
+                self.optimizer.step()
+                i+=1
+
+
+        # Measure how long this epoch took.
+        training_time = format_time(time.time() - t0)
+
+        epoch_acc = epoch_acc / len(training_dataloader)
+        epoch_loss = epoch_loss / len(training_dataloader)
+
+        return epoch_acc, epoch_loss, training_time
+
+    def __eval_epoch(self, validation_dataloader):
+        self.model.eval()
+        eval_loss = 0
+        eval_accuracy = 0
+        predictions, true_labels = [], []
+        nb_eval_steps, nb_eval_examples = 0, 0
+        with torch.no_grad():
+            print("Start eval")
+            for step, batch in enumerate(validation_dataloader):
+                b_sents_tokenized = batch[0].to(self.device)
+                b_sents_mask = batch[1].to(self.device)
+                b_symbols_tokenized = batch[2].to(self.device)
+
+                logits = self.predict((b_sents_tokenized, b_sents_mask, b_symbols_tokenized))
+
+                logits = logits.detach().cpu().numpy()
+                label_ids = b_symbols_tokenized.cpu().numpy()
+
+                predictions.extend([list(p) for p in np.argmax(logits, axis=2)])
+                true_labels.append(label_ids)
+                accuracy = categorical_accuracy(logits, label_ids)
+                print(logits[0][:50])
+                print(label_ids[0][:50])
+                #eval_loss += loss.mean().item()
+                eval_accuracy += accuracy
+                nb_eval_examples += b_sents_tokenized.size(0)
+                nb_eval_steps += 1
+
+            eval_loss = eval_loss / nb_eval_steps
+            eval_accuracy = eval_accuracy / nb_eval_steps
+        return eval_accuracy, eval_loss, nb_eval_steps
+
+    def __checkpoint_save(self, path='/model_check.pt'):
+        self.model.cpu()
+        # print('save model parameters to [%s]' % path, file=sys.stderr)
+
+        torch.save({
+            'args': dict(bert_name=self.bert_name, index_to_tags=self.index_to_tags, epoch=self.epoch_i),
+            'state_dict': self.model.state_dict(),
+            'optimizer': self.optimizer,
+        }, path)
+        self.model.to(self.device)
+
diff --git a/SuperTagger/SymbolTokenizer.py b/SuperTagger/SymbolTokenizer.py
deleted file mode 100644
index 8a4948c..0000000
--- a/SuperTagger/SymbolTokenizer.py
+++ /dev/null
@@ -1,56 +0,0 @@
-import pickle
-
-import numpy as np
-import torch
-
-
-def load_obj(name):
-    with open(name + '.pkl', 'rb') as f:
-        return pickle.load(f)
-
-
-class SymbolTokenizer():
-
-    def __init__(self):
-        """@params tokenizer (PretrainedTokenizer): Tokenizer that tokenizes text """
-        self.index_to_super = load_obj('Datasets/index_to_super')
-        self.index_to_pos1 = load_obj('Datasets/index_to_pos1')
-        self.index_to_pos2 = load_obj('Datasets/index_to_pos2')
-        self.super_to_index = {v: int(k) for k, v in self.index_to_super.items()}
-        self.pos1_to_index = {v: int(k) for k, v in self.index_to_pos1.items()}
-        self.pos2_to_index = {v: int(k) for k, v in self.index_to_pos2.items()}
-
-    def lenPOS1(self):
-        print(self.pos1_to_index)
-        return len(self.index_to_pos1) + 1
-
-    def lenPOS2(self):
-        return len(self.index_to_pos2) + 1
-
-    def lenSuper(self):
-        return len(self.index_to_super) + 1
-
-    def convert_batchs_to_ids(self, Y1, Y2, Super):
-        max_len_Y1 = max(len(elem) for elem in Y1)
-        max_len_Y2 = max(len(elem) for elem in Y2)
-        max_len_S = max(len(elem) for elem in Super)
-        Y1_tok = torch.as_tensor(pad_sequence([[self.pos1_to_index[str(symbol)] for symbol in sents] for sents in Y1]))
-        Y2_tok = torch.as_tensor(pad_sequence(
-            [[self.pos2_to_index[str(symbol)] for symbol in sents] for sents in Y2]))
-        super_tok = torch.as_tensor(pad_sequence(
-            [[self.super_to_index[str(symbol)] for symbol in sents] for sents in Super]))
-
-        return Y1_tok, Y2_tok, super_tok
-
-    # def convert_ids_to_symbols(self, ids):
-    #     return [self.inverse_symbol_map[int(i)] for i in ids]
-
-def pad_sequence(sequences, max_len=400):
-    sequences_pad = []
-    for s in sequences:
-        padded = [0] * max_len
-        padded[:len(s)] = s
-        sequences_pad.append(padded)
-    return sequences_pad
-
-
diff --git a/SuperTagger/EncoderTokenizer.py b/SuperTagger/Utils/SentencesTokenizer.py
similarity index 79%
rename from SuperTagger/EncoderTokenizer.py
rename to SuperTagger/Utils/SentencesTokenizer.py
index 865f5a7..7aee1d4 100644
--- a/SuperTagger/EncoderTokenizer.py
+++ b/SuperTagger/Utils/SentencesTokenizer.py
@@ -1,17 +1,21 @@
+import numpy as np
 import torch
 
 
-class EncoderTokenizer():
+class SentencesTokenizer():
 
     def __init__(self, tokenizer):
         """@params tokenizer (PretrainedTokenizer): Tokenizer that tokenizes text """
         self.tokenizer = tokenizer
 
     def fit_transform(self, sents):
-        return self.tokenizer(sents, padding=True,)
+        return self.tokenizer(sents, padding=True)
 
     def fit_transform_tensors(self, sents):
-        temp = self.tokenizer(sents, padding=True, return_tensors='pt', )
+        # , return_tensors = 'pt'
+        temp = self.tokenizer(sents, padding=True, return_tensors = 'pt')
+
+
         return temp['input_ids'], temp['attention_mask']
 
     def convert_ids_to_tokens(self, inputs_ids, skip_special_tokens=False):
diff --git a/SuperTagger/Utils/SymbolTokenizer.py b/SuperTagger/Utils/SymbolTokenizer.py
new file mode 100644
index 0000000..48543da
--- /dev/null
+++ b/SuperTagger/Utils/SymbolTokenizer.py
@@ -0,0 +1,43 @@
+import pickle
+
+import numpy as np
+import torch
+
+
+def load_obj(name):
+    with open(name + '.pkl', 'rb') as f:
+        return pickle.load(f)
+
+
+class SymbolTokenizer():
+
+
+    def __init__(self, index_to_super):
+        """@params tokenizer (PretrainedTokenizer): Tokenizer that tokenizes text """
+        self.index_to_super = index_to_super
+        self.super_to_index = {v: int(k) for k, v in self.index_to_super.items()}
+
+
+    def lenSuper(self):
+        return len(self.index_to_super) + 1
+
+    def convert_batchs_to_ids(self, tags, sents_tokenized):
+        encoded_labels = []
+        labels = [[self.super_to_index[str(symbol)] for symbol in sents] for sents in tags]
+        for l, s in zip(labels, sents_tokenized):
+            super_tok = torch.tensor(pad_sequence(l,len(s)))
+            encoded_labels.append(super_tok.tolist())
+
+        return torch.tensor(encoded_labels)
+
+    def convert_ids_to_tags(self, tags_ids):
+        labels = [[self.index_to_super[int(symbol)] for symbol in sents if self.index_to_super[int(symbol)] != '<unk>'] for sents in tags_ids]
+
+        return labels
+
+def pad_sequence(sequences, max_len=400):
+    padded = [0] * max_len
+    padded[:len(sequences)] = sequences
+    return padded
+
+
diff --git a/SuperTagger/Utils/Tagging_bert_model.py b/SuperTagger/Utils/Tagging_bert_model.py
new file mode 100644
index 0000000..7a896ca
--- /dev/null
+++ b/SuperTagger/Utils/Tagging_bert_model.py
@@ -0,0 +1,40 @@
+import torch
+import transformers
+from torch.nn import Module
+
+from transformers import logging
+
+
+class Tagging_bert_model(Module):
+    """
+    A standard Encoder-Decoder architecture. Base for this and many 
+    other models.
+
+    decoder : instance of Decoder
+    """
+
+    def __init__(self, bert_name, num_labels):
+        super(Tagging_bert_model, self).__init__()
+        self.bert_name = bert_name
+        self.num_labels = num_labels
+        self.bert = transformers.BertForTokenClassification.from_pretrained(bert_name, num_labels=num_labels)
+
+    def forward(self, batch):
+        b_input_ids = batch[0]
+        b_input_mask = batch[1]
+        labels = batch[2]
+
+        output = self.bert(
+            input_ids=b_input_ids, attention_mask=b_input_mask, labels=labels)
+        loss, logits = output[:2]
+
+        return loss, logits
+
+    def predict(self, batch):
+        b_input_ids = batch[0]
+        b_input_mask = batch[1]
+
+        output = self.bert(
+            input_ids=b_input_ids, attention_mask=b_input_mask)
+
+        return torch.argmax(output[0], dim=2)
diff --git a/SuperTagger/Utils/utils.py b/SuperTagger/Utils/utils.py
new file mode 100644
index 0000000..efe708c
--- /dev/null
+++ b/SuperTagger/Utils/utils.py
@@ -0,0 +1,29 @@
+import datetime
+
+import pandas as pd
+import torch
+from tqdm import tqdm
+
+
+def read_csv_pgbar(csv_path, nrows=float('inf'), chunksize=100):
+    print("\n" + "#" * 20)
+    print("Loading csv...")
+
+    rows = sum(1 for _ in open(csv_path, 'r', encoding="utf8")) - 1  # minus the header
+    chunk_list = []
+
+    if rows > nrows:
+        rows = nrows
+
+    with tqdm(total=rows, desc='Rows read: ') as bar:
+        for chunk in pd.read_csv(csv_path, converters={'Y1': pd.eval,'Y2': pd.eval,'Z': pd.eval}, chunksize=chunksize, nrows=rows):
+            chunk_list.append(chunk)
+            bar.update(len(chunk))
+
+    df = pd.concat((f for f in chunk_list), axis=0)
+    print("#" * 20)
+
+    return df
+
+
+
diff --git a/SuperTagger/eval.py b/SuperTagger/eval.py
deleted file mode 100644
index dd99757..0000000
--- a/SuperTagger/eval.py
+++ /dev/null
@@ -1,46 +0,0 @@
-import torch
-from torch import Tensor
-from torch.nn import Module
-from torch.nn.functional import cross_entropy
-
-
-# Another from Kokos function to calculate the accuracy of our predictions vs labels
-def measure_supertagging_accuracy(pred, truth, ignore_idx=0):
-    r"""Evaluation of the decoder's output and the true symbols without considering the padding in the true targets.
-
-     Args:
-         pred: [batch_size, max_symbols_in_sentence] prediction of symbols for each symbols.
-         truth: [batch_size, max_symbols_in_sentence] true values of symbols
-     """
-    correct_symbols = torch.ones(pred.size())
-    correct_symbols[pred != truth] = 0
-    correct_symbols[truth == ignore_idx] = 1
-    num_correct_symbols = correct_symbols.sum().item()
-    num_masked_symbols = len(truth[truth == ignore_idx])
-
-    correct_sentences = correct_symbols.prod(dim=1)
-    num_correct_sentences = correct_sentences.sum().item()
-
-    return (num_correct_sentences, pred.shape[0]), \
-           (num_correct_symbols - num_masked_symbols, pred.shape[0] * pred.shape[1] - num_masked_symbols)
-
-
-def count_sep(xs, sep_id, dim=-1):
-    return xs.eq(sep_id).sum(dim=dim)
-
-
-class NormCrossEntropy(Module):
-    r"""Loss based on the cross entropy, it considers the number of words and ignore the padding.
-     """
-
-    def __init__(self, ignore_index, weights=None):
-        super(NormCrossEntropy, self).__init__()
-        self.ignore_index = ignore_index
-        self.weights = weights
-
-    def forward(self, predictions, truths):
-        print()
-        print("predictions : ", predictions.size())
-        print("truths : ", truths.size())
-        return cross_entropy(predictions.flatten(0, -2), truths.flatten(), weight=torch.tensor(self.weights,device="cuda" if torch.cuda.is_available() else "cpu"),
-        reduction='sum', ignore_index=self.ignore_index)
diff --git a/SuperTagger/utils.py b/SuperTagger/utils.py
deleted file mode 100644
index e257eb7..0000000
--- a/SuperTagger/utils.py
+++ /dev/null
@@ -1,66 +0,0 @@
-import datetime
-
-import pandas as pd
-import torch
-from tqdm import tqdm
-
-
-def read_csv_pgbar(csv_path, nrows=float('inf'), chunksize=100):
-    print("\n" + "#" * 20)
-    print("Loading csv...")
-
-    rows = sum(1 for _ in open(csv_path, 'r', encoding="utf8")) - 1  # minus the header
-    chunk_list = []
-
-    if rows > nrows:
-        rows = nrows
-        chunksize = nrows
-
-    with tqdm(total=rows, desc='Rows read: ') as bar:
-        for chunk in pd.read_csv(csv_path, converters={'Y1': pd.eval,'Y2': pd.eval,'Z': pd.eval}, chunksize=chunksize, nrows=rows):
-            chunk_list.append(chunk)
-            bar.update(len(chunk))
-
-    df = pd.concat((f for f in chunk_list), axis=0)
-    print("#" * 20)
-
-    return df
-
-
-def format_time(elapsed):
-    '''
-    Takes a time in seconds and returns a string hh:mm:ss
-    '''
-    # Round to the nearest second.
-    elapsed_rounded = int(round(elapsed))
-
-    # Format as hh:mm:ss
-    return str(datetime.timedelta(seconds=elapsed_rounded))
-
-
-def checkpoint_save(model, opt, epoch, dir, loss):
-    torch.save({
-        'epoch': epoch,
-        'model_state_dict': model.state_dict(),
-        'optimizer_state_dict': opt.state_dict(),
-        'loss': loss,
-    }, dir + '/model_check.pt')
-
-
-def checkpoint_load(model, opt, path):
-    epoch = 0
-    loss = 0
-    print("#" * 15)
-
-    try:
-        checkpoint = torch.load(path)
-        model.load_state_dict(checkpoint['model_state_dict'])
-        opt.load_state_dict(checkpoint['optimizer_state_dict'])
-        epoch = checkpoint['epoch']
-        loss = checkpoint['loss']
-        print("\n The loading checkpoint was successful  ! \n")
-        print("#" * 10)
-    except Exception as e:
-        print("\nCan't load checkpoint model because : " + str(e) + "\n\nUse default model \n")
-    print("#" * 15)
-    return model, opt, epoch, loss
diff --git a/main.py b/main.py
new file mode 100644
index 0000000..846264b
--- /dev/null
+++ b/main.py
@@ -0,0 +1,63 @@
+import numpy as np
+from transformers import AutoTokenizer
+
+from SuperTagger.SuperTagger import SuperTagger
+from SuperTagger.Utils.SentencesTokenizer import SentencesTokenizer
+from SuperTagger.Utils.SymbolTokenizer import SymbolTokenizer
+from SuperTagger.Utils.utils import read_csv_pgbar
+
+def categorical_accuracy(preds, truth):
+    flat_preds = preds[:len(truth)].flatten()
+    flat_labels = truth.flatten()
+    return np.sum(flat_preds == flat_labels) / len(flat_labels)
+
+def load_obj(name):
+    with open(name + '.pkl', 'rb') as f:
+        import pickle
+        return pickle.load(f)
+
+
+file_path = 'Datasets/m2_dataset_V2.csv'
+
+
+df = read_csv_pgbar(file_path, 10)
+
+texts = df['X'].tolist()
+tags = df['Z'].tolist()
+
+
+texts = texts[:1]
+tags = tags[:1]
+
+
+tagger = SuperTagger()
+
+tagger.load_weights("models/model_check.pt")
+#
+# sent_tokenizer = SentencesTokenizer(AutoTokenizer.from_pretrained(
+#                 "camembert-base",
+#                 do_lower_case=True))
+# #
+# from tokenizers.pre_tokenizers import Whitespace
+# pre_tokenizer = Whitespace()
+# print(pre_tokenizer.pre_tokenize_str("Lyonnaise-Dumez frite"))
+
+# sents_tokenized_t, sents_mask_t = sent_tokenizer.fit_transform_tensors(["Lyonnaise-Dumez frite", "Lyonnaise-Dumez vient d' hispaniser sa filiale espagnole "])
+#
+# print(sents_tokenized_t)
+# print(sents_mask_t)
+#
+#
+# print(sent_tokenizer.convert_ids_to_tokens(sents_tokenized_t))
+
+pred = tagger.predict(texts)
+
+print(tags)
+print()
+print(pred[0])
+
+print(pred[0][0] == tags[0])
+
+print(np.sum(pred[0][:len(tags)] == tags) / len(tags))
+
+
diff --git a/train.py b/train.py
index a46b4a5..dbf105a 100644
--- a/train.py
+++ b/train.py
@@ -1,314 +1,42 @@
-import os
-import time
-from datetime import datetime
+from SuperTagger.SuperTagger import SuperTagger
+from SuperTagger.Utils.utils import read_csv_pgbar
 
-import numpy as np
-import torch
-import torch.nn.functional as F
-import transformers
-from torch.optim import Adam, RMSprop
-from torch.utils.data import Dataset, TensorDataset, random_split
-from transformers import (AutoTokenizer, get_cosine_schedule_with_warmup)
-from transformers import (CamembertModel)
 
-from Configuration import Configuration
+def load_obj(name):
+    with open(name + '.pkl', 'rb') as f:
+        import pickle
+        return pickle.load(f)
 
-from SuperTagger.EncoderDecoder import EncoderDecoder
-from SuperTagger.EncoderTokenizer import EncoderTokenizer
-from SuperTagger.SymbolTokenizer import SymbolTokenizer
-from SuperTagger.eval import NormCrossEntropy
-from SuperTagger.utils import format_time, read_csv_pgbar
-
-from torch.utils.tensorboard import SummaryWriter
-
-transformers.TOKENIZERS_PARALLELISM = True
-torch.cuda.empty_cache()
-
-# region ParamsModel
-
-
-num_gru_layers = int(Configuration.modelDecoderConfig['num_rnn_layers'])
-
-# endregion ParamsModel
-
-# region ParamsTraining
 
 file_path = 'Datasets/m2_dataset_V2.csv'
-batch_size = int(Configuration.modelTrainingConfig['batch_size'])
-nb_sentences = batch_size * 50
-epochs = int(Configuration.modelTrainingConfig['epoch'])
-seed_val = int(Configuration.modelTrainingConfig['seed_val'])
-learning_rate = float(Configuration.modelTrainingConfig['learning_rate'])
-loss_scaled_by_freq = True
-
-# endregion ParamsTraining
-
-# region OutputTraining
-
-outpout_path = str(Configuration.modelTrainingConfig['output_path'])
 
-training_dir = os.path.join(outpout_path, 'Tranning_' + datetime.today().strftime('%d-%m_%H-%M'))
-logs_dir = os.path.join(training_dir, 'logs')
 
-checkpoint_dir = training_dir
-writer = SummaryWriter(log_dir=logs_dir)
+df = read_csv_pgbar(file_path,100)
 
-use_checkpoint_SAVE = bool(Configuration.modelTrainingConfig.getboolean('use_checkpoint_SAVE'))
 
-# endregion OutputTraining
+texts = df['X'].tolist()
+tags = df['Z'].tolist()
 
-# region InputTraining
+test_s = texts[:4]
+tags_s = tags[:4]
 
-input_path = str(Configuration.modelTrainingConfig['input_path'])
-model_to_load = str(Configuration.modelTrainingConfig['model_to_load'])
-model_to_load_path = os.path.join(input_path, model_to_load)
-use_checkpoint_LOAD = bool(Configuration.modelTrainingConfig.getboolean('use_checkpoint_LOAD'))
+texts = texts[4:]
+tags = tags[4:]
 
-# endregion InputTraining
 
-# region Print config
+index_to_super = load_obj('Datasets/index_to_super')
+super_to_index = {v: int(k) for k, v in index_to_super.items()}
 
-print("##" * 15 + "\nConfiguration : \n")
+tagger = SuperTagger()
 
-print("ParamsModel\n")
+tagger.create_new_model(len(index_to_super),'camembert-base',index_to_super)
 
-# print("\tmax_symbols_per_word :", max_symbols_per_word)
-# print("\tsymbol_vocab_size :", symbol_vocab_size)
-print("\tbidirectional : ", False)
-print("\tnum_gru_layers : ", num_gru_layers)
+tagger.train(texts,tags,tensorboard=True,checkpoint=True)
 
-print("\n ParamsTraining\n")
-
-print("\tDataset :", file_path)
-print("\tb_sentences :", nb_sentences)
-print("\tbatch_size :", batch_size)
-print("\tepochs :", epochs)
-print("\tseed_val :", seed_val)
-
-print("\n Output\n")
-print("\tuse checkpoint save :", use_checkpoint_SAVE)
-print("\tcheckpoint_dir :", checkpoint_dir)
-print("\tlogs_dir :", logs_dir)
-
-print("\n Input\n")
-print("\tModel to load :", model_to_load_path)
-print("\tLoad checkpoint :", use_checkpoint_LOAD)
-
-print("\nLoss and optimizer : ")
-
-print("\tlearning_rate :", learning_rate)
-print("\twith loss scaled by freq :", loss_scaled_by_freq)
-
-print("\n Device\n")
-device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-print("\t", device)
+pred = tagger.predict(test_s)
 
+print(test_s)
 print()
-print("##" * 15)
-
-# endregion Print config
-
-# region Model
-
-file_path = 'Datasets/m2_dataset_V2.csv'
-BASE_TOKENIZER = AutoTokenizer.from_pretrained(
-    'camembert-base',
-    do_lower_case=True)
-BASE_MODEL = CamembertModel.from_pretrained("camembert-base")
-sents_tokenizer = EncoderTokenizer(BASE_TOKENIZER)
-symbol_tokenizer = SymbolTokenizer()
-
-# endregion Model
-
-# region Data loader
-df = read_csv_pgbar(file_path, nb_sentences)
-
-sents_tokenized, sents_mask = sents_tokenizer.fit_transform_tensors(df['X'].tolist())
-
-y1, y2, super = symbol_tokenizer.convert_batchs_to_ids(df['Y1'].tolist(),df['Y2'].tolist(),df['Z'].tolist())
-
-dataset = TensorDataset(sents_tokenized, sents_mask, y1, y2, super)
-# , torch.tensor(df['Y1'].tolist()), torch.tensor(df['Y2'].tolist()), torch.tensor(df['Z'].tolist())
-
-# Calculate the number of samples to include in each set.
-train_size = int(0.9 * len(dataset))
-val_size = len(dataset) - train_size
-
-# Divide the dataset by randomly selecting samples.
-train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
-
-print('{:>5,} training samples'.format(train_size))
-print('{:>5,} validation samples'.format(val_size))
-
-training_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
-validation_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
-
-# endregion Data loader
-
-
-model = EncoderDecoder(BASE_MODEL, symbol_tokenizer.lenPOS1(),symbol_tokenizer.lenPOS2(),symbol_tokenizer.lenSuper())
-model = model.to("cuda" if torch.cuda.is_available() else "cpu")
-
-# region Fit tunning
-
-# Optimizer
-optimizer = RMSprop(model.parameters())
-
-# Total number of training steps is [number of batches] x [number of epochs].
-# (Note that this is not the same as the number of training samples).
-total_steps = len(training_dataloader) * epochs
-
-# Create the learning rate scheduler.
-# scheduler_encoder = get_cosine_schedule_with_warmup(optimizer_encoder,
-#                                                     num_warmup_steps=0,
-#                                                     num_training_steps=5)
-# scheduler_decoder = get_cosine_schedule_with_warmup(optimizer_decoder,
-#                                                     num_warmup_steps=0,
-#                                                     num_training_steps=total_steps)
-
-# # Loss
-cross_entropy_loss_Y1 = NormCrossEntropy(0,0.15)
-cross_entropy_loss_Y2 = NormCrossEntropy(0,.35)
-cross_entropy_loss_S = NormCrossEntropy(0,.5)
-
-np.random.seed(seed_val)
-torch.manual_seed(seed_val)
-torch.cuda.manual_seed_all(seed_val)
-torch.autograd.set_detect_anomaly(True)
-
-# endregion Fit tunning
-
-# region Train
-
-# Measure the total training time for the whole run.
-total_t0 = time.time()
-
-validate = True
-
-# if use_checkpoint_LOAD:
-#     model, optimizer_decoder, last_epoch, loss = checkpoint_load(model, optimizer_decoder, model_to_load_path)
-#     epochs = epochs - last_epoch
-
-
-def run_epochs(epochs):
-    # For each epoch...
-    for epoch_i in range(0, epochs):
-        # ========================================
-        #               Training
-        # ========================================
-
-        # Perform one full pass over the training set.
-
-        print("")
-        print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs))
-        print('Training...')
-
-        # Measure how long the training epoch takes.
-        t0 = time.time()
-
-        # Reset the total loss for this epoch.
-        total_train_loss_Y1 =0
-        total_train_loss_Y2 =0
-        total_train_loss_S =0
-
-        model.train()
-
-        # For each batch of training data...
-        for step, batch in enumerate(training_dataloader):
-
-            # if epoch_i == 0 and step == 0:
-            #     writer.add_graph(model, input_to_model=batch[0], verbose=False)
-
-            # Progress update every 10 batches.
-            if step % 10 == 0 and not step == 0:
-                # Calculate elapsed time in minutes.
-                elapsed = format_time(time.time() - t0)
-                # Report progress.
-                print('  Batch {:>5,}  of  {:>5,}.    Elapsed: {:}.'.format(step, len(training_dataloader), elapsed))
-
-                # Unpack this training batch from our dataloader.
-            b_sents_tokenized = batch[0].to("cuda" if torch.cuda.is_available() else "cpu")
-            b_sents_mask = batch[1].to("cuda" if torch.cuda.is_available() else "cpu")
-
-            optimizer.zero_grad()
-
-            logits_predictions = model((b_sents_tokenized, b_sents_mask))
-
-            output_dim_Y1 = logits_predictions[0].shape[1]
-            print(output_dim_Y1)
-            # output_Y1 = logits_predictions[0][1:].view(-1, output_dim_Y1)
-            output_dim_Y2 = logits_predictions[1].shape[1]
-            # output_Y2 = logits_predictions[1][1:].view(-1, output_dim_Y2)
-            output_dim_S = logits_predictions[2].shape[1]
-            # output_S = logits_predictions[2][1:].view(-1, output_dim_S)
-
-            loss_Y1 = cross_entropy_loss_Y1(logits_predictions[0], batch[2][:output_dim_Y1])
-            loss_Y2 = cross_entropy_loss_Y2(logits_predictions[1], batch[3][:output_dim_Y2])
-            loss_S = cross_entropy_loss_S(logits_predictions[2], batch[4][:output_dim_S])
-
-            total_train_loss_Y1 += float(loss_Y1)
-            total_train_loss_Y2 += float(loss_Y2)
-            total_train_loss_S += float(loss_S)
-
-            loss_Y1.backward()
-            loss_Y2.backward()
-            loss_S.backward()
-
-            # This is to help prevent the "exploding gradients" problem.
-            #torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0, norm_type=2)
-
-            # Update parameters and take a step using the computed gradient.
-            optimizer.step()
-            #
-            # scheduler_encoder.step()
-            # scheduler_decoder.step()
-
-        # checkpoint
-
-        # if use_checkpoint_SAVE:
-        #     checkpoint_save(model, optimizer_decoder, epoch_i, checkpoint_dir, loss)
-
-        avg_train_loss_Y1 = total_train_loss_Y1 / len(training_dataloader)
-        avg_train_loss_Y2 = total_train_loss_Y2 / len(training_dataloader)
-        avg_train_loss_S = total_train_loss_S / len(training_dataloader)
-
-        # Measure how long this epoch took.
-        training_time = format_time(time.time() - t0)
-
-        if validate:
-            model.eval()
-            with torch.no_grad():
-                print("Start eval")
-                # accuracy_sents, accuracy_symbol, v_loss = model.eval_epoch(validation_dataloader, cross_entropy_loss)
-                # print("")
-                # print("  Average accuracy sents on epoch: {0:.2f}".format(accuracy_sents))
-                # print("  Average accuracy symbol on epoch: {0:.2f}".format(accuracy_symbol))
-                # writer.add_scalar('Accuracy/sents', accuracy_sents, epoch_i + 1)
-                # writer.add_scalar('Accuracy/symbol', accuracy_symbol, epoch_i + 1)
-
-        print("")
-        print("  Average training loss: {0:.2f}".format(avg_train_loss_Y1))
-        print("  Average training loss: {0:.2f}".format(avg_train_loss_Y2))
-        print("  Average training loss: {0:.2f}".format(avg_train_loss_S))
-        print("  Training epcoh took: {:}".format(training_time))
-
-        # writer.add_scalar('Loss/train', total_train_loss, epoch_i+1)
-
-        # writer.add_scalars('Training vs. Validation Loss',
-        #                    {'Training': avg_train_loss, 'Validation': v_loss},
-        #                    epoch_i + 1)
-        writer.flush()
-
-
-# run_epochs(epochs)
-# endregion Train
-# b1, b2 , y1,y2,y3 = next(iter(training_dataloader))
-# b =(b1, b2)
-# # , y1,y2,y3
-# a = model(b)
-# print(len(b))
-# print(a[0].size(),a[1].size(),a[2].size())
-print(symbol_tokenizer.lenPOS1())
-
+print(pred)
 
 
-- 
GitLab