Skip to content
Snippets Groups Projects
Commit c25dcea8 authored by Caroline DE POURTALES's avatar Caroline DE POURTALES
Browse files

utils

parent 4a477cae
Branches
No related tags found
2 merge requests!6Linker with transformer,!5Linker with transformer
...@@ -68,7 +68,7 @@ class Linker(Module): ...@@ -68,7 +68,7 @@ class Linker(Module):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def __preprocess_data(self, batch_size, df_axiom_links, sentences_tokens, sentences_mask, validation_rate=0.0): def __preprocess_data(self, batch_size, df_axiom_links, validation_rate=0.0):
r""" r"""
Args: Args:
batch_size : int batch_size : int
...@@ -79,6 +79,9 @@ class Linker(Module): ...@@ -79,6 +79,9 @@ class Linker(Module):
Returns: Returns:
the training dataloader and the validation dataloader. They contains the list of atoms, their polarities, the axiom links, the sentences tokenized, sentence mask the training dataloader and the validation dataloader. They contains the list of atoms, their polarities, the axiom links, the sentences tokenized, sentence mask
""" """
sentences_batch = df_axiom_links["Sentences"].tolist()
sentences_tokens, sentences_mask = self.supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch)
atoms_batch = get_atoms_batch(df_axiom_links["sub_tree"]) atoms_batch = get_atoms_batch(df_axiom_links["sub_tree"])
atom_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence) atom_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence)
atoms_batch_tokenized = atom_tokenizer.convert_batchs_to_ids(atoms_batch) atoms_batch_tokenized = atom_tokenizer.convert_batchs_to_ids(atoms_batch)
...@@ -154,7 +157,7 @@ class Linker(Module): ...@@ -154,7 +157,7 @@ class Linker(Module):
return F.log_softmax(link_weights_per_batch, dim=3) return F.log_softmax(link_weights_per_batch, dim=3)
def train_linker(self, df_axiom_links, sentences_tokens, sentences_mask, validation_rate=0.1, epochs=20, def train_linker(self, df_axiom_links, validation_rate=0.1, epochs=20,
batch_size=32, checkpoint=True, validate=True): batch_size=32, checkpoint=True, validate=True):
r""" r"""
Args: Args:
...@@ -170,7 +173,6 @@ class Linker(Module): ...@@ -170,7 +173,6 @@ class Linker(Module):
Final accuracy and final loss Final accuracy and final loss
""" """
training_dataloader, validation_dataloader = self.__preprocess_data(batch_size, df_axiom_links, training_dataloader, validation_dataloader = self.__preprocess_data(batch_size, df_axiom_links,
sentences_tokens, sentences_mask,
validation_rate) validation_rate)
self.to(self.device) self.to(self.device)
for epoch_i in range(0, epochs): for epoch_i in range(0, epochs):
......
...@@ -21,4 +21,4 @@ print("Linker") ...@@ -21,4 +21,4 @@ print("Linker")
linker = Linker(supertagger) linker = Linker(supertagger)
linker = linker.to(torch.device("cuda" if torch.cuda.is_available() else "cpu")) linker = linker.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
print("Linker Training") print("Linker Training")
linker.train_linker(df_axiom_links, sents_tokenized, sents_mask, validation_rate=0.1, epochs=epochs, batch_size=batch_size, checkpoint=True, validate=True) linker.train_linker(df_axiom_links, validation_rate=0.1, epochs=epochs, batch_size=batch_size, checkpoint=True, validate=True)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment