-
Julien Rabault authoredJulien Rabault authored
predict.py 1.27 KiB
from SuperTagger.SuperTagger import SuperTagger
from SuperTagger.Utils.helpers import categorical_accuracy_str
#### DATA ####
a_s = "( 1 ) parmi les huit \" partants \" acquis ou potentiels , MM. Lacombe , Koehler et Laroze ne sont pas membres " \
"du PCF . "
tags_s = [['let', 'dr(0,s,s)', 'let', 'dr(0,dr(0,s,s),np)', 'dr(0,np,n)', 'dr(0,n,n)', 'let', 'n', 'let', 'dl(0,n,n)',
'dr(0,dl(0,dl(0,n,n),dl(0,n,n)),dl(0,n,n))', 'dl(0,n,n)', 'let', 'dr(0,np,np)', 'np', 'dr(0,dl(0,np,np),np)',
'np', 'dr(0,dl(0,np,np),np)', 'np', 'dr(0,dl(0,np,s),dl(0,np,s))', 'dr(0,dl(0,np,s),np)', 'dl(1,s,s)', 'np',
'dr(0,dl(0,np,np),n)', 'n', 'dl(0,s,txt)']]
#### MODEL ####
tagger = SuperTagger()
model = "models/flaubert_super_98%_V2_50e/flaubert_super_98%_V2_50e.pt"
tagger.load_weights(model)
#### TEST ####
_, pred_convert = tagger.predict(a_s)
print("Model : ", model)
print("\tLen Text : ", len(a_s.split()))
print("\tLen tags : ", len(tags_s[0]))
print("\tLen pred_convert : ", len(pred_convert[0]))
print()
print("\tText : ", a_s)
print()
print("\tTags : ", tags_s[0])
print()
print("\tPred_convert : ", pred_convert[0])
print()
print("\tScore :", categorical_accuracy_str(pred_convert, tags_s))