From a7008a57b0f2b5faa3706086e509295c2aaec1e8 Mon Sep 17 00:00:00 2001
From: PNRIA - Julien <julien.rabault@irit.fr>
Date: Wed, 18 May 2022 10:18:40 +0200
Subject: [PATCH] Add predict and train file

---
 SuperTagger/SuperTagger.py                 |  9 +++---
 SuperTagger/Utils/{utils.py => helpers.py} |  0
 predict.py                                 | 35 ++++++++++++++++++++++
 train.py                                   |  9 +++---
 4 files changed, 45 insertions(+), 8 deletions(-)
 rename SuperTagger/Utils/{utils.py => helpers.py} (100%)
 create mode 100644 predict.py

diff --git a/SuperTagger/SuperTagger.py b/SuperTagger/SuperTagger.py
index b49485e..8e211f7 100644
--- a/SuperTagger/SuperTagger.py
+++ b/SuperTagger/SuperTagger.py
@@ -105,7 +105,8 @@ class SuperTagger:
         """
         self.trainable = False
 
-        print("#" * 15)
+        print("#" * 20)
+        print("\n Loading...")
         try:
             params = torch.load(model_file, map_location=self.device)
             args = params['args']
@@ -127,7 +128,7 @@ class SuperTagger:
         except Exception as e:
             print("\n/!\ Can't load checkpoint model /!\ because :\n\n " + str(e), file=sys.stderr)
             raise e
-        print("#" * 15)
+        print("#" * 20)
 
         self.model_load = True
         self.trainable = True
@@ -224,7 +225,7 @@ class SuperTagger:
 
         for epoch_i in range(0, epochs):
             print("")
-            print('======== Epoch {:} / {:} ========'.format(epoch_i, epochs))
+            print('======== Epoch {:} / {:} ========'.format(epoch_i+1, epochs))
             print('Training...')
 
             # Train
@@ -310,7 +311,7 @@ class SuperTagger:
                 targets = batch[2].to(self.device)
                 self.optimizer.zero_grad()
 
-                loss, logit = self.model((b_sents_tokenized, b_sents_mask, targets))
+                loss, logit, _ = self.model((b_sents_tokenized, b_sents_mask, targets))
 
                 predictions = torch.argmax(logit, dim=2).detach().cpu().numpy()
                 label_ids = targets.cpu().numpy()
diff --git a/SuperTagger/Utils/utils.py b/SuperTagger/Utils/helpers.py
similarity index 100%
rename from SuperTagger/Utils/utils.py
rename to SuperTagger/Utils/helpers.py
diff --git a/predict.py b/predict.py
new file mode 100644
index 0000000..5834d01
--- /dev/null
+++ b/predict.py
@@ -0,0 +1,35 @@
+from SuperTagger.SuperTagger import SuperTagger
+from SuperTagger.Utils.helpers import categorical_accuracy_str
+
+#### DATA ####
+
+a_s = "( 1 ) parmi les huit \" partants \" acquis ou potentiels , MM. Lacombe , Koehler et Laroze ne sont pas membres " \
+      "du PCF . "
+tags_s = [['let', 'dr(0,s,s)', 'let', 'dr(0,dr(0,s,s),np)', 'dr(0,np,n)', 'dr(0,n,n)', 'let', 'n', 'let', 'dl(0,n,n)',
+           'dr(0,dl(0,dl(0,n,n),dl(0,n,n)),dl(0,n,n))', 'dl(0,n,n)', 'let', 'dr(0,np,np)', 'np', 'dr(0,dl(0,np,np),np)',
+           'np', 'dr(0,dl(0,np,np),np)', 'np', 'dr(0,dl(0,np,s),dl(0,np,s))', 'dr(0,dl(0,np,s),np)', 'dl(1,s,s)', 'np',
+           'dr(0,dl(0,np,np),n)', 'n', 'dl(0,s,txt)']]
+
+#### MODEL ####
+tagger = SuperTagger()
+
+model = "models/flaubert_super_98%_V2_50e.pt"
+
+tagger.load_weights(model)
+
+#### TEST ####
+_, pred_convert, _ = tagger.predict(a_s)
+
+print("Model : ", model)
+
+print("\tLen Text           : ", len(a_s.split()))
+print("\tLen tags           : ", len(tags_s[0]))
+print("\tLen pred_convert   : ", len(pred_convert[0]))
+print()
+print("\tText               : ", a_s)
+print()
+print("\tTags               : ", tags_s[0])
+print()
+print("\tPred_convert       : ", pred_convert[0])
+print()
+print("\tScore              :", categorical_accuracy_str(pred_convert, tags_s))
diff --git a/train.py b/train.py
index ddcda4c..d3cb4ef 100644
--- a/train.py
+++ b/train.py
@@ -1,12 +1,11 @@
 from SuperTagger.SuperTagger import SuperTagger
-from SuperTagger.Utils.utils import read_csv_pgbar, load_obj
+from SuperTagger.Utils.helpers import read_csv_pgbar, load_obj
 
+#### DATA ####
 file_path = 'Datasets/m2_dataset.csv'
 
-
 df = read_csv_pgbar(file_path,1000)
 
-
 texts = df['X'].tolist()
 tags = df['Y1'].tolist()
 
@@ -16,9 +15,9 @@ tags_s = tags[:4]
 texts = texts[4:]
 tags = tags[4:]
 
-
 index_to_super = load_obj('Datasets/index_to_pos1')
 
+#### MODEL ####
 tagger = SuperTagger()
 
 tagger.create_new_model(len(index_to_super),'camembert-base',index_to_super)
@@ -26,6 +25,8 @@ tagger.create_new_model(len(index_to_super),'camembert-base',index_to_super)
 
 tagger.train(texts,tags,batch_size=16,validation_rate=0.1,tensorboard=True,checkpoint=True)
 
+
+#### TEST ####
 pred = tagger.predict(test_s)
 
 print(test_s)
-- 
GitLab