From 7b10151214babc2c3f1bc474eb9bec25458a8347 Mon Sep 17 00:00:00 2001
From: PNRIA - Julien <julien.rabault@irit.fr>
Date: Wed, 25 May 2022 17:15:23 +0200
Subject: [PATCH] Add hidden state

---
 .gitignore                              |   1 +
 SuperTagger/SuperTagger.py              |  18 ++++++++++--------
 SuperTagger/Utils/Tagging_bert_model.py |   9 ++++++---
 predict.py                              |   4 ++--
 requirements.txt                        | Bin 476 -> 436 bytes
 train.py                                |   4 ++--
 6 files changed, 21 insertions(+), 15 deletions(-)

diff --git a/.gitignore b/.gitignore
index f57c436..75d2899 100644
--- a/.gitignore
+++ b/.gitignore
@@ -9,3 +9,4 @@ good_models
 main.py
 *.pt
 Datasets/Utils
+*.zip
diff --git a/SuperTagger/SuperTagger.py b/SuperTagger/SuperTagger.py
index a55c8f2..2f5dceb 100644
--- a/SuperTagger/SuperTagger.py
+++ b/SuperTagger/SuperTagger.py
@@ -182,17 +182,17 @@ class SuperTagger:
 
             self.model = self.model.cpu()
 
-            preds, hidden = self.model.predict((sents_tokenized_t, sents_mask_t))
+            output = self.model.predict((sents_tokenized_t, sents_mask_t))
 
-            return preds, self.tags_tokenizer.convert_ids_to_tags(torch.argmax(preds, dim=2).detach()), hidden
+            return output['logit'], self.tags_tokenizer.convert_ids_to_tags(torch.argmax(output['logit'], dim=2).detach())
 
     def forward(self, b_sents_tokenized: Tensor, b_sents_mask: Tensor) -> (Tensor, Tensor):
         """
         Function used for the linker (same of predict)
         """
         with torch.no_grad():
-            logit, hidden = self.model.predict((b_sents_tokenized, b_sents_mask))
-            return logit, hidden
+            output = self.model.predict((b_sents_tokenized, b_sents_mask))
+            return output
 
     def train(self, sentences: list[str], tags: list[list[str]], validation_rate=0.1, epochs=20, batch_size=16,
               tensorboard=False,
@@ -311,9 +311,10 @@ class SuperTagger:
                 targets = batch[2].to(self.device)
                 self.optimizer.zero_grad()
 
-                loss, logit, _ = self.model((b_sents_tokenized, b_sents_mask, targets))
+                output = self.model((b_sents_tokenized, b_sents_mask, targets))
+                loss = output['loss']
 
-                predictions = torch.argmax(logit, dim=2).detach().cpu().numpy()
+                predictions = torch.argmax(output['logit'], dim=2).detach().cpu().numpy()
                 label_ids = targets.cpu().numpy()
 
                 acc = categorical_accuracy(predictions, label_ids)
@@ -353,9 +354,10 @@ class SuperTagger:
                 b_sents_mask = batch[1].to(self.device)
                 b_symbols_tokenized = batch[2].to(self.device)
 
-                loss, logits, _ = self.model((b_sents_tokenized, b_sents_mask, b_symbols_tokenized))
+                output = self.model((b_sents_tokenized, b_sents_mask, b_symbols_tokenized))
+                loss = output['loss']
 
-                predictions = torch.argmax(logits, dim=2).detach().cpu().numpy()
+                predictions = torch.argmax(output['logit'], dim=2).detach().cpu().numpy()
                 label_ids = b_symbols_tokenized.cpu().numpy()
 
                 accuracy = categorical_accuracy(predictions, label_ids)
diff --git a/SuperTagger/Utils/Tagging_bert_model.py b/SuperTagger/Utils/Tagging_bert_model.py
index 83ef46f..5d008ff 100644
--- a/SuperTagger/Utils/Tagging_bert_model.py
+++ b/SuperTagger/Utils/Tagging_bert_model.py
@@ -24,9 +24,10 @@ class Tagging_bert_model(Module):
 
         output = self.bert(
             input_ids=b_input_ids, attention_mask=b_input_mask, labels=labels)
-        loss, logits, hidden = output[:3]
 
-        return loss, logits, hidden[0]
+        result = {'loss': output[0],'logit': output[1], 'word_embeding': output[2][0], 'last_hidden_state': output[2][1]}
+
+        return result
 
     def predict(self, batch):
         b_input_ids = batch[0]
@@ -35,4 +36,6 @@ class Tagging_bert_model(Module):
         output = self.bert(
             input_ids=b_input_ids, attention_mask=b_input_mask)
 
-        return output[0], output[1][0]
+        result = {'logit' : output[0], 'word_embeding': output[1][0], 'last_hidden_state':output[1][1]}
+
+        return result
diff --git a/predict.py b/predict.py
index 5834d01..7d71919 100644
--- a/predict.py
+++ b/predict.py
@@ -13,12 +13,12 @@ tags_s = [['let', 'dr(0,s,s)', 'let', 'dr(0,dr(0,s,s),np)', 'dr(0,np,n)', 'dr(0,
 #### MODEL ####
 tagger = SuperTagger()
 
-model = "models/flaubert_super_98%_V2_50e.pt"
+model = "models/flaubert_super_98%_V2_50e/flaubert_super_98%_V2_50e.pt"
 
 tagger.load_weights(model)
 
 #### TEST ####
-_, pred_convert, _ = tagger.predict(a_s)
+_, pred_convert = tagger.predict(a_s)
 
 print("Model : ", model)
 
diff --git a/requirements.txt b/requirements.txt
index 41ce3b5b133a0d8ce94376ac0e32283b3c9f2417..34e4e866d8569ca33a964edbf27b80e80f75644e 100644
GIT binary patch
delta 22
dcmcb^yoGr~5aZ-3#t1G;20aD?23`g(1^`T61h@bI

delta 54
zcmdnOe1~~M5TkrHLn=caLncENLn=cNLotIbgDryrgC2t+gAs_F?8%tOWx=4wV8Fo3
Hz{LOnO$-UM

diff --git a/train.py b/train.py
index 9ac77d4..8c2290a 100644
--- a/train.py
+++ b/train.py
@@ -7,7 +7,7 @@ file_path = 'Datasets/m2_dataset.csv'
 df = read_csv_pgbar(file_path,100)
 
 texts = df['X'].tolist()
-tags = df['Y1'].tolist()
+tags = df['Z'].tolist()
 
 test_s = texts[:4]
 tags_s = tags[:4]
@@ -15,7 +15,7 @@ tags_s = tags[:4]
 texts = texts[4:]
 tags = tags[4:]
 
-index_to_super = load_obj('Datasets/index_to_pos1')
+index_to_super = load_obj('Datasets/index_to_super')
 
 #### MODEL ####
 tagger = SuperTagger()
-- 
GitLab