Skip to content
Snippets Groups Projects
Commit 77ac0b86 authored by Caroline DE POURTALES's avatar Caroline DE POURTALES
Browse files

brouillon main predict

parent 8d109c5a
Branches
No related tags found
2 merge requests!6Linker with transformer,!5Linker with transformer
......@@ -192,6 +192,7 @@ class Linker(Module):
self.scheduler.step()
avg_train_loss = epoch_loss / len(training_dataloader)
print("Average Loss on train dataset : ", avg_train_loss)
if checkpoint:
checkpoint_dir = os.path.join("Output", 'Tranning_' + datetime.today().strftime('%d-%m_%H-%M'))
......@@ -200,6 +201,8 @@ class Linker(Module):
if validate:
with torch.no_grad():
accuracy, average_test_loss = self.eval_epoch(validation_dataloader, self.cross_entropy_loss)
print("Average Loss on test dataset : ", average_test_loss)
print("Average Accuracy on test dataset : ", accuracy)
return accuracy, avg_train_loss
......
import torch.nn.functional as F
import torch
from Configuration import Configuration
from Linker.Linker import Linker
from Supertagger.SuperTagger.SuperTagger import SuperTagger
......@@ -7,15 +7,17 @@ from Supertagger.SuperTagger.SuperTagger import SuperTagger
max_atoms_in_sentence = int(Configuration.datasetConfig['max_atoms_in_sentence'])
# categories tagger
tagger = SuperTagger()
tagger.load_weights("models/model_check.pt")
supertagger = SuperTagger()
supertagger.load_weights("models/model_supertagger.pt")
# axiom linker
linker = Linker()
linker = Linker(supertagger)
linker.load_weights("models/linker.pt")
# predict categories and links for this sentence
sentence = [[]]
categories, sentence_embedding = tagger.predict(sentence)
sentence = ["le chat est noir"]
sents_tokenized, sents_mask = supertagger.sent_tokenizer.fit_transform_tensors(sentence)
logits, sentence_embedding = supertagger.foward(sents_tokenized, sents_mask)
categories = torch.argmax(F.softmax(logits, dim=2), dim=2)
axiom_links = linker.predict(categories, sentence_embedding)
axiom_links = linker.predict(categories, sentence_embedding, sents_mask)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment