Skip to content
Snippets Groups Projects
Commit c25dcea8 authored by Caroline DE POURTALES's avatar Caroline DE POURTALES
Browse files

utils

parent 4a477cae
No related branches found
No related tags found
2 merge requests!6Linker with transformer,!5Linker with transformer
......@@ -68,7 +68,7 @@ class Linker(Module):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def __preprocess_data(self, batch_size, df_axiom_links, sentences_tokens, sentences_mask, validation_rate=0.0):
def __preprocess_data(self, batch_size, df_axiom_links, validation_rate=0.0):
r"""
Args:
batch_size : int
......@@ -79,6 +79,9 @@ class Linker(Module):
Returns:
the training dataloader and the validation dataloader. They contains the list of atoms, their polarities, the axiom links, the sentences tokenized, sentence mask
"""
sentences_batch = df_axiom_links["Sentences"].tolist()
sentences_tokens, sentences_mask = self.supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch)
atoms_batch = get_atoms_batch(df_axiom_links["sub_tree"])
atom_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence)
atoms_batch_tokenized = atom_tokenizer.convert_batchs_to_ids(atoms_batch)
......@@ -154,7 +157,7 @@ class Linker(Module):
return F.log_softmax(link_weights_per_batch, dim=3)
def train_linker(self, df_axiom_links, sentences_tokens, sentences_mask, validation_rate=0.1, epochs=20,
def train_linker(self, df_axiom_links, validation_rate=0.1, epochs=20,
batch_size=32, checkpoint=True, validate=True):
r"""
Args:
......@@ -170,7 +173,6 @@ class Linker(Module):
Final accuracy and final loss
"""
training_dataloader, validation_dataloader = self.__preprocess_data(batch_size, df_axiom_links,
sentences_tokens, sentences_mask,
validation_rate)
self.to(self.device)
for epoch_i in range(0, epochs):
......
......@@ -21,4 +21,4 @@ print("Linker")
linker = Linker(supertagger)
linker = linker.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
print("Linker Training")
linker.train_linker(df_axiom_links, sents_tokenized, sents_mask, validation_rate=0.1, epochs=epochs, batch_size=batch_size, checkpoint=True, validate=True)
linker.train_linker(df_axiom_links, validation_rate=0.1, epochs=epochs, batch_size=batch_size, checkpoint=True, validate=True)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment