diff --git a/SuperTagger/Utils/SymbolTokenizer.py b/SuperTagger/Utils/SymbolTokenizer.py index 62228c94621372a0ed19e8e8d8ebf41516df5552..eaf5ae4ba7526eaaa3c1f52d0caa8fc7439646c5 100644 --- a/SuperTagger/Utils/SymbolTokenizer.py +++ b/SuperTagger/Utils/SymbolTokenizer.py @@ -37,7 +37,7 @@ class SymbolTokenizer(): def pad_sequence(sequences, max_len=400): padded = [0] * max_len - padded[1:len(sequences)+1] = sequences + padded[:len(sequences)] = sequences return padded diff --git a/main.py b/main.py index bc66aec3e7b48ee90d5ce83b52955f2004ce363b..f8833e517a5f8fdee2691a544a94e4b352f49134 100644 --- a/main.py +++ b/main.py @@ -17,13 +17,12 @@ def load_obj(name): file_path = 'Datasets/m2_dataset_V2.csv' -df = read_csv_pgbar(file_path,1000) +df = read_csv_pgbar(file_path,100) texts = df['X'].tolist() tags = df['Z'].tolist() -# texts = texts[12650:12800] -# tags = tags[12650:12800] -print(len(tags)) +texts = texts[98:99] +tags = tags[98:99] tagger = SuperTagger() @@ -47,13 +46,13 @@ tagger.load_weights("models/model_check.pt") pred, pred_convert = tagger.predict(texts) # -# print(texts) -# print() -# print(tags) -# print() -# print(pred) -# print() -# print(pred_convert) +print(texts) +print() +print(tags) +print() +print(pred) +print() +print(pred_convert) def categorical_accuracy(preds, truth): diff --git a/train.py b/train.py index daeef4c30f968b6f9d9db7c595fe93e1fb3363ec..551368aa3e16edabf7932d35b9165880f1c4e6f3 100644 --- a/train.py +++ b/train.py @@ -29,7 +29,8 @@ super_to_index = {v: int(k) for k, v in index_to_super.items()} tagger = SuperTagger() -tagger.create_new_model(len(index_to_super),'camembert-base',index_to_super) +# tagger.create_new_model(len(index_to_super),'camembert-base',index_to_super) +tagger.load_weights("models/model_check.pt") tagger.train(texts,tags,validation_rate=0.1,tensorboard=True,checkpoint=True)