Skip to content
Snippets Groups Projects
Commit d7a164e5 authored by Julien Rabault's avatar Julien Rabault
Browse files

remove bug

parent a76e9380
Branches
No related tags found
2 merge requests!6Linker with transformer,!5Linker with transformer
......@@ -6,3 +6,4 @@ Utils/gold
Linker/__pycache__
Configuration/__pycache__
__pycache__
TensorBoard
......@@ -81,7 +81,7 @@ class Linker(Module):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def __preprocess_data(self, batch_size, df_axiom_links, validation_rate=0.0):
def __preprocess_data(self, batch_size, df_axiom_links, validation_rate=0.1):
r"""
Args:
batch_size : int
......@@ -106,7 +106,7 @@ class Linker(Module):
dataset = TensorDataset(atoms_batch_tokenized, atoms_polarity_batch, truth_links_batch, sentences_tokens,
sentences_mask)
if validation_rate > 0:
if validation_rate > 0.0:
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
......
......@@ -6,12 +6,9 @@ from utils import read_csv_pgbar
torch.cuda.empty_cache()
batch_size = int(Configuration.modelTrainingConfig['batch_size'])
nb_sentences = batch_size * 20
nb_sentences = batch_size * 400
epochs = int(Configuration.modelTrainingConfig['epoch'])
file_path_axiom_links = 'Datasets/goldANDsilver_dataset_links.csv'
nb_sentences = batch_size * 20
epochs = int(Configuration.modelTrainingConfig['epoch'])
file_path_axiom_links = 'Datasets/gold_dataset_links.csv'
df_axiom_links = read_csv_pgbar(file_path_axiom_links, nb_sentences)
......@@ -21,10 +18,8 @@ supertagger = SuperTagger()
supertagger.load_weights("models/flaubert_super_98%_V2_50e.pt")
sents_tokenized, sents_mask = supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch)
print("Linker")
linker = Linker(supertagger)
linker = linker.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
print("Linker Training")
linker.train_linker(df_axiom_links, sents_tokenized, sents_mask, validation_rate=0.1, epochs=epochs, batch_size=batch_size, checkpoint=False, validate=True)
linker.train_linker(df_axiom_links, batch_size=batch_size, checkpoint=False, tensorboard=True)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment