Skip to content
Snippets Groups Projects
predict_links.py 1.15 KiB
from NeuralProofNet.NeuralProofNet import NeuralProofNet
from postprocessing import draw_sentence_output

# region data
a_s = ["( 1 ) parmi les huit \" partants \" acquis ou potentiels , MM. Lacombe , Koehler et Laroze ne sont pas membres du PCF ."]
tags_s = ['let', 'dr(0,s,s)', 'let', 'dr(0,dr(0,s,s),np)', 'dr(0,np,n)', 'dr(0,n,n)', 'let', 'n', 'let', 'dl(0,n,n)',
           'dr(0,dl(0,dl(0,n,n),dl(0,n,n)),dl(0,n,n))', 'dl(0,n,n)', 'let', 'dr(0,np,np)', 'np', 'dr(0,dl(0,np,np),np)',
           'np', 'dr(0,dl(0,np,np),np)', 'np', 'dr(0,dl(0,np,s),dl(0,np,s))', 'dr(0,dl(0,np,s),np)', 'dl(1,s,s)', 'np',
           'dr(0,dl(0,np,np),n)', 'n', 'dl(0,s,txt)']
# endregion


# region model
model_tagger = "models/flaubert_super_98_V2_50e.pt"
neuralproofnet = NeuralProofNet(model_tagger)
model = "Output/linker.pt"
neuralproofnet.linker.load_weights(model)
# endregion


# region prediction
linker = neuralproofnet.linker
categories, links = linker.predict_without_categories(a_s)
#links = linker.predict_with_categories(a_s, tags_s)
# endregion

if __name__== '__main__':
      idx=0
      draw_sentence_output(a_s[idx].split(" "), categories[idx], links[:,idx,:].numpy())