diff --git a/.gitignore b/.gitignore index 844b4af0825b9c4c45bf5e3dea1189dc919711ce..371503a80d900dde316cbdc92efdf8fa6ce1812f 100644 --- a/.gitignore +++ b/.gitignore @@ -14,3 +14,4 @@ models *.pkl good_models/model_check.pt main.py +*.pt diff --git a/SuperTagger/Utils/Tagging_bert_model.py b/SuperTagger/Utils/Tagging_bert_model.py index 3afad6caaf7e9f99894410f15e7dbb0f65a0d105..aaa1c771a4dabed21efbee7eb5a167fc5e5a39cc 100644 --- a/SuperTagger/Utils/Tagging_bert_model.py +++ b/SuperTagger/Utils/Tagging_bert_model.py @@ -29,7 +29,7 @@ class Tagging_bert_model(Module): input_ids=b_input_ids, attention_mask=b_input_mask, labels=labels) loss, logits, hidden = output[:3] - return loss, logits, hidden + return loss, logits, hidden[0] def predict(self, batch): b_input_ids = batch[0] @@ -38,4 +38,4 @@ class Tagging_bert_model(Module): output = self.bert( input_ids=b_input_ids, attention_mask=b_input_mask) - return torch.argmax(output[0], dim=2), output[1] + return torch.argmax(output[0], dim=2), output[1][0] diff --git a/good_models/camenbert_classique_80%.pt b/good_models/camenbert_classique_80%.pt deleted file mode 100644 index c7e569fba3fcfd837dfd4bd1a0ad0eec52a1ed4a..0000000000000000000000000000000000000000 Binary files a/good_models/camenbert_classique_80%.pt and /dev/null differ