diff --git a/Linker/Linker.py b/Linker/Linker.py index 4d4eaaadee8560ef15667203cafdb3d8e9f0598f..7f7462dfa9bcc615c9234e9bd474adfe47f99e02 100644 --- a/Linker/Linker.py +++ b/Linker/Linker.py @@ -192,6 +192,7 @@ class Linker(Module): self.scheduler.step() avg_train_loss = epoch_loss / len(training_dataloader) + print("Average Loss on train dataset : ", avg_train_loss) if checkpoint: checkpoint_dir = os.path.join("Output", 'Tranning_' + datetime.today().strftime('%d-%m_%H-%M')) @@ -200,6 +201,8 @@ class Linker(Module): if validate: with torch.no_grad(): accuracy, average_test_loss = self.eval_epoch(validation_dataloader, self.cross_entropy_loss) + print("Average Loss on test dataset : ", average_test_loss) + print("Average Accuracy on test dataset : ", accuracy) return accuracy, avg_train_loss diff --git a/SuperTagger/__init__.py b/SuperTagger/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/main.py b/main.py index 723f0f188b2e9a969ecffb267967e05cee16390a..55e8c527e8a70afb39389339c74bab66a3bbe4e1 100644 --- a/main.py +++ b/main.py @@ -1,5 +1,5 @@ import torch.nn.functional as F - +import torch from Configuration import Configuration from Linker.Linker import Linker from Supertagger.SuperTagger.SuperTagger import SuperTagger @@ -7,15 +7,17 @@ from Supertagger.SuperTagger.SuperTagger import SuperTagger max_atoms_in_sentence = int(Configuration.datasetConfig['max_atoms_in_sentence']) # categories tagger -tagger = SuperTagger() -tagger.load_weights("models/model_check.pt") +supertagger = SuperTagger() +supertagger.load_weights("models/model_supertagger.pt") # axiom linker -linker = Linker() +linker = Linker(supertagger) linker.load_weights("models/linker.pt") # predict categories and links for this sentence -sentence = [[]] -categories, sentence_embedding = tagger.predict(sentence) +sentence = ["le chat est noir"] +sents_tokenized, sents_mask = supertagger.sent_tokenizer.fit_transform_tensors(sentence) +logits, sentence_embedding = supertagger.foward(sents_tokenized, sents_mask) +categories = torch.argmax(F.softmax(logits, dim=2), dim=2) -axiom_links = linker.predict(categories, sentence_embedding) +axiom_links = linker.predict(categories, sentence_embedding, sents_mask)