From 7b10151214babc2c3f1bc474eb9bec25458a8347 Mon Sep 17 00:00:00 2001 From: PNRIA - Julien <julien.rabault@irit.fr> Date: Wed, 25 May 2022 17:15:23 +0200 Subject: [PATCH] Add hidden state --- .gitignore | 1 + SuperTagger/SuperTagger.py | 18 ++++++++++-------- SuperTagger/Utils/Tagging_bert_model.py | 9 ++++++--- predict.py | 4 ++-- requirements.txt | Bin 476 -> 436 bytes train.py | 4 ++-- 6 files changed, 21 insertions(+), 15 deletions(-) diff --git a/.gitignore b/.gitignore index f57c436..75d2899 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,4 @@ good_models main.py *.pt Datasets/Utils +*.zip diff --git a/SuperTagger/SuperTagger.py b/SuperTagger/SuperTagger.py index a55c8f2..2f5dceb 100644 --- a/SuperTagger/SuperTagger.py +++ b/SuperTagger/SuperTagger.py @@ -182,17 +182,17 @@ class SuperTagger: self.model = self.model.cpu() - preds, hidden = self.model.predict((sents_tokenized_t, sents_mask_t)) + output = self.model.predict((sents_tokenized_t, sents_mask_t)) - return preds, self.tags_tokenizer.convert_ids_to_tags(torch.argmax(preds, dim=2).detach()), hidden + return output['logit'], self.tags_tokenizer.convert_ids_to_tags(torch.argmax(output['logit'], dim=2).detach()) def forward(self, b_sents_tokenized: Tensor, b_sents_mask: Tensor) -> (Tensor, Tensor): """ Function used for the linker (same of predict) """ with torch.no_grad(): - logit, hidden = self.model.predict((b_sents_tokenized, b_sents_mask)) - return logit, hidden + output = self.model.predict((b_sents_tokenized, b_sents_mask)) + return output def train(self, sentences: list[str], tags: list[list[str]], validation_rate=0.1, epochs=20, batch_size=16, tensorboard=False, @@ -311,9 +311,10 @@ class SuperTagger: targets = batch[2].to(self.device) self.optimizer.zero_grad() - loss, logit, _ = self.model((b_sents_tokenized, b_sents_mask, targets)) + output = self.model((b_sents_tokenized, b_sents_mask, targets)) + loss = output['loss'] - predictions = torch.argmax(logit, dim=2).detach().cpu().numpy() + predictions = torch.argmax(output['logit'], dim=2).detach().cpu().numpy() label_ids = targets.cpu().numpy() acc = categorical_accuracy(predictions, label_ids) @@ -353,9 +354,10 @@ class SuperTagger: b_sents_mask = batch[1].to(self.device) b_symbols_tokenized = batch[2].to(self.device) - loss, logits, _ = self.model((b_sents_tokenized, b_sents_mask, b_symbols_tokenized)) + output = self.model((b_sents_tokenized, b_sents_mask, b_symbols_tokenized)) + loss = output['loss'] - predictions = torch.argmax(logits, dim=2).detach().cpu().numpy() + predictions = torch.argmax(output['logit'], dim=2).detach().cpu().numpy() label_ids = b_symbols_tokenized.cpu().numpy() accuracy = categorical_accuracy(predictions, label_ids) diff --git a/SuperTagger/Utils/Tagging_bert_model.py b/SuperTagger/Utils/Tagging_bert_model.py index 83ef46f..5d008ff 100644 --- a/SuperTagger/Utils/Tagging_bert_model.py +++ b/SuperTagger/Utils/Tagging_bert_model.py @@ -24,9 +24,10 @@ class Tagging_bert_model(Module): output = self.bert( input_ids=b_input_ids, attention_mask=b_input_mask, labels=labels) - loss, logits, hidden = output[:3] - return loss, logits, hidden[0] + result = {'loss': output[0],'logit': output[1], 'word_embeding': output[2][0], 'last_hidden_state': output[2][1]} + + return result def predict(self, batch): b_input_ids = batch[0] @@ -35,4 +36,6 @@ class Tagging_bert_model(Module): output = self.bert( input_ids=b_input_ids, attention_mask=b_input_mask) - return output[0], output[1][0] + result = {'logit' : output[0], 'word_embeding': output[1][0], 'last_hidden_state':output[1][1]} + + return result diff --git a/predict.py b/predict.py index 5834d01..7d71919 100644 --- a/predict.py +++ b/predict.py @@ -13,12 +13,12 @@ tags_s = [['let', 'dr(0,s,s)', 'let', 'dr(0,dr(0,s,s),np)', 'dr(0,np,n)', 'dr(0, #### MODEL #### tagger = SuperTagger() -model = "models/flaubert_super_98%_V2_50e.pt" +model = "models/flaubert_super_98%_V2_50e/flaubert_super_98%_V2_50e.pt" tagger.load_weights(model) #### TEST #### -_, pred_convert, _ = tagger.predict(a_s) +_, pred_convert = tagger.predict(a_s) print("Model : ", model) diff --git a/requirements.txt b/requirements.txt index 41ce3b5b133a0d8ce94376ac0e32283b3c9f2417..34e4e866d8569ca33a964edbf27b80e80f75644e 100644 GIT binary patch delta 22 dcmcb^yoGr~5aZ-3#t1G;20aD?23`g(1^`T61h@bI delta 54 zcmdnOe1~~M5TkrHLn=caLncENLn=cNLotIbgDryrgC2t+gAs_F?8%tOWx=4wV8Fo3 Hz{LOnO$-UM diff --git a/train.py b/train.py index 9ac77d4..8c2290a 100644 --- a/train.py +++ b/train.py @@ -7,7 +7,7 @@ file_path = 'Datasets/m2_dataset.csv' df = read_csv_pgbar(file_path,100) texts = df['X'].tolist() -tags = df['Y1'].tolist() +tags = df['Z'].tolist() test_s = texts[:4] tags_s = tags[:4] @@ -15,7 +15,7 @@ tags_s = tags[:4] texts = texts[4:] tags = tags[4:] -index_to_super = load_obj('Datasets/index_to_pos1') +index_to_super = load_obj('Datasets/index_to_super') #### MODEL #### tagger = SuperTagger() -- GitLab