Skip to content
Snippets Groups Projects
Commit 77ac0b86 authored by Caroline DE POURTALES's avatar Caroline DE POURTALES
Browse files

brouillon main predict

parent 8d109c5a
No related branches found
No related tags found
2 merge requests!6Linker with transformer,!5Linker with transformer
...@@ -192,6 +192,7 @@ class Linker(Module): ...@@ -192,6 +192,7 @@ class Linker(Module):
self.scheduler.step() self.scheduler.step()
avg_train_loss = epoch_loss / len(training_dataloader) avg_train_loss = epoch_loss / len(training_dataloader)
print("Average Loss on train dataset : ", avg_train_loss)
if checkpoint: if checkpoint:
checkpoint_dir = os.path.join("Output", 'Tranning_' + datetime.today().strftime('%d-%m_%H-%M')) checkpoint_dir = os.path.join("Output", 'Tranning_' + datetime.today().strftime('%d-%m_%H-%M'))
...@@ -200,6 +201,8 @@ class Linker(Module): ...@@ -200,6 +201,8 @@ class Linker(Module):
if validate: if validate:
with torch.no_grad(): with torch.no_grad():
accuracy, average_test_loss = self.eval_epoch(validation_dataloader, self.cross_entropy_loss) accuracy, average_test_loss = self.eval_epoch(validation_dataloader, self.cross_entropy_loss)
print("Average Loss on test dataset : ", average_test_loss)
print("Average Accuracy on test dataset : ", accuracy)
return accuracy, avg_train_loss return accuracy, avg_train_loss
......
import torch.nn.functional as F import torch.nn.functional as F
import torch
from Configuration import Configuration from Configuration import Configuration
from Linker.Linker import Linker from Linker.Linker import Linker
from Supertagger.SuperTagger.SuperTagger import SuperTagger from Supertagger.SuperTagger.SuperTagger import SuperTagger
...@@ -7,15 +7,17 @@ from Supertagger.SuperTagger.SuperTagger import SuperTagger ...@@ -7,15 +7,17 @@ from Supertagger.SuperTagger.SuperTagger import SuperTagger
max_atoms_in_sentence = int(Configuration.datasetConfig['max_atoms_in_sentence']) max_atoms_in_sentence = int(Configuration.datasetConfig['max_atoms_in_sentence'])
# categories tagger # categories tagger
tagger = SuperTagger() supertagger = SuperTagger()
tagger.load_weights("models/model_check.pt") supertagger.load_weights("models/model_supertagger.pt")
# axiom linker # axiom linker
linker = Linker() linker = Linker(supertagger)
linker.load_weights("models/linker.pt") linker.load_weights("models/linker.pt")
# predict categories and links for this sentence # predict categories and links for this sentence
sentence = [[]] sentence = ["le chat est noir"]
categories, sentence_embedding = tagger.predict(sentence) sents_tokenized, sents_mask = supertagger.sent_tokenizer.fit_transform_tensors(sentence)
logits, sentence_embedding = supertagger.foward(sents_tokenized, sents_mask)
categories = torch.argmax(F.softmax(logits, dim=2), dim=2)
axiom_links = linker.predict(categories, sentence_embedding) axiom_links = linker.predict(categories, sentence_embedding, sents_mask)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment