From 77ac0b8697ea67c236efdfb11cea70b5fa0422cf Mon Sep 17 00:00:00 2001 From: Caroline DE POURTALES <cdepourt@montana.irit.fr> Date: Tue, 17 May 2022 15:28:25 +0200 Subject: [PATCH] brouillon main predict --- Linker/Linker.py | 3 +++ SuperTagger/__init__.py | 0 main.py | 16 +++++++++------- 3 files changed, 12 insertions(+), 7 deletions(-) delete mode 100644 SuperTagger/__init__.py diff --git a/Linker/Linker.py b/Linker/Linker.py index 4d4eaaa..7f7462d 100644 --- a/Linker/Linker.py +++ b/Linker/Linker.py @@ -192,6 +192,7 @@ class Linker(Module): self.scheduler.step() avg_train_loss = epoch_loss / len(training_dataloader) + print("Average Loss on train dataset : ", avg_train_loss) if checkpoint: checkpoint_dir = os.path.join("Output", 'Tranning_' + datetime.today().strftime('%d-%m_%H-%M')) @@ -200,6 +201,8 @@ class Linker(Module): if validate: with torch.no_grad(): accuracy, average_test_loss = self.eval_epoch(validation_dataloader, self.cross_entropy_loss) + print("Average Loss on test dataset : ", average_test_loss) + print("Average Accuracy on test dataset : ", accuracy) return accuracy, avg_train_loss diff --git a/SuperTagger/__init__.py b/SuperTagger/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/main.py b/main.py index 723f0f1..55e8c52 100644 --- a/main.py +++ b/main.py @@ -1,5 +1,5 @@ import torch.nn.functional as F - +import torch from Configuration import Configuration from Linker.Linker import Linker from Supertagger.SuperTagger.SuperTagger import SuperTagger @@ -7,15 +7,17 @@ from Supertagger.SuperTagger.SuperTagger import SuperTagger max_atoms_in_sentence = int(Configuration.datasetConfig['max_atoms_in_sentence']) # categories tagger -tagger = SuperTagger() -tagger.load_weights("models/model_check.pt") +supertagger = SuperTagger() +supertagger.load_weights("models/model_supertagger.pt") # axiom linker -linker = Linker() +linker = Linker(supertagger) linker.load_weights("models/linker.pt") # predict categories and links for this sentence -sentence = [[]] -categories, sentence_embedding = tagger.predict(sentence) +sentence = ["le chat est noir"] +sents_tokenized, sents_mask = supertagger.sent_tokenizer.fit_transform_tensors(sentence) +logits, sentence_embedding = supertagger.foward(sents_tokenized, sents_mask) +categories = torch.argmax(F.softmax(logits, dim=2), dim=2) -axiom_links = linker.predict(categories, sentence_embedding) +axiom_links = linker.predict(categories, sentence_embedding, sents_mask) -- GitLab