Skip to content
Snippets Groups Projects
train.py 951 B
import torch
from Configuration import Configuration
from Linker import *
from Supertagger import *
from utils import read_csv_pgbar

torch.cuda.empty_cache()
batch_size = int(Configuration.modelTrainingConfig['batch_size'])
nb_sentences = batch_size * 200
epochs = int(Configuration.modelTrainingConfig['epoch'])

file_path_axiom_links = 'Datasets/goldANDsilver_dataset_links.csv'
df_axiom_links = read_csv_pgbar(file_path_axiom_links, nb_sentences)

sentences_batch = df_axiom_links["Sentences"].tolist()
supertagger = SuperTagger()
supertagger.load_weights("models/model_supertagger.pt")
sents_tokenized, sents_mask = supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch)

print("Linker")
linker = Linker(supertagger)
linker = linker.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
print("Linker Training")
linker.train_linker(df_axiom_links, validation_rate=0.1, epochs=epochs, batch_size=batch_size, checkpoint=True)