Skip to content
Snippets Groups Projects
Commit ebcd90cd authored by Caroline DE POURTALES's avatar Caroline DE POURTALES
Browse files

adding class neuralproofnet and good config

parent 76b70555
Branches
No related tags found
1 merge request!1Add neural proof net class
......@@ -24,8 +24,8 @@ sinkhorn_iters = 5
[MODEL_TRAINING]
batch_size = 32
pretrain_linker_epochs = 1
epoch = 1
pretrain_linker_epochs = 10
epoch = 20
seed_val = 42
learning_rate = 2e-3
......@@ -67,7 +67,6 @@ class NeuralProofNet(Module):
if linker_path_model is not None:
linker.load_weights(linker_path_model)
self.linker = linker
self.Supertagger = self.linker.Supertagger
# Learning
self.linker_loss = SinkhornLoss()
......@@ -96,7 +95,7 @@ class NeuralProofNet(Module):
"""
print("Start preprocess Data")
sentences_batch = df_axiom_links["X"].str.strip().tolist()
sentences_tokens, sentences_mask = self.Supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch)
sentences_tokens, sentences_mask = self.linker.Supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch)
_, polarities, _ = get_GOAL(self.max_len_sentence, df_axiom_links)
atoms_polarity_batch = pad_sequence(
......@@ -126,10 +125,10 @@ class NeuralProofNet(Module):
def forward(self, batch_sentences_tokens, batch_sentences_mask):
# get sentence embedding from BERT which is already trained
output = self.Supertagger.forward(batch_sentences_tokens, batch_sentences_mask)
output = self.linker.Supertagger.forward(batch_sentences_tokens, batch_sentences_mask)
last_hidden_state = output['logit']
pred_categories = torch.argmax(torch.softmax(last_hidden_state, dim=2), dim=2)
pred_categories = self.Supertagger.tags_tokenizer.convert_ids_to_tags(pred_categories)
pred_categories = self.linker.Supertagger.tags_tokenizer.convert_ids_to_tags(pred_categories)
# get information from tagger predictions
atoms_batch, polarities, batch_num_atoms_per_word = get_info_for_tagger(self.max_len_sentence, pred_categories)
......@@ -140,6 +139,11 @@ class NeuralProofNet(Module):
batch_pos_idx = get_pos_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type)
batch_neg_idx = get_neg_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type)
batch_num_atoms_per_word = batch_num_atoms_per_word.to(self.device)
atoms_batch_tokenized = atoms_batch_tokenized.to(self.device)
batch_pos_idx = batch_pos_idx.to(self.device)
batch_neg_idx = batch_neg_idx.to(self.device)
logits_links = self.linker(batch_num_atoms_per_word, atoms_batch_tokenized, batch_pos_idx, batch_neg_idx,
output['word_embeding'])
......
......@@ -6,9 +6,8 @@ from utils import read_csv_pgbar
from find_config import configurate
from Configuration import Configuration
torch.cuda.empty_cache()
nb_sentences = 100000000
nb_sentences = 1000000000
file_path_axiom_links = 'Datasets/goldANDsilver_dataset_links.csv'
model_tagger = "models/flaubert_super_98_V2_50e.pt"
......@@ -29,7 +28,7 @@ df_axiom_links = read_csv_pgbar(file_path_axiom_links, nb_sentences)
print("#" * 20)
print("#" * 20)
neural_proof_net = NeuralProofNet(model_tagger)
neural_proof_net.train_neuralproofnet(df_axiom_links, validation_rate=0.05, epochs=epochs, batch_size=batch_size,
checkpoint=True, tensorboard=True)
neural_proof_net.train_neuralproofnet(df_axiom_links, validation_rate=0.1, epochs=epochs, batch_size=batch_size,
checkpoint=True, tensorboard=True)
print("#" * 20)
print("#" * 20)
print("#" * 20)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment