Skip to content
Snippets Groups Projects
Commit e6ae31ff authored by Caroline DE POURTALES's avatar Caroline DE POURTALES
Browse files

it runs, some corrections needed next

parent 4a11d176
No related branches found
No related tags found
2 merge requests!6Linker with transformer,!5Linker with transformer
...@@ -6,7 +6,7 @@ symbols_vocab_size=26 ...@@ -6,7 +6,7 @@ symbols_vocab_size=26
atom_vocab_size=20 atom_vocab_size=20
max_len_sentence=148 max_len_sentence=148
max_atoms_in_sentence=1250 max_atoms_in_sentence=1250
max_atoms_in_one_type=160 max_atoms_in_one_type=250
[MODEL_ENCODER] [MODEL_ENCODER]
dim_encoder = 768 dim_encoder = 768
......
This diff is collapsed.
...@@ -115,7 +115,6 @@ class Linker(Module): ...@@ -115,7 +115,6 @@ class Linker(Module):
logits_axiom_links_pred = self.forward(batch_atoms, batch_polarity, []) logits_axiom_links_pred = self.forward(batch_atoms, batch_polarity, [])
logits_axiom_links_pred = logits_axiom_links_pred.permute(1, 0, 2, 3) logits_axiom_links_pred = logits_axiom_links_pred.permute(1, 0, 2, 3)
axiom_links_pred = torch.argmax(F.softmax(logits_axiom_links_pred, dim=3), dim=3) axiom_links_pred = torch.argmax(F.softmax(logits_axiom_links_pred, dim=3), dim=3)
print(axiom_links_pred.shape)
accuracy = mesure_accuracy(batch_true_links, axiom_links_pred) accuracy = mesure_accuracy(batch_true_links, axiom_links_pred)
loss = cross_entropy_loss(logits_axiom_links_pred, batch_true_links) loss = cross_entropy_loss(logits_axiom_links_pred, batch_true_links)
......
No preview for this file type
No preview for this file type
...@@ -28,4 +28,4 @@ def mesure_accuracy(linking_plus_to_minus, axiom_links_pred): ...@@ -28,4 +28,4 @@ def mesure_accuracy(linking_plus_to_minus, axiom_links_pred):
num_correct_links = correct_links.sum().item() num_correct_links = correct_links.sum().item()
# diviser par nombre de links # diviser par nombre de links
return num_correct_links return num_correct_links/(axiom_links_pred.size()[0] * axiom_links_pred.size()[1] * axiom_links_pred.size()[2])
...@@ -31,7 +31,7 @@ atom_vocab_size = int(Configuration.datasetConfig['atom_vocab_size']) ...@@ -31,7 +31,7 @@ atom_vocab_size = int(Configuration.datasetConfig['atom_vocab_size'])
# region ParamsTraining # region ParamsTraining
batch_size = int(Configuration.modelTrainingConfig['batch_size']) batch_size = int(Configuration.modelTrainingConfig['batch_size'])
nb_sentences = 2 nb_sentences = batch_size * 2
epochs = int(Configuration.modelTrainingConfig['epoch']) epochs = int(Configuration.modelTrainingConfig['epoch'])
seed_val = int(Configuration.modelTrainingConfig['seed_val']) seed_val = int(Configuration.modelTrainingConfig['seed_val'])
learning_rate = float(Configuration.modelTrainingConfig['learning_rate']) learning_rate = float(Configuration.modelTrainingConfig['learning_rate'])
...@@ -41,11 +41,8 @@ learning_rate = float(Configuration.modelTrainingConfig['learning_rate']) ...@@ -41,11 +41,8 @@ learning_rate = float(Configuration.modelTrainingConfig['learning_rate'])
# region Data loader # region Data loader
file_path_axiom_links = 'Datasets/aa1_links_dataset_links.csv' file_path_axiom_links = 'Datasets/aa1_links_dataset_links.csv'
df_axiom_links = read_csv_pgbar(file_path_axiom_links, nb_sentences) df_axiom_links = read_csv_pgbar(file_path_axiom_links, nb_sentences)
torch.set_printoptions(threshold=1250)
atoms_batch = get_atoms_batch(df_axiom_links["sub_tree"]) atoms_batch = get_atoms_batch(df_axiom_links["sub_tree"])
atom_tokenizer = AtomTokenizer(atom_map, max_atoms_in_sentence) atom_tokenizer = AtomTokenizer(atom_map, max_atoms_in_sentence)
atoms_batch_tokenized = atom_tokenizer.convert_batchs_to_ids(atoms_batch) atoms_batch_tokenized = atom_tokenizer.convert_batchs_to_ids(atoms_batch)
...@@ -145,13 +142,9 @@ def run_epochs(epochs): ...@@ -145,13 +142,9 @@ def run_epochs(epochs):
# Run the kinker on the categories predictions # Run the kinker on the categories predictions
logits_predictions = linker(batch_atoms, batch_polarity, []) logits_predictions = linker(batch_atoms, batch_polarity, [])
print(logits_predictions.permute(1, 0, 2, 3).shape) print(logits_predictions.permute(1, 0, 2, 3).shape)
print(logits_predictions.permute(1, 0, 2, 3).dtype)
print(batch_true_links.dtype)
linker_loss = cross_entropy_loss(logits_predictions.permute(1, 0, 2, 3), batch_true_links) linker_loss = cross_entropy_loss(logits_predictions.permute(1, 0, 2, 3), batch_true_links)
print("loss")
# Perform a backward pass to calculate the gradients. # Perform a backward pass to calculate the gradients.
total_train_loss += float(linker_loss) total_train_loss += float(linker_loss)
linker_loss.backward() linker_loss.backward()
...@@ -174,8 +167,8 @@ def run_epochs(epochs): ...@@ -174,8 +167,8 @@ def run_epochs(epochs):
print("Start eval") print("Start eval")
accuracy, loss = linker.eval_epoch(validation_dataloader, cross_entropy_loss) accuracy, loss = linker.eval_epoch(validation_dataloader, cross_entropy_loss)
print("") print("")
print(" Average accuracy sents on epoch: {0:.2f}".format(accuracy)) print(" Average accuracy on epoch: {0:.2f}".format(accuracy))
print(" Average accuracy atom on epoch: {0:.2f}".format(loss)) print(" Average loss on epoch: {0:.2f}".format(loss))
print("") print("")
print(" Average training loss: {0:.2f}".format(avg_train_loss)) print(" Average training loss: {0:.2f}".format(avg_train_loss))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment