Skip to content
Snippets Groups Projects
train.py 766 B
from SuperTagger.SuperTagger import SuperTagger
from SuperTagger.Utils.utils import read_csv_pgbar


def load_obj(name):
    with open(name + '.pkl', 'rb') as f:
        import pickle
        return pickle.load(f)


file_path = 'Datasets/m2_dataset_V2.csv'


df = read_csv_pgbar(file_path,100)


texts = df['X'].tolist()
tags = df['Z'].tolist()

test_s = texts[:4]
tags_s = tags[:4]

texts = texts[4:]
tags = tags[4:]


index_to_super = load_obj('Datasets/index_to_super')
super_to_index = {v: int(k) for k, v in index_to_super.items()}

tagger = SuperTagger()

tagger.create_new_model(len(index_to_super),'camembert-base',index_to_super)

tagger.train(texts,tags,tensorboard=True,checkpoint=True)

pred = tagger.predict(test_s)

print(test_s)
print()
print(pred)