Skip to content
Snippets Groups Projects
Commit a4a0a1b7 authored by Caroline DE POURTALES's avatar Caroline DE POURTALES
Browse files

adding training methods

parent df0c26ca
No related branches found
No related tags found
2 merge requests!6Linker with transformer,!5Linker with transformer
......@@ -35,9 +35,6 @@ class Linker(Module):
self.max_atoms_in_sentence = int(Configuration.datasetConfig['max_atoms_in_sentence'])
self.max_atoms_in_one_type = int(Configuration.datasetConfig['max_atoms_in_one_type'])
self.atom_vocab_size = int(Configuration.datasetConfig['atom_vocab_size'])
batch_size = int(Configuration.modelTrainingConfig['batch_size'])
nb_sentences = batch_size * 10
self.epochs = int(Configuration.modelTrainingConfig['epoch'])
learning_rate = float(Configuration.modelTrainingConfig['learning_rate'])
self.dropout = Dropout(0.1)
self.device = ""
......@@ -73,7 +70,6 @@ class Linker(Module):
atoms_polarity_batch = find_pos_neg_idexes(self.max_atoms_in_sentence, df_axiom_links["sub_tree"])
torch.set_printoptions(edgeitems=20)
truth_links_batch = get_axiom_links(self.max_atoms_in_one_type, atoms_polarity_batch,
df_axiom_links["sub_tree"])
truth_links_batch = truth_links_batch.permute(1, 0, 2)
......@@ -147,16 +143,14 @@ class Linker(Module):
return torch.stack(link_weights)
def train_linker(self, df_axiom_links, validation_rate=0.1, epochs=20, batch_size=32):
def train_linker(self, df_axiom_links, validation_rate=0.1, epochs=20, batch_size=32, checkpoint=True, validate=True):
training_dataloader, validation_dataloader = self.__preprocess_data(batch_size, df_axiom_links, validation_rate)
epochs = epochs - self.epochs
self.train()
for epoch_i in range(0, epochs):
epoch_acc, epoch_loss = self.__train_epoch(training_dataloader, validation_dataloader)
epoch_acc, epoch_loss = self.train_epoch(training_dataloader, validation_dataloader)
def __train_epoch(self, training_dataloader, validation_dataloader, checkpoint=True, validate=True):
def train_epoch(self, training_dataloader, validation_dataloader, checkpoint=True, validate=True):
# Reset the total loss for this epoch.
epoch_loss = 0
......@@ -198,7 +192,6 @@ class Linker(Module):
self.__checkpoint_save(path=os.path.join(checkpoint_dir, 'model_check.pt'))
if validate:
self.eval()
with torch.no_grad():
accuracy, average_test_loss = self.eval_epoch(validation_dataloader, self.cross_entropy_loss)
......@@ -215,8 +208,6 @@ class Linker(Module):
'''
self.eval()
batch_size, len_sentence, sents_embedding_dim = sents_embedding.shape
# get atoms
atoms_batch = get_atoms_batch(categories)
atoms_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms_batch)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment