-
Julien Rabault authoredJulien Rabault authored
train.py 766 B
from SuperTagger.SuperTagger import SuperTagger
from SuperTagger.Utils.utils import read_csv_pgbar
def load_obj(name):
with open(name + '.pkl', 'rb') as f:
import pickle
return pickle.load(f)
file_path = 'Datasets/m2_dataset_V2.csv'
df = read_csv_pgbar(file_path,100)
texts = df['X'].tolist()
tags = df['Z'].tolist()
test_s = texts[:4]
tags_s = tags[:4]
texts = texts[4:]
tags = tags[4:]
index_to_super = load_obj('Datasets/index_to_super')
super_to_index = {v: int(k) for k, v in index_to_super.items()}
tagger = SuperTagger()
tagger.create_new_model(len(index_to_super),'camembert-base',index_to_super)
tagger.train(texts,tags,tensorboard=True,checkpoint=True)
pred = tagger.predict(test_s)
print(test_s)
print()
print(pred)