Skip to content
Snippets Groups Projects
Commit a7008a57 authored by Julien Rabault's avatar Julien Rabault
Browse files

Add predict and train file

parent eeb4774c
No related branches found
No related tags found
No related merge requests found
......@@ -105,7 +105,8 @@ class SuperTagger:
"""
self.trainable = False
print("#" * 15)
print("#" * 20)
print("\n Loading...")
try:
params = torch.load(model_file, map_location=self.device)
args = params['args']
......@@ -127,7 +128,7 @@ class SuperTagger:
except Exception as e:
print("\n/!\ Can't load checkpoint model /!\ because :\n\n " + str(e), file=sys.stderr)
raise e
print("#" * 15)
print("#" * 20)
self.model_load = True
self.trainable = True
......@@ -224,7 +225,7 @@ class SuperTagger:
for epoch_i in range(0, epochs):
print("")
print('======== Epoch {:} / {:} ========'.format(epoch_i, epochs))
print('======== Epoch {:} / {:} ========'.format(epoch_i+1, epochs))
print('Training...')
# Train
......@@ -310,7 +311,7 @@ class SuperTagger:
targets = batch[2].to(self.device)
self.optimizer.zero_grad()
loss, logit = self.model((b_sents_tokenized, b_sents_mask, targets))
loss, logit, _ = self.model((b_sents_tokenized, b_sents_mask, targets))
predictions = torch.argmax(logit, dim=2).detach().cpu().numpy()
label_ids = targets.cpu().numpy()
......
File moved
from SuperTagger.SuperTagger import SuperTagger
from SuperTagger.Utils.helpers import categorical_accuracy_str
#### DATA ####
a_s = "( 1 ) parmi les huit \" partants \" acquis ou potentiels , MM. Lacombe , Koehler et Laroze ne sont pas membres " \
"du PCF . "
tags_s = [['let', 'dr(0,s,s)', 'let', 'dr(0,dr(0,s,s),np)', 'dr(0,np,n)', 'dr(0,n,n)', 'let', 'n', 'let', 'dl(0,n,n)',
'dr(0,dl(0,dl(0,n,n),dl(0,n,n)),dl(0,n,n))', 'dl(0,n,n)', 'let', 'dr(0,np,np)', 'np', 'dr(0,dl(0,np,np),np)',
'np', 'dr(0,dl(0,np,np),np)', 'np', 'dr(0,dl(0,np,s),dl(0,np,s))', 'dr(0,dl(0,np,s),np)', 'dl(1,s,s)', 'np',
'dr(0,dl(0,np,np),n)', 'n', 'dl(0,s,txt)']]
#### MODEL ####
tagger = SuperTagger()
model = "models/flaubert_super_98%_V2_50e.pt"
tagger.load_weights(model)
#### TEST ####
_, pred_convert, _ = tagger.predict(a_s)
print("Model : ", model)
print("\tLen Text : ", len(a_s.split()))
print("\tLen tags : ", len(tags_s[0]))
print("\tLen pred_convert : ", len(pred_convert[0]))
print()
print("\tText : ", a_s)
print()
print("\tTags : ", tags_s[0])
print()
print("\tPred_convert : ", pred_convert[0])
print()
print("\tScore :", categorical_accuracy_str(pred_convert, tags_s))
from SuperTagger.SuperTagger import SuperTagger
from SuperTagger.Utils.utils import read_csv_pgbar, load_obj
from SuperTagger.Utils.helpers import read_csv_pgbar, load_obj
#### DATA ####
file_path = 'Datasets/m2_dataset.csv'
df = read_csv_pgbar(file_path,1000)
texts = df['X'].tolist()
tags = df['Y1'].tolist()
......@@ -16,9 +15,9 @@ tags_s = tags[:4]
texts = texts[4:]
tags = tags[4:]
index_to_super = load_obj('Datasets/index_to_pos1')
#### MODEL ####
tagger = SuperTagger()
tagger.create_new_model(len(index_to_super),'camembert-base',index_to_super)
......@@ -26,6 +25,8 @@ tagger.create_new_model(len(index_to_super),'camembert-base',index_to_super)
tagger.train(texts,tags,batch_size=16,validation_rate=0.1,tensorboard=True,checkpoint=True)
#### TEST ####
pred = tagger.predict(test_s)
print(test_s)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment