From 72199af19e38639caaf8f3730192edea450f6852 Mon Sep 17 00:00:00 2001
From: PNRIA - Julien <julien.rabault@irit.fr>
Date: Wed, 11 May 2022 15:47:04 +0200
Subject: [PATCH] V0.9

---
 SuperTagger/Utils/SymbolTokenizer.py |  2 +-
 main.py                              | 21 ++++++++++-----------
 train.py                             |  3 ++-
 3 files changed, 13 insertions(+), 13 deletions(-)

diff --git a/SuperTagger/Utils/SymbolTokenizer.py b/SuperTagger/Utils/SymbolTokenizer.py
index 62228c9..eaf5ae4 100644
--- a/SuperTagger/Utils/SymbolTokenizer.py
+++ b/SuperTagger/Utils/SymbolTokenizer.py
@@ -37,7 +37,7 @@ class SymbolTokenizer():
 
 def pad_sequence(sequences, max_len=400):
     padded = [0] * max_len
-    padded[1:len(sequences)+1] = sequences
+    padded[:len(sequences)] = sequences
     return padded
 
 
diff --git a/main.py b/main.py
index bc66aec..f8833e5 100644
--- a/main.py
+++ b/main.py
@@ -17,13 +17,12 @@ def load_obj(name):
 
 file_path = 'Datasets/m2_dataset_V2.csv'
 
-df = read_csv_pgbar(file_path,1000)
+df = read_csv_pgbar(file_path,100)
 
 texts = df['X'].tolist()
 tags = df['Z'].tolist()
-# texts = texts[12650:12800]
-# tags = tags[12650:12800]
-print(len(tags))
+texts = texts[98:99]
+tags = tags[98:99]
 
 tagger = SuperTagger()
 
@@ -47,13 +46,13 @@ tagger.load_weights("models/model_check.pt")
 
 pred, pred_convert = tagger.predict(texts)
 #
-# print(texts)
-# print()
-# print(tags)
-# print()
-# print(pred)
-# print()
-# print(pred_convert)
+print(texts)
+print()
+print(tags)
+print()
+print(pred)
+print()
+print(pred_convert)
 
 
 def categorical_accuracy(preds, truth):
diff --git a/train.py b/train.py
index daeef4c..551368a 100644
--- a/train.py
+++ b/train.py
@@ -29,7 +29,8 @@ super_to_index = {v: int(k) for k, v in index_to_super.items()}
 
 tagger = SuperTagger()
 
-tagger.create_new_model(len(index_to_super),'camembert-base',index_to_super)
+# tagger.create_new_model(len(index_to_super),'camembert-base',index_to_super)
+tagger.load_weights("models/model_check.pt")
 
 tagger.train(texts,tags,validation_rate=0.1,tensorboard=True,checkpoint=True)
 
-- 
GitLab