Skip to content
Snippets Groups Projects
Commit a4a0a1b7 authored by Caroline DE POURTALES's avatar Caroline DE POURTALES
Browse files

adding training methods

parent df0c26ca
Branches
No related tags found
2 merge requests!6Linker with transformer,!5Linker with transformer
...@@ -35,9 +35,6 @@ class Linker(Module): ...@@ -35,9 +35,6 @@ class Linker(Module):
self.max_atoms_in_sentence = int(Configuration.datasetConfig['max_atoms_in_sentence']) self.max_atoms_in_sentence = int(Configuration.datasetConfig['max_atoms_in_sentence'])
self.max_atoms_in_one_type = int(Configuration.datasetConfig['max_atoms_in_one_type']) self.max_atoms_in_one_type = int(Configuration.datasetConfig['max_atoms_in_one_type'])
self.atom_vocab_size = int(Configuration.datasetConfig['atom_vocab_size']) self.atom_vocab_size = int(Configuration.datasetConfig['atom_vocab_size'])
batch_size = int(Configuration.modelTrainingConfig['batch_size'])
nb_sentences = batch_size * 10
self.epochs = int(Configuration.modelTrainingConfig['epoch'])
learning_rate = float(Configuration.modelTrainingConfig['learning_rate']) learning_rate = float(Configuration.modelTrainingConfig['learning_rate'])
self.dropout = Dropout(0.1) self.dropout = Dropout(0.1)
self.device = "" self.device = ""
...@@ -73,7 +70,6 @@ class Linker(Module): ...@@ -73,7 +70,6 @@ class Linker(Module):
atoms_polarity_batch = find_pos_neg_idexes(self.max_atoms_in_sentence, df_axiom_links["sub_tree"]) atoms_polarity_batch = find_pos_neg_idexes(self.max_atoms_in_sentence, df_axiom_links["sub_tree"])
torch.set_printoptions(edgeitems=20)
truth_links_batch = get_axiom_links(self.max_atoms_in_one_type, atoms_polarity_batch, truth_links_batch = get_axiom_links(self.max_atoms_in_one_type, atoms_polarity_batch,
df_axiom_links["sub_tree"]) df_axiom_links["sub_tree"])
truth_links_batch = truth_links_batch.permute(1, 0, 2) truth_links_batch = truth_links_batch.permute(1, 0, 2)
...@@ -147,16 +143,14 @@ class Linker(Module): ...@@ -147,16 +143,14 @@ class Linker(Module):
return torch.stack(link_weights) return torch.stack(link_weights)
def train_linker(self, df_axiom_links, validation_rate=0.1, epochs=20, batch_size=32): def train_linker(self, df_axiom_links, validation_rate=0.1, epochs=20, batch_size=32, checkpoint=True, validate=True):
training_dataloader, validation_dataloader = self.__preprocess_data(batch_size, df_axiom_links, validation_rate) training_dataloader, validation_dataloader = self.__preprocess_data(batch_size, df_axiom_links, validation_rate)
epochs = epochs - self.epochs
self.train()
for epoch_i in range(0, epochs): for epoch_i in range(0, epochs):
epoch_acc, epoch_loss = self.__train_epoch(training_dataloader, validation_dataloader) epoch_acc, epoch_loss = self.train_epoch(training_dataloader, validation_dataloader)
def __train_epoch(self, training_dataloader, validation_dataloader, checkpoint=True, validate=True): def train_epoch(self, training_dataloader, validation_dataloader, checkpoint=True, validate=True):
# Reset the total loss for this epoch. # Reset the total loss for this epoch.
epoch_loss = 0 epoch_loss = 0
...@@ -198,7 +192,6 @@ class Linker(Module): ...@@ -198,7 +192,6 @@ class Linker(Module):
self.__checkpoint_save(path=os.path.join(checkpoint_dir, 'model_check.pt')) self.__checkpoint_save(path=os.path.join(checkpoint_dir, 'model_check.pt'))
if validate: if validate:
self.eval()
with torch.no_grad(): with torch.no_grad():
accuracy, average_test_loss = self.eval_epoch(validation_dataloader, self.cross_entropy_loss) accuracy, average_test_loss = self.eval_epoch(validation_dataloader, self.cross_entropy_loss)
...@@ -215,8 +208,6 @@ class Linker(Module): ...@@ -215,8 +208,6 @@ class Linker(Module):
''' '''
self.eval() self.eval()
batch_size, len_sentence, sents_embedding_dim = sents_embedding.shape
# get atoms # get atoms
atoms_batch = get_atoms_batch(categories) atoms_batch = get_atoms_batch(categories)
atoms_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms_batch) atoms_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms_batch)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment