Skip to content
Snippets Groups Projects
predict.py 1.27 KiB
from SuperTagger.SuperTagger import SuperTagger
from SuperTagger.Utils.helpers import categorical_accuracy_str

#### DATA ####

a_s = "( 1 ) parmi les huit \" partants \" acquis ou potentiels , MM. Lacombe , Koehler et Laroze ne sont pas membres " \
      "du PCF . "
tags_s = [['let', 'dr(0,s,s)', 'let', 'dr(0,dr(0,s,s),np)', 'dr(0,np,n)', 'dr(0,n,n)', 'let', 'n', 'let', 'dl(0,n,n)',
           'dr(0,dl(0,dl(0,n,n),dl(0,n,n)),dl(0,n,n))', 'dl(0,n,n)', 'let', 'dr(0,np,np)', 'np', 'dr(0,dl(0,np,np),np)',
           'np', 'dr(0,dl(0,np,np),np)', 'np', 'dr(0,dl(0,np,s),dl(0,np,s))', 'dr(0,dl(0,np,s),np)', 'dl(1,s,s)', 'np',
           'dr(0,dl(0,np,np),n)', 'n', 'dl(0,s,txt)']]

#### MODEL ####
tagger = SuperTagger()

model = "models/flaubert_super_98%_V2_50e/flaubert_super_98%_V2_50e.pt"

tagger.load_weights(model)

#### TEST ####
_, pred_convert = tagger.predict(a_s)

print("Model : ", model)

print("\tLen Text           : ", len(a_s.split()))
print("\tLen tags           : ", len(tags_s[0]))
print("\tLen pred_convert   : ", len(pred_convert[0]))
print()
print("\tText               : ", a_s)
print()
print("\tTags               : ", tags_s[0])
print()
print("\tPred_convert       : ", pred_convert[0])
print()
print("\tScore              :", categorical_accuracy_str(pred_convert, tags_s))