From 72199af19e38639caaf8f3730192edea450f6852 Mon Sep 17 00:00:00 2001 From: PNRIA - Julien <julien.rabault@irit.fr> Date: Wed, 11 May 2022 15:47:04 +0200 Subject: [PATCH] V0.9 --- SuperTagger/Utils/SymbolTokenizer.py | 2 +- main.py | 21 ++++++++++----------- train.py | 3 ++- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/SuperTagger/Utils/SymbolTokenizer.py b/SuperTagger/Utils/SymbolTokenizer.py index 62228c9..eaf5ae4 100644 --- a/SuperTagger/Utils/SymbolTokenizer.py +++ b/SuperTagger/Utils/SymbolTokenizer.py @@ -37,7 +37,7 @@ class SymbolTokenizer(): def pad_sequence(sequences, max_len=400): padded = [0] * max_len - padded[1:len(sequences)+1] = sequences + padded[:len(sequences)] = sequences return padded diff --git a/main.py b/main.py index bc66aec..f8833e5 100644 --- a/main.py +++ b/main.py @@ -17,13 +17,12 @@ def load_obj(name): file_path = 'Datasets/m2_dataset_V2.csv' -df = read_csv_pgbar(file_path,1000) +df = read_csv_pgbar(file_path,100) texts = df['X'].tolist() tags = df['Z'].tolist() -# texts = texts[12650:12800] -# tags = tags[12650:12800] -print(len(tags)) +texts = texts[98:99] +tags = tags[98:99] tagger = SuperTagger() @@ -47,13 +46,13 @@ tagger.load_weights("models/model_check.pt") pred, pred_convert = tagger.predict(texts) # -# print(texts) -# print() -# print(tags) -# print() -# print(pred) -# print() -# print(pred_convert) +print(texts) +print() +print(tags) +print() +print(pred) +print() +print(pred_convert) def categorical_accuracy(preds, truth): diff --git a/train.py b/train.py index daeef4c..551368a 100644 --- a/train.py +++ b/train.py @@ -29,7 +29,8 @@ super_to_index = {v: int(k) for k, v in index_to_super.items()} tagger = SuperTagger() -tagger.create_new_model(len(index_to_super),'camembert-base',index_to_super) +# tagger.create_new_model(len(index_to_super),'camembert-base',index_to_super) +tagger.load_weights("models/model_check.pt") tagger.train(texts,tags,validation_rate=0.1,tensorboard=True,checkpoint=True) -- GitLab