diff --git a/Linker/Linker.py b/Linker/Linker.py index ce256406be8271ec9a951af3f08f7cd7995c7a5a..396639195b6381ad699457dbfe78899af4777e9c 100644 --- a/Linker/Linker.py +++ b/Linker/Linker.py @@ -68,7 +68,7 @@ class Linker(Module): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - def __preprocess_data(self, batch_size, df_axiom_links, sentences_tokens, sentences_mask, validation_rate=0.0): + def __preprocess_data(self, batch_size, df_axiom_links, validation_rate=0.0): r""" Args: batch_size : int @@ -79,6 +79,9 @@ class Linker(Module): Returns: the training dataloader and the validation dataloader. They contains the list of atoms, their polarities, the axiom links, the sentences tokenized, sentence mask """ + sentences_batch = df_axiom_links["Sentences"].tolist() + sentences_tokens, sentences_mask = self.supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch) + atoms_batch = get_atoms_batch(df_axiom_links["sub_tree"]) atom_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence) atoms_batch_tokenized = atom_tokenizer.convert_batchs_to_ids(atoms_batch) @@ -154,7 +157,7 @@ class Linker(Module): return F.log_softmax(link_weights_per_batch, dim=3) - def train_linker(self, df_axiom_links, sentences_tokens, sentences_mask, validation_rate=0.1, epochs=20, + def train_linker(self, df_axiom_links, validation_rate=0.1, epochs=20, batch_size=32, checkpoint=True, validate=True): r""" Args: @@ -170,7 +173,6 @@ class Linker(Module): Final accuracy and final loss """ training_dataloader, validation_dataloader = self.__preprocess_data(batch_size, df_axiom_links, - sentences_tokens, sentences_mask, validation_rate) self.to(self.device) for epoch_i in range(0, epochs): diff --git a/train.py b/train.py index 8505431ecabdd47b68b8cf232f58f92ea5697ea4..4684d3b7fb816ceb1b04b37aaa40add606db6871 100644 --- a/train.py +++ b/train.py @@ -21,4 +21,4 @@ print("Linker") linker = Linker(supertagger) linker = linker.to(torch.device("cuda" if torch.cuda.is_available() else "cpu")) print("Linker Training") -linker.train_linker(df_axiom_links, sents_tokenized, sents_mask, validation_rate=0.1, epochs=epochs, batch_size=batch_size, checkpoint=True, validate=True) +linker.train_linker(df_axiom_links, validation_rate=0.1, epochs=epochs, batch_size=batch_size, checkpoint=True, validate=True)