-
Caroline DE POURTALES authoredCaroline DE POURTALES authored
main.py 786 B
import torch.nn.functional as F
import torch
from Configuration import Configuration
from Linker import *
from Supertagger import *
max_atoms_in_sentence = int(Configuration.datasetConfig['max_atoms_in_sentence'])
# categories tagger
supertagger = SuperTagger()
supertagger.load_weights("models/model_supertagger.pt")
# axiom linker
linker = Linker(supertagger)
linker.load_weights("models/linker.pt")
# predict categories and links for this sentence
sentence = ["le chat est noir"]
sents_tokenized, sents_mask = supertagger.sent_tokenizer.fit_transform_tensors(sentence)
logits, sentence_embedding = supertagger.foward(sents_tokenized, sents_mask)
categories = torch.argmax(F.softmax(logits, dim=2), dim=2)
axiom_links = linker.predict(categories, sentence_embedding, sents_mask)