From d891f64e8b978f5bea953f51cfc8c9c85505dab6 Mon Sep 17 00:00:00 2001 From: Caroline DE POURTALES <caroline.de-pourtales@irit.fr> Date: Fri, 24 Mar 2023 18:04:57 +0100 Subject: [PATCH] contrect code and simplify --- Linker/Linker.py | 249 +++---------------------------- Linker/eval.py | 15 +- NeuralProofNet/NeuralProofNet.py | 230 ++++++++++++++++++++++------ predict_links.py | 15 +- train_neuralproofnet.py | 6 +- 5 files changed, 223 insertions(+), 292 deletions(-) diff --git a/Linker/Linker.py b/Linker/Linker.py index 5002f9c..80fbfa1 100644 --- a/Linker/Linker.py +++ b/Linker/Linker.py @@ -28,7 +28,7 @@ class Linker(Module): # region initialization - def __init__(self, supertagger_path_model): + def __init__(self): super(Linker, self).__init__() # region parameters @@ -58,12 +58,6 @@ class Linker(Module): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # endregion - # SuperTagger for categories - supertagger = SuperTagger() - supertagger.load_weights(supertagger_path_model) - self.Supertagger = supertagger - self.Supertagger.model.to(self.device) - # Atoms embedding self.atoms_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence) self.atom_map_redux = atom_map_redux @@ -118,53 +112,6 @@ class Linker(Module): #endregion - # region data - - def __preprocess_data(self, batch_size, df_axiom_links, validation_rate=0.1): - r""" - Args: - batch_size : int - df_axiom_links pandas DataFrame - validation_rate - Returns: - the training dataloader and the validation dataloader. They contains the list of atoms, their polarities, the axiom links, the sentences tokenized, sentence mask - """ - print("Start preprocess Data") - sentences_batch = df_axiom_links["X"].str.strip().tolist() - sentences_tokens, sentences_mask = self.Supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch) - - atoms_batch, polarities, num_atoms_per_word = get_GOAL(self.max_len_sentence, df_axiom_links) - atoms_polarity_batch = pad_sequence( - [torch.as_tensor(polarities[i], dtype=torch.bool) for i in range(len(polarities))], - max_len=self.max_atoms_in_sentence, padding_value=0) - atoms_batch_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms_batch) - - pos_idx = get_pos_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type) - neg_idx = get_neg_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type) - - truth_links_batch = get_axiom_links(self.max_atoms_in_one_type, atoms_polarity_batch, - df_axiom_links["Y"]) - truth_links_batch = truth_links_batch.permute(1, 0, 2) - - # Construction tensor dataset - dataset = TensorDataset(num_atoms_per_word, atoms_batch_tokenized, pos_idx, neg_idx, truth_links_batch, - sentences_tokens, sentences_mask) - - if validation_rate > 0.0: - train_size = int(0.9 * len(dataset)) - val_size = len(dataset) - train_size - train_dataset, val_dataset = random_split(dataset, [train_size, val_size]) - validation_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False) - else: - validation_dataloader = None - train_dataset = dataset - - training_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=False) - print("End preprocess Data") - return training_dataloader, validation_dataloader - - #endregion - # region training def make_sinkhorn_inputs(self, bsd_tensor, positional_ids, atom_type): @@ -229,56 +176,7 @@ class Linker(Module): return F.log_softmax(link_weights, dim=3) - def train_linker(self, df_axiom_links, validation_rate=0.1, epochs=20, - batch_size=32, checkpoint=True, tensorboard=False): - r""" - Args: - df_axiom_links : pandas dataFrame containing the atoms anoted with _i - validation_rate : float - epochs : int - batch_size : int - checkpoint : boolean - tensorboard : boolean - Returns: - Final accuracy and final loss - """ - training_dataloader, validation_dataloader = self.__preprocess_data(batch_size, df_axiom_links, - validation_rate) - if checkpoint or tensorboard: - checkpoint_dir, writer = output_create_dir() - - for epoch_i in range(epochs): - print("") - print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs)) - print('Training...') - avg_train_loss, avg_accuracy_train, training_time = self.train_epoch(training_dataloader) - - print("") - print(f'Epoch: {epoch_i + 1:02} | Epoch Time: {training_time}') - print(f'\tTrain Loss: {avg_train_loss:.3f} | Train Acc: {avg_accuracy_train * 100:.2f}%') - - if validation_rate > 0.0: - loss_test, accuracy_test = self.eval_epoch(validation_dataloader) - print(f'\tVal Loss: {loss_test:.3f} | Val Acc: {accuracy_test * 100:.2f}%') - - if checkpoint: - self.__checkpoint_save( - path=os.path.join("Output", 'linker.pt')) - - if tensorboard: - writer.add_scalars(f'Accuracy', { - 'Train': avg_accuracy_train}, epoch_i) - writer.add_scalars(f'Loss', { - 'Train': avg_train_loss}, epoch_i) - if validation_rate > 0.0: - writer.add_scalars(f'Accuracy', { - 'Validation': accuracy_test}, epoch_i) - writer.add_scalars(f'Loss', { - 'Validation': loss_test}, epoch_i) - - print('\n') - - def train_epoch(self, training_dataloader): + def train_epoch(self, training_dataloader, Supertagger): r""" Train epoch Args: @@ -309,12 +207,11 @@ class Linker(Module): self.optimizer.zero_grad() # get sentence embedding from BERT which is already trained - output = self.Supertagger.forward(batch_sentences_tokens, batch_sentences_mask) + output = Supertagger.forward(batch_sentences_tokens, batch_sentences_mask) # Run the Linker on the atoms logits_predictions = self(batch_num_atoms, batch_atoms_tok, batch_pos_idx, batch_neg_idx, output['word_embedding']) - linker_loss = self.cross_entropy_loss(logits_predictions, batch_true_links) # Perform a backward pass to calculate the gradients. epoch_loss += float(linker_loss) @@ -342,33 +239,7 @@ class Linker(Module): # region evaluation - def eval_batch(self, batch): - batch_num_atoms = batch[0].to(self.device) - batch_atoms_tok = batch[1].to(self.device) - batch_pos_idx = batch[2].to(self.device) - batch_neg_idx = batch[3].to(self.device) - batch_true_links = batch[4].to(self.device) - batch_sentences_tokens = batch[5].to(self.device) - batch_sentences_mask = batch[6].to(self.device) - - output = self.Supertagger.forward(batch_sentences_tokens, batch_sentences_mask) - - logits_predictions = self(batch_num_atoms, batch_atoms_tok, batch_pos_idx, batch_neg_idx, output[ - 'word_embedding']) # atom_vocab, batch_size, max atoms in one type, max atoms in one type - axiom_links_pred = torch.argmax(logits_predictions, dim=3) # atom_vocab, batch_size, max atoms in one type - - print('\n') - print(batch_true_links) - print("Les vrais liens de la catégorie n : ", batch_true_links[0][2][:100]) - print("Les prédictions : ", axiom_links_pred[2][0][:100]) - print('\n') - - accuracy = measure_accuracy(batch_true_links, axiom_links_pred) - loss = self.cross_entropy_loss(logits_predictions, batch_true_links) - - return loss, accuracy - - def eval_epoch(self, dataloader): + def eval_epoch(self, dataloader, Supertagger): r"""Average the evaluation of all the batch. Args: @@ -379,107 +250,25 @@ class Linker(Module): loss_average = 0 with torch.no_grad(): for step, batch in enumerate(dataloader): - loss, accuracy = self.eval_batch(batch) - accuracy_average += accuracy - loss_average += float(loss) - - return loss_average / len(dataloader), accuracy_average / len(dataloader) - - #endregion - - #region prediction - - def predict_with_categories(self, sentence, categories): - r""" Predict the links from a sentence and its categories - - Args : - sentence : list of words composing the sentence - categories : list of categories (tags) of each word - - Return : - links : links prediction - """ - self.eval() - with torch.no_grad(): - self.cpu() - self.device = torch.device("cpu") - sentences_tokens, sentences_mask = self.Supertagger.sent_tokenizer.fit_transform_tensors([sentence]) - nb_sentence, len_sentence = sentences_tokens.shape - - atoms = get_atoms_batch([categories]) - atoms_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms) - - polarities = find_pos_neg_idexes([categories]) - polarities = pad_sequence( - [torch.as_tensor(polarities[i], dtype=torch.bool) for i in range(len(polarities))], - max_len=self.max_atoms_in_sentence, padding_value=0) - - num_atoms_per_word = get_num_atoms_batch([categories], len_sentence) - - pos_idx = get_pos_idx(atoms, polarities, self.max_atoms_in_one_type) - neg_idx = get_neg_idx(atoms, polarities, self.max_atoms_in_one_type) - - output = self.Supertagger.forward(sentences_tokens, sentences_mask) - - logits_predictions = self(num_atoms_per_word, atoms_tokenized, pos_idx, neg_idx, output['word_embedding']) - axiom_links_pred = torch.argmax(logits_predictions, dim=3) - - return axiom_links_pred - - def predict_without_categories(self, sentence): - r""" Predict the links from a sentence - - Args : - sentence : list of words composing the sentence - - Return : - categories : the supertags predicted - links : links prediction - """ - self.eval() - with torch.no_grad(): - self.cpu() - self.device = torch.device("cpu") - sentences_tokens, sentences_mask = self.Supertagger.sent_tokenizer.fit_transform_tensors(sentence) - nb_sentence, len_sentence = sentences_tokens.shape - - hidden_state, categories = self.Supertagger.predict(sentence) + batch_num_atoms = batch[0].to(self.device) + batch_atoms_tok = batch[1].to(self.device) + batch_pos_idx = batch[2].to(self.device) + batch_neg_idx = batch[3].to(self.device) + batch_true_links = batch[4].to(self.device) + batch_sentences_tokens = batch[5].to(self.device) + batch_sentences_mask = batch[6].to(self.device) - output = self.Supertagger.forward(sentences_tokens, sentences_mask) - atoms = get_atoms_batch(categories) - atoms_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms) + output = Supertagger.forward(batch_sentences_tokens, batch_sentences_mask) - polarities = find_pos_neg_idexes(categories) - polarities = pad_sequence( - [torch.as_tensor(polarities[i], dtype=torch.bool) for i in range(len(polarities))], - max_len=self.max_atoms_in_sentence, padding_value=0) + logits_predictions = self(batch_num_atoms, batch_atoms_tok, batch_pos_idx, batch_neg_idx, output['word_embedding']) # atom_vocab, batch_size, max atoms in one type, max atoms in one type + axiom_links_pred = torch.argmax(logits_predictions, dim=3) # atom_vocab, batch_size, max atoms in one type - num_atoms_per_word = get_num_atoms_batch(categories, len_sentence) + accuracy = measure_accuracy(batch_true_links, axiom_links_pred) + loss = self.cross_entropy_loss(logits_predictions, batch_true_links) - pos_idx = get_pos_idx(atoms, polarities, self.max_atoms_in_one_type) - neg_idx = get_neg_idx(atoms, polarities, self.max_atoms_in_one_type) + accuracy_average += accuracy + loss_average += float(loss) - logits_predictions = self(num_atoms_per_word, atoms_tokenized, pos_idx, neg_idx, output['word_embedding']) - axiom_links_pred = torch.argmax(logits_predictions, dim=3) + return loss_average / len(dataloader), accuracy_average / len(dataloader) - return categories, axiom_links_pred - #endregion - - def __checkpoint_save(self, path='/linker.pt'): - """ - @param path: - """ - self.cpu() - - torch.save({ - 'atom_encoder': self.atom_encoder.state_dict(), - 'position_encoder': self.position_encoder.state_dict(), - 'transformer': self.transformer.state_dict(), - 'linker_encoder': self.linker_encoder.state_dict(), - 'pos_transformation': self.pos_transformation.state_dict(), - 'neg_transformation': self.neg_transformation.state_dict(), - 'cross_entropy_loss': self.cross_entropy_loss.state_dict(), - 'optimizer': self.optimizer, - }, path) - self.to(self.device) diff --git a/Linker/eval.py b/Linker/eval.py index b252e5e..fe7f347 100644 --- a/Linker/eval.py +++ b/Linker/eval.py @@ -1,6 +1,9 @@ +import numpy as np + import torch from torch.nn import Module from torch.nn.functional import nll_loss + from Linker.atom_map import atom_map, atom_map_redux @@ -12,8 +15,16 @@ class SinkhornLoss(Module): super(SinkhornLoss, self).__init__() def forward(self, predictions, truths): - return sum(nll_loss(link.flatten(0, 1), perm.flatten(), reduction='mean', ignore_index=-1) - for link, perm in zip(predictions, truths.permute(1, 0, 2))) + sum = 0 + # for each categorie of atom (txt, np ...) + for link, perm in zip(predictions, truths.permute(1, 0, 2)): + # test if there are true links in this categorie + if 0 in perm.flatten(): + # mean nll loss of the categorie current calculated on the whole batch + it = nll_loss(link.flatten(0, 1), perm.flatten(), reduction='mean', ignore_index=-1) + # sum it to the current total loss + sum+=it + return sum def measure_accuracy(batch_true_links, axiom_links_pred): diff --git a/NeuralProofNet/NeuralProofNet.py b/NeuralProofNet/NeuralProofNet.py index ab12693..41ee516 100644 --- a/NeuralProofNet/NeuralProofNet.py +++ b/NeuralProofNet/NeuralProofNet.py @@ -8,13 +8,14 @@ from torch.utils.data import TensorDataset, random_split from tqdm import tqdm from Configuration import Configuration +from NeuralProofNet.utils_proofnet import get_info_for_tagger +from SuperTagger import SuperTagger from Linker import Linker from Linker.eval import measure_accuracy, SinkhornLoss -from Linker.utils_linker import get_axiom_links, get_GOAL, get_pos_idx, get_num_atoms_batch, get_neg_idx -from NeuralProofNet.utils_proofnet import get_info_for_tagger +from Linker.utils_linker import FFN, get_axiom_links, get_GOAL, get_pos_idx, get_neg_idx, get_atoms_batch, \ + find_pos_neg_idexes, get_num_atoms_batch, generate_square_subsequent_mask from utils import pad_sequence, format_time, output_create_dir - class NeuralProofNet(Module): def __init__(self, supertagger_path_model, linker_path_model=None): @@ -28,7 +29,13 @@ class NeuralProofNet(Module): self.max_atoms_in_one_type = int(datasetConfig['max_atoms_in_one_type']) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - linker = Linker(supertagger_path_model) + # SuperTagger for categories + supertagger = SuperTagger() + supertagger.load_weights(supertagger_path_model) + self.Supertagger = supertagger + self.Supertagger.model.to(self.device) + + linker = Linker() if linker_path_model is not None: linker.load_weights(linker_path_model) self.linker = linker @@ -41,12 +48,6 @@ class NeuralProofNet(Module): self.to(self.device) - def __pretrain_linker__(self, df_axiom_links, pretrain_linker_epochs, batch_size, checkpoint=False, tensorboard=True): - print("\nLinker Pre-Training\n") - self.linker.train_linker(df_axiom_links, validation_rate=0.05, epochs=pretrain_linker_epochs, - batch_size=batch_size, checkpoint=checkpoint, tensorboard=tensorboard) - print("\nEND Linker Pre-Training\n") - def __preprocess_data(self, batch_size, df_axiom_links, validation_rate=0.1): r""" Args: @@ -54,26 +55,31 @@ class NeuralProofNet(Module): df_axiom_links pandas DataFrame validation_rate Returns: - the training dataloader and the validation dataloader. They contain the list of atoms, their polarities, the axiom links, the sentences tokenized, sentence mask + the training dataloader and the validation dataloader. They contains the list of atoms, their polarities, the axiom links, the sentences tokenized, sentence mask """ print("Start preprocess Data") sentences_batch = df_axiom_links["X"].str.strip().tolist() - sentences_tokens, sentences_mask = self.linker.Supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch) + sentences_tokens, sentences_mask = self.Supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch) - _, polarities, _ = get_GOAL(self.max_len_sentence, df_axiom_links) + atoms_batch, polarities, num_atoms_per_word = get_GOAL(self.max_len_sentence, df_axiom_links) atoms_polarity_batch = pad_sequence( [torch.as_tensor(polarities[i], dtype=torch.bool) for i in range(len(polarities))], max_len=self.max_atoms_in_sentence, padding_value=0) + atoms_batch_tokenized = self.linker.atoms_tokenizer.convert_batchs_to_ids(atoms_batch) - truth_links_batch = get_axiom_links(self.max_atoms_in_one_type, atoms_polarity_batch, + pos_idx = get_pos_idx(atoms_batch, atoms_polarity_batch, self.linker.max_atoms_in_one_type) + neg_idx = get_neg_idx(atoms_batch, atoms_polarity_batch, self.linker.max_atoms_in_one_type) + + truth_links_batch = get_axiom_links(self.linker.max_atoms_in_one_type, atoms_polarity_batch, df_axiom_links["Y"]) truth_links_batch = truth_links_batch.permute(1, 0, 2) # Construction tensor dataset - dataset = TensorDataset(truth_links_batch, sentences_tokens, sentences_mask) + dataset = TensorDataset(num_atoms_per_word, atoms_batch_tokenized, pos_idx, neg_idx, truth_links_batch, + sentences_tokens, sentences_mask) if validation_rate > 0.0: - train_size = int(0.9 * len(dataset)) + train_size = int((1-validation_rate) * len(dataset)) val_size = len(dataset) - train_size train_dataset, val_dataset = random_split(dataset, [train_size, val_size]) validation_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False) @@ -85,13 +91,15 @@ class NeuralProofNet(Module): print("End preprocess Data") return training_dataloader, validation_dataloader + # region training + def forward(self, batch_sentences_tokens, batch_sentences_mask): # get sentence embedding from BERT which is already trained - output = self.linker.Supertagger.forward(batch_sentences_tokens, batch_sentences_mask) + output = self.Supertagger.forward(batch_sentences_tokens, batch_sentences_mask) last_hidden_state = output['logit'] pred_categories = torch.argmax(torch.softmax(last_hidden_state, dim=2), dim=2) - pred_categories = self.linker.Supertagger.tags_tokenizer.convert_ids_to_tags(pred_categories) + pred_categories = self.Supertagger.tags_tokenizer.convert_ids_to_tags(pred_categories) # get information from tagger predictions atoms_batch, polarities, batch_num_atoms_per_word = get_info_for_tagger(self.max_len_sentence, pred_categories) @@ -112,6 +120,49 @@ class NeuralProofNet(Module): return torch.log_softmax(logits_links, dim=3) + def pretrain_linker(self, training_dataloader, validation_dataloader, pretrain_linker_epochs, checkpoint=None, writer=None): + r""" + Args: + df_axiom_links : pandas dataFrame containing the atoms anoted with _i + validation_rate : float + epochs : int + batch_size : int + checkpoint : boolean + tensorboard : boolean + Returns: + Final accuracy and final loss + """ + + for epoch_i in range(pretrain_linker_epochs): + print("") + print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, pretrain_linker_epochs)) + print('Training...') + avg_train_loss, avg_accuracy_train, training_time = self.linker.train_epoch(training_dataloader, self.Supertagger) + + print("") + print(f'Epoch: {epoch_i + 1:02} | Epoch Time: {training_time}') + print(f'\tTrain Loss: {avg_train_loss:.3f} | Train Acc: {avg_accuracy_train * 100:.2f}%') + + if validation_dataloader: + loss_test, accuracy_test = self.linker.eval_epoch(validation_dataloader, self.Supertagger) + print(f'\tVal Loss: {loss_test:.3f} | Val Acc: {accuracy_test * 100:.2f}%') + + if checkpoint: + self.__checkpoint_save(path='Output/linker.pt') + + if writer: + writer.add_scalars(f'Accuracy', { + 'Train': avg_accuracy_train}, epoch_i) + writer.add_scalars(f'Loss', { + 'Train': avg_train_loss}, epoch_i) + if validation_dataloader : + writer.add_scalars(f'Accuracy', { + 'Validation': accuracy_test}, epoch_i) + writer.add_scalars(f'Loss', { + 'Validation': loss_test}, epoch_i) + + print('\n') + def train_neuralproofnet(self, df_axiom_links, validation_rate=0.1, epochs=20, pretrain_linker_epochs=0, batch_size=32, checkpoint=True, tensorboard=False): r""" @@ -125,15 +176,20 @@ class NeuralProofNet(Module): Returns: Final accuracy and final loss """ - # Pretrain the linker - self.__pretrain_linker__(df_axiom_links, pretrain_linker_epochs, batch_size) - # Start learning with output from tagger training_dataloader, validation_dataloader = self.__preprocess_data(batch_size, df_axiom_links, validation_rate) if checkpoint or tensorboard: checkpoint_dir, writer = output_create_dir() + # Pretrain the linker with the rights categories + if pretrain_linker_epochs >0 : + print("\nLinker Pre-Training\n") + self.pretrain_linker(training_dataloader, validation_dataloader, \ + pretrain_linker_epochs, checkpoint, writer) + print("\nEND Linker Pre-Training\n") + + # Train Linker with predicted categories from supertagger for epoch_i in range(epochs): print("") print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs)) @@ -153,14 +209,14 @@ class NeuralProofNet(Module): if tensorboard: writer.add_scalars(f'Accuracy', { - 'Train': avg_accuracy_train}, epoch_i) + 'Train': avg_accuracy_train}, pretrain_linker_epochs + epoch_i) writer.add_scalars(f'Loss', { - 'Train': avg_train_loss}, epoch_i) + 'Train': avg_train_loss}, pretrain_linker_epochs + epoch_i) if validation_rate > 0.0: writer.add_scalars(f'Accuracy', { - 'Validation': accuracy_test}, epoch_i) + 'Validation': accuracy_test}, pretrain_linker_epochs + epoch_i) writer.add_scalars(f'Loss', { - 'Validation': loss_test}, epoch_i) + 'Validation': loss_test}, pretrain_linker_epochs + epoch_i) print('\n') @@ -184,9 +240,9 @@ class NeuralProofNet(Module): with tqdm(training_dataloader, unit="batch") as tepoch: for batch in tepoch: # Unpack this training batch from our dataloader - batch_true_links = batch[0].to(self.device) - batch_sentences_tokens = batch[1].to(self.device) - batch_sentences_mask = batch[2].to(self.device) + batch_true_links = batch[4].to(self.device) + batch_sentences_tokens = batch[5].to(self.device) + batch_sentences_mask = batch[6].to(self.device) self.linker_optimizer.zero_grad() @@ -215,25 +271,10 @@ class NeuralProofNet(Module): avg_accuracy_train = accuracy_train / len(training_dataloader) return avg_train_loss, avg_accuracy_train, training_time + + #endregion - def eval_batch(self, batch): - batch_true_links = batch[0].to(self.device) - batch_sentences_tokens = batch[1].to(self.device) - batch_sentences_mask = batch[2].to(self.device) - - logits_predictions_links = self(batch_sentences_tokens, batch_sentences_mask) - axiom_links_pred = torch.argmax(logits_predictions_links, - dim=3) # atom_vocab, batch_size, max atoms in one type - - print('\n') - print("Les vrais liens de la catégorie n : ", batch_true_links[0][2][:100]) - print("Les prédictions : ", axiom_links_pred[2][0][:100]) - print('\n') - - accuracy = measure_accuracy(batch_true_links, axiom_links_pred) - linker_loss = self.linker_loss(logits_predictions_links, batch_true_links) - - return linker_loss, accuracy + # region evaluation def eval_epoch(self, dataloader): r"""Average the evaluation of all the batch. @@ -246,12 +287,24 @@ class NeuralProofNet(Module): loss_average = 0 with torch.no_grad(): for step, batch in enumerate(dataloader): - loss, accuracy = self.eval_batch(batch) + batch_true_links = batch[4].to(self.device) + batch_sentences_tokens = batch[5].to(self.device) + batch_sentences_mask = batch[6].to(self.device) + + logits_predictions_links = self(batch_sentences_tokens, batch_sentences_mask) + axiom_links_pred = torch.argmax(logits_predictions_links, + dim=3) # atom_vocab, batch_size, max atoms in one type + + accuracy = measure_accuracy(batch_true_links, axiom_links_pred) + linker_loss = self.linker_loss(logits_predictions_links, batch_true_links) + accuracy_average += accuracy - loss_average += float(loss) + loss_average += float(linker_loss) return loss_average / len(dataloader), accuracy_average / len(dataloader) + #endregion + def __checkpoint_save(self, path='/linker.pt'): """ @param path: @@ -268,4 +321,83 @@ class NeuralProofNet(Module): 'cross_entropy_loss': self.linker_loss.state_dict(), 'optimizer': self.linker_optimizer, }, path) - self.to(self.device) \ No newline at end of file + self.to(self.device) + + #region prediction + + def predict_with_categories(self, sentence, categories): + r""" Predict the links from a sentence and its categories + + Args : + sentence : list of words composing the sentence + categories : list of categories (tags) of each word + + Return : + links : links prediction + """ + self.eval() + with torch.no_grad(): + self.cpu() + self.device = torch.device("cpu") + sentences_tokens, sentences_mask = self.Supertagger.sent_tokenizer.fit_transform_tensors(sentence) + nb_sentence, len_sentence = sentences_tokens.shape + + atoms = get_atoms_batch(categories) + atoms_tokenized = self.linker.atoms_tokenizer.convert_batchs_to_ids(atoms) + + polarities = find_pos_neg_idexes(categories) + polarities = pad_sequence( + [torch.as_tensor(polarities[i], dtype=torch.bool) for i in range(len(polarities))], + max_len=self.max_atoms_in_sentence, padding_value=0) + + num_atoms_per_word = get_num_atoms_batch(categories, len_sentence) + + pos_idx = get_pos_idx(atoms, polarities, self.max_atoms_in_one_type) + neg_idx = get_neg_idx(atoms, polarities, self.max_atoms_in_one_type) + + output = self.Supertagger.forward(sentences_tokens, sentences_mask) + + logits_predictions = self.linker(num_atoms_per_word, atoms_tokenized, pos_idx, neg_idx, output['word_embedding']) + axiom_links_pred = torch.argmax(logits_predictions, dim=3) + + return axiom_links_pred + + def predict_without_categories(self, sentence): + r""" Predict the links from a sentence + + Args : + sentence : list of words composing the sentence + + Return : + categories : the supertags predicted + links : links prediction + """ + self.eval() + with torch.no_grad(): + self.cpu() + self.device = torch.device("cpu") + sentences_tokens, sentences_mask = self.Supertagger.sent_tokenizer.fit_transform_tensors(sentence) + nb_sentence, len_sentence = sentences_tokens.shape + + hidden_state, categories = self.Supertagger.predict(sentence) + + output = self.Supertagger.forward(sentences_tokens, sentences_mask) + atoms = get_atoms_batch(categories) + atoms_tokenized = self.linker.atoms_tokenizer.convert_batchs_to_ids(atoms) + + polarities = find_pos_neg_idexes(categories) + polarities = pad_sequence( + [torch.as_tensor(polarities[i], dtype=torch.bool) for i in range(len(polarities))], + max_len=self.max_atoms_in_sentence, padding_value=0) + + num_atoms_per_word = get_num_atoms_batch(categories, len_sentence) + + pos_idx = get_pos_idx(atoms, polarities, self.max_atoms_in_one_type) + neg_idx = get_neg_idx(atoms, polarities, self.max_atoms_in_one_type) + + logits_predictions = self.linker(num_atoms_per_word, atoms_tokenized, pos_idx, neg_idx, output['word_embedding']) + axiom_links_pred = torch.argmax(logits_predictions, dim=3) + + return categories, axiom_links_pred + + #endregion \ No newline at end of file diff --git a/predict_links.py b/predict_links.py index 4e87a1f..b52982a 100644 --- a/predict_links.py +++ b/predict_links.py @@ -4,22 +4,21 @@ from postprocessing import draw_sentence_output if __name__== '__main__': # region data a_s = ["( 1 ) parmi les huit \" partants \" acquis ou potentiels , MM. Lacombe , Koehler et Laroze ne sont pas membres du PCF ."] - tags_s = ['let', 'dr(0,s,s)', 'let', 'dr(0,dr(0,s,s),np)', 'dr(0,np,n)', 'dr(0,n,n)', 'let', 'n', 'let', 'dl(0,n,n)', + tags_s = [['let', 'dr(0,s,s)', 'let', 'dr(0,dr(0,s,s),np)', 'dr(0,np,n)', 'dr(0,n,n)', 'let', 'n', 'let', 'dl(0,n,n)', 'dr(0,dl(0,dl(0,n,n),dl(0,n,n)),dl(0,n,n))', 'dl(0,n,n)', 'let', 'dr(0,np,np)', 'np', 'dr(0,dl(0,np,np),np)', 'np', 'dr(0,dl(0,np,np),np)', 'np', 'dr(0,dl(0,np,s),dl(0,np,s))', 'dr(0,dl(0,np,s),np)', 'dl(1,s,s)', 'np', - 'dr(0,dl(0,np,np),n)', 'n', 'dl(0,s,txt)'] + 'dr(0,dl(0,np,np),n)', 'n', 'dl(0,s,txt)']] # endregion - # region model model_tagger = "models/flaubert_super_98_V2_50e.pt" neuralproofnet = NeuralProofNet(model_tagger) - model = "Output/linker.pt" + model = "Output/saved_linker.pt" neuralproofnet.linker.load_weights(model) # endregion - linker = neuralproofnet.linker - categories, links = linker.predict_without_categories(a_s) - #links = linker.predict_with_categories(a_s, tags_s) + #categories, links = neuralproofnet.predict_without_categories(a_s) + links = neuralproofnet.predict_with_categories(a_s, tags_s) + idx=0 - draw_sentence_output(a_s[idx].split(" "), categories[idx], links[:,idx,:].numpy()) + draw_sentence_output(a_s[idx].split(" "), tags_s[idx], links[:,idx,:].numpy()) diff --git a/train_neuralproofnet.py b/train_neuralproofnet.py index ce393ad..f3c2284 100644 --- a/train_neuralproofnet.py +++ b/train_neuralproofnet.py @@ -6,8 +6,8 @@ torch.cuda.empty_cache() # region data -file_path_axiom_links = 'Datasets/goldANDsilver_dataset_links.csv' -df_axiom_links = read_links_csv(file_path_axiom_links) +file_path_axiom_links = 'Datasets/gold_dataset_links.csv' +df_axiom_links = read_links_csv(file_path_axiom_links)[:32] # endregion @@ -16,7 +16,7 @@ print("#" * 20) print("#" * 20) model_tagger = "models/flaubert_super_98_V2_50e.pt" neural_proof_net = NeuralProofNet(model_tagger) -neural_proof_net.train_neuralproofnet(df_axiom_links, validation_rate=0.1, epochs=25, pretrain_linker_epochs=20, batch_size=16, +neural_proof_net.train_neuralproofnet(df_axiom_links, validation_rate=0, epochs=5, pretrain_linker_epochs=5, batch_size=16, checkpoint=True, tensorboard=True) print("#" * 20) print("#" * 20) -- GitLab