diff --git a/.gitignore b/.gitignore index f57c436fa6463b2f0a0f382938ddf148bd516eea..75d289914e587cf2ccde09d2df1d9a06db85e5df 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,4 @@ good_models main.py *.pt Datasets/Utils +*.zip diff --git a/SuperTagger/SuperTagger.py b/SuperTagger/SuperTagger.py index a55c8f272a681f89132313929c44c55a7625586c..2f5dcebbc95a1f7d9c009d37a2c2b74630ede9e2 100644 --- a/SuperTagger/SuperTagger.py +++ b/SuperTagger/SuperTagger.py @@ -182,17 +182,17 @@ class SuperTagger: self.model = self.model.cpu() - preds, hidden = self.model.predict((sents_tokenized_t, sents_mask_t)) + output = self.model.predict((sents_tokenized_t, sents_mask_t)) - return preds, self.tags_tokenizer.convert_ids_to_tags(torch.argmax(preds, dim=2).detach()), hidden + return output['logit'], self.tags_tokenizer.convert_ids_to_tags(torch.argmax(output['logit'], dim=2).detach()) def forward(self, b_sents_tokenized: Tensor, b_sents_mask: Tensor) -> (Tensor, Tensor): """ Function used for the linker (same of predict) """ with torch.no_grad(): - logit, hidden = self.model.predict((b_sents_tokenized, b_sents_mask)) - return logit, hidden + output = self.model.predict((b_sents_tokenized, b_sents_mask)) + return output def train(self, sentences: list[str], tags: list[list[str]], validation_rate=0.1, epochs=20, batch_size=16, tensorboard=False, @@ -311,9 +311,10 @@ class SuperTagger: targets = batch[2].to(self.device) self.optimizer.zero_grad() - loss, logit, _ = self.model((b_sents_tokenized, b_sents_mask, targets)) + output = self.model((b_sents_tokenized, b_sents_mask, targets)) + loss = output['loss'] - predictions = torch.argmax(logit, dim=2).detach().cpu().numpy() + predictions = torch.argmax(output['logit'], dim=2).detach().cpu().numpy() label_ids = targets.cpu().numpy() acc = categorical_accuracy(predictions, label_ids) @@ -353,9 +354,10 @@ class SuperTagger: b_sents_mask = batch[1].to(self.device) b_symbols_tokenized = batch[2].to(self.device) - loss, logits, _ = self.model((b_sents_tokenized, b_sents_mask, b_symbols_tokenized)) + output = self.model((b_sents_tokenized, b_sents_mask, b_symbols_tokenized)) + loss = output['loss'] - predictions = torch.argmax(logits, dim=2).detach().cpu().numpy() + predictions = torch.argmax(output['logit'], dim=2).detach().cpu().numpy() label_ids = b_symbols_tokenized.cpu().numpy() accuracy = categorical_accuracy(predictions, label_ids) diff --git a/SuperTagger/Utils/Tagging_bert_model.py b/SuperTagger/Utils/Tagging_bert_model.py index 83ef46f6791561549123c5901ea65f691bc07283..5d008ffb6602f30e9238ae06aa0ed3de80280ea6 100644 --- a/SuperTagger/Utils/Tagging_bert_model.py +++ b/SuperTagger/Utils/Tagging_bert_model.py @@ -24,9 +24,10 @@ class Tagging_bert_model(Module): output = self.bert( input_ids=b_input_ids, attention_mask=b_input_mask, labels=labels) - loss, logits, hidden = output[:3] - return loss, logits, hidden[0] + result = {'loss': output[0],'logit': output[1], 'word_embeding': output[2][0], 'last_hidden_state': output[2][1]} + + return result def predict(self, batch): b_input_ids = batch[0] @@ -35,4 +36,6 @@ class Tagging_bert_model(Module): output = self.bert( input_ids=b_input_ids, attention_mask=b_input_mask) - return output[0], output[1][0] + result = {'logit' : output[0], 'word_embeding': output[1][0], 'last_hidden_state':output[1][1]} + + return result diff --git a/predict.py b/predict.py index 5834d013fbb51afd1da8ba3680a7647bb4e54725..7d7191976e80e48b5c35a428231658887a7e1420 100644 --- a/predict.py +++ b/predict.py @@ -13,12 +13,12 @@ tags_s = [['let', 'dr(0,s,s)', 'let', 'dr(0,dr(0,s,s),np)', 'dr(0,np,n)', 'dr(0, #### MODEL #### tagger = SuperTagger() -model = "models/flaubert_super_98%_V2_50e.pt" +model = "models/flaubert_super_98%_V2_50e/flaubert_super_98%_V2_50e.pt" tagger.load_weights(model) #### TEST #### -_, pred_convert, _ = tagger.predict(a_s) +_, pred_convert = tagger.predict(a_s) print("Model : ", model) diff --git a/requirements.txt b/requirements.txt index 41ce3b5b133a0d8ce94376ac0e32283b3c9f2417..34e4e866d8569ca33a964edbf27b80e80f75644e 100644 Binary files a/requirements.txt and b/requirements.txt differ diff --git a/train.py b/train.py index 9ac77d44ad546ba5ef61d2f120f7bcffbdad79b2..8c2290a9cf724d9c82c3f29128679cb5e9c356db 100644 --- a/train.py +++ b/train.py @@ -7,7 +7,7 @@ file_path = 'Datasets/m2_dataset.csv' df = read_csv_pgbar(file_path,100) texts = df['X'].tolist() -tags = df['Y1'].tolist() +tags = df['Z'].tolist() test_s = texts[:4] tags_s = tags[:4] @@ -15,7 +15,7 @@ tags_s = tags[:4] texts = texts[4:] tags = tags[4:] -index_to_super = load_obj('Datasets/index_to_pos1') +index_to_super = load_obj('Datasets/index_to_super') #### MODEL #### tagger = SuperTagger()