From a7008a57b0f2b5faa3706086e509295c2aaec1e8 Mon Sep 17 00:00:00 2001 From: PNRIA - Julien <julien.rabault@irit.fr> Date: Wed, 18 May 2022 10:18:40 +0200 Subject: [PATCH] Add predict and train file --- SuperTagger/SuperTagger.py | 9 +++--- SuperTagger/Utils/{utils.py => helpers.py} | 0 predict.py | 35 ++++++++++++++++++++++ train.py | 9 +++--- 4 files changed, 45 insertions(+), 8 deletions(-) rename SuperTagger/Utils/{utils.py => helpers.py} (100%) create mode 100644 predict.py diff --git a/SuperTagger/SuperTagger.py b/SuperTagger/SuperTagger.py index b49485e..8e211f7 100644 --- a/SuperTagger/SuperTagger.py +++ b/SuperTagger/SuperTagger.py @@ -105,7 +105,8 @@ class SuperTagger: """ self.trainable = False - print("#" * 15) + print("#" * 20) + print("\n Loading...") try: params = torch.load(model_file, map_location=self.device) args = params['args'] @@ -127,7 +128,7 @@ class SuperTagger: except Exception as e: print("\n/!\ Can't load checkpoint model /!\ because :\n\n " + str(e), file=sys.stderr) raise e - print("#" * 15) + print("#" * 20) self.model_load = True self.trainable = True @@ -224,7 +225,7 @@ class SuperTagger: for epoch_i in range(0, epochs): print("") - print('======== Epoch {:} / {:} ========'.format(epoch_i, epochs)) + print('======== Epoch {:} / {:} ========'.format(epoch_i+1, epochs)) print('Training...') # Train @@ -310,7 +311,7 @@ class SuperTagger: targets = batch[2].to(self.device) self.optimizer.zero_grad() - loss, logit = self.model((b_sents_tokenized, b_sents_mask, targets)) + loss, logit, _ = self.model((b_sents_tokenized, b_sents_mask, targets)) predictions = torch.argmax(logit, dim=2).detach().cpu().numpy() label_ids = targets.cpu().numpy() diff --git a/SuperTagger/Utils/utils.py b/SuperTagger/Utils/helpers.py similarity index 100% rename from SuperTagger/Utils/utils.py rename to SuperTagger/Utils/helpers.py diff --git a/predict.py b/predict.py new file mode 100644 index 0000000..5834d01 --- /dev/null +++ b/predict.py @@ -0,0 +1,35 @@ +from SuperTagger.SuperTagger import SuperTagger +from SuperTagger.Utils.helpers import categorical_accuracy_str + +#### DATA #### + +a_s = "( 1 ) parmi les huit \" partants \" acquis ou potentiels , MM. Lacombe , Koehler et Laroze ne sont pas membres " \ + "du PCF . " +tags_s = [['let', 'dr(0,s,s)', 'let', 'dr(0,dr(0,s,s),np)', 'dr(0,np,n)', 'dr(0,n,n)', 'let', 'n', 'let', 'dl(0,n,n)', + 'dr(0,dl(0,dl(0,n,n),dl(0,n,n)),dl(0,n,n))', 'dl(0,n,n)', 'let', 'dr(0,np,np)', 'np', 'dr(0,dl(0,np,np),np)', + 'np', 'dr(0,dl(0,np,np),np)', 'np', 'dr(0,dl(0,np,s),dl(0,np,s))', 'dr(0,dl(0,np,s),np)', 'dl(1,s,s)', 'np', + 'dr(0,dl(0,np,np),n)', 'n', 'dl(0,s,txt)']] + +#### MODEL #### +tagger = SuperTagger() + +model = "models/flaubert_super_98%_V2_50e.pt" + +tagger.load_weights(model) + +#### TEST #### +_, pred_convert, _ = tagger.predict(a_s) + +print("Model : ", model) + +print("\tLen Text : ", len(a_s.split())) +print("\tLen tags : ", len(tags_s[0])) +print("\tLen pred_convert : ", len(pred_convert[0])) +print() +print("\tText : ", a_s) +print() +print("\tTags : ", tags_s[0]) +print() +print("\tPred_convert : ", pred_convert[0]) +print() +print("\tScore :", categorical_accuracy_str(pred_convert, tags_s)) diff --git a/train.py b/train.py index ddcda4c..d3cb4ef 100644 --- a/train.py +++ b/train.py @@ -1,12 +1,11 @@ from SuperTagger.SuperTagger import SuperTagger -from SuperTagger.Utils.utils import read_csv_pgbar, load_obj +from SuperTagger.Utils.helpers import read_csv_pgbar, load_obj +#### DATA #### file_path = 'Datasets/m2_dataset.csv' - df = read_csv_pgbar(file_path,1000) - texts = df['X'].tolist() tags = df['Y1'].tolist() @@ -16,9 +15,9 @@ tags_s = tags[:4] texts = texts[4:] tags = tags[4:] - index_to_super = load_obj('Datasets/index_to_pos1') +#### MODEL #### tagger = SuperTagger() tagger.create_new_model(len(index_to_super),'camembert-base',index_to_super) @@ -26,6 +25,8 @@ tagger.create_new_model(len(index_to_super),'camembert-base',index_to_super) tagger.train(texts,tags,batch_size=16,validation_rate=0.1,tensorboard=True,checkpoint=True) + +#### TEST #### pred = tagger.predict(test_s) print(test_s) -- GitLab