diff --git a/Linker/Linker.py b/Linker/Linker.py index 4eb4f75347efc497c88897648ddc425571980c60..167c923af4e2035cbeb3a2f3ff9aa27421363618 100644 --- a/Linker/Linker.py +++ b/Linker/Linker.py @@ -35,9 +35,6 @@ class Linker(Module): self.max_atoms_in_sentence = int(Configuration.datasetConfig['max_atoms_in_sentence']) self.max_atoms_in_one_type = int(Configuration.datasetConfig['max_atoms_in_one_type']) self.atom_vocab_size = int(Configuration.datasetConfig['atom_vocab_size']) - batch_size = int(Configuration.modelTrainingConfig['batch_size']) - nb_sentences = batch_size * 10 - self.epochs = int(Configuration.modelTrainingConfig['epoch']) learning_rate = float(Configuration.modelTrainingConfig['learning_rate']) self.dropout = Dropout(0.1) self.device = "" @@ -73,7 +70,6 @@ class Linker(Module): atoms_polarity_batch = find_pos_neg_idexes(self.max_atoms_in_sentence, df_axiom_links["sub_tree"]) - torch.set_printoptions(edgeitems=20) truth_links_batch = get_axiom_links(self.max_atoms_in_one_type, atoms_polarity_batch, df_axiom_links["sub_tree"]) truth_links_batch = truth_links_batch.permute(1, 0, 2) @@ -147,16 +143,14 @@ class Linker(Module): return torch.stack(link_weights) - def train_linker(self, df_axiom_links, validation_rate=0.1, epochs=20, batch_size=32): + def train_linker(self, df_axiom_links, validation_rate=0.1, epochs=20, batch_size=32, checkpoint=True, validate=True): training_dataloader, validation_dataloader = self.__preprocess_data(batch_size, df_axiom_links, validation_rate) - epochs = epochs - self.epochs - self.train() for epoch_i in range(0, epochs): - epoch_acc, epoch_loss = self.__train_epoch(training_dataloader, validation_dataloader) + epoch_acc, epoch_loss = self.train_epoch(training_dataloader, validation_dataloader) - def __train_epoch(self, training_dataloader, validation_dataloader, checkpoint=True, validate=True): + def train_epoch(self, training_dataloader, validation_dataloader, checkpoint=True, validate=True): # Reset the total loss for this epoch. epoch_loss = 0 @@ -198,7 +192,6 @@ class Linker(Module): self.__checkpoint_save(path=os.path.join(checkpoint_dir, 'model_check.pt')) if validate: - self.eval() with torch.no_grad(): accuracy, average_test_loss = self.eval_epoch(validation_dataloader, self.cross_entropy_loss) @@ -215,8 +208,6 @@ class Linker(Module): ''' self.eval() - batch_size, len_sentence, sents_embedding_dim = sents_embedding.shape - # get atoms atoms_batch = get_atoms_batch(categories) atoms_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms_batch)