import time import torch from torch.nn import Module from torch.optim import AdamW from torch.optim.lr_scheduler import StepLR from torch.utils.data import TensorDataset, random_split from tqdm import tqdm from Configuration import Configuration from Linker import Linker from Linker.eval import measure_accuracy, SinkhornLoss from Linker.utils_linker import get_axiom_links, get_GOAL, get_pos_idx, get_num_atoms_batch, get_neg_idx from NeuralProofNet.utils_proofnet import get_info_for_tagger from utils import pad_sequence, format_time, output_create_dir class NeuralProofNet(Module): def __init__(self, supertagger_path_model, linker_path_model=None): super(NeuralProofNet, self).__init__() config = Configuration.read_config() datasetConfig = config["DATASET_PARAMS"] # settings self.max_len_sentence = int(datasetConfig['max_len_sentence']) self.max_atoms_in_sentence = int(datasetConfig['max_atoms_in_sentence']) self.max_atoms_in_one_type = int(datasetConfig['max_atoms_in_one_type']) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") linker = Linker(supertagger_path_model) if linker_path_model is not None: linker.load_weights(linker_path_model) self.linker = linker # Learning self.linker_loss = SinkhornLoss() self.linker_optimizer = AdamW(self.linker.parameters(), lr=0.001) self.linker_scheduler = StepLR(self.linker_optimizer, step_size=2, gamma=0.5) self.to(self.device) def __pretrain_linker__(self, df_axiom_links, pretrain_linker_epochs, batch_size, checkpoint=False, tensorboard=True): print("\nLinker Pre-Training\n") self.linker.train_linker(df_axiom_links, validation_rate=0.05, epochs=pretrain_linker_epochs, batch_size=batch_size, checkpoint=checkpoint, tensorboard=tensorboard) print("\nEND Linker Pre-Training\n") def __preprocess_data(self, batch_size, df_axiom_links, validation_rate=0.1): r""" Args: batch_size : int df_axiom_links pandas DataFrame validation_rate Returns: the training dataloader and the validation dataloader. They contain the list of atoms, their polarities, the axiom links, the sentences tokenized, sentence mask """ print("Start preprocess Data") sentences_batch = df_axiom_links["X"].str.strip().tolist() sentences_tokens, sentences_mask = self.linker.Supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch) _, polarities, _ = get_GOAL(self.max_len_sentence, df_axiom_links) atoms_polarity_batch = pad_sequence( [torch.as_tensor(polarities[i], dtype=torch.bool) for i in range(len(polarities))], max_len=self.max_atoms_in_sentence, padding_value=0) truth_links_batch = get_axiom_links(self.max_atoms_in_one_type, atoms_polarity_batch, df_axiom_links["Y"]) truth_links_batch = truth_links_batch.permute(1, 0, 2) # Construction tensor dataset dataset = TensorDataset(truth_links_batch, sentences_tokens, sentences_mask) if validation_rate > 0.0: train_size = int(0.9 * len(dataset)) val_size = len(dataset) - train_size train_dataset, val_dataset = random_split(dataset, [train_size, val_size]) validation_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False) else: validation_dataloader = None train_dataset = dataset training_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=False) print("End preprocess Data") return training_dataloader, validation_dataloader def forward(self, batch_sentences_tokens, batch_sentences_mask): # get sentence embedding from BERT which is already trained output = self.linker.Supertagger.forward(batch_sentences_tokens, batch_sentences_mask) last_hidden_state = output['logit'] pred_categories = torch.argmax(torch.softmax(last_hidden_state, dim=2), dim=2) pred_categories = self.linker.Supertagger.tags_tokenizer.convert_ids_to_tags(pred_categories) # get information from tagger predictions atoms_batch, polarities, batch_num_atoms_per_word = get_info_for_tagger(self.max_len_sentence, pred_categories) atoms_polarity_batch = pad_sequence( [torch.as_tensor(polarities[i], dtype=torch.bool) for i in range(len(polarities))], max_len=self.max_atoms_in_sentence, padding_value=0) atoms_batch_tokenized = self.linker.atoms_tokenizer.convert_batchs_to_ids(atoms_batch) batch_pos_idx = get_pos_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type) batch_neg_idx = get_neg_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type) batch_num_atoms_per_word = batch_num_atoms_per_word.to(self.device) atoms_batch_tokenized = atoms_batch_tokenized.to(self.device) batch_pos_idx = batch_pos_idx.to(self.device) batch_neg_idx = batch_neg_idx.to(self.device) logits_links = self.linker(batch_num_atoms_per_word, atoms_batch_tokenized, batch_pos_idx, batch_neg_idx, output['word_embedding']) return torch.log_softmax(logits_links, dim=3) def train_neuralproofnet(self, df_axiom_links, validation_rate=0.1, epochs=20, pretrain_linker_epochs=0, batch_size=32, checkpoint=True, tensorboard=False): r""" Args: df_axiom_links : pandas dataFrame containing the atoms anoted with _i validation_rate : float epochs : int batch_size : int checkpoint : boolean tensorboard : boolean Returns: Final accuracy and final loss """ # Pretrain the linker self.__pretrain_linker__(df_axiom_links, pretrain_linker_epochs, batch_size) # Start learning with output from tagger training_dataloader, validation_dataloader = self.__preprocess_data(batch_size, df_axiom_links, validation_rate) if checkpoint or tensorboard: checkpoint_dir, writer = output_create_dir() for epoch_i in range(epochs): print("") print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs)) print('Training...') avg_train_loss, avg_accuracy_train, training_time = self.train_epoch(training_dataloader) print("") print(f'Epoch: {epoch_i + 1:02} | Epoch Time: {training_time}') print(f'\tTrain Loss: {avg_train_loss:.3f} | Train Acc: {avg_accuracy_train * 100:.2f}%') if validation_rate > 0.0: loss_test, accuracy_test = self.eval_epoch(validation_dataloader) print(f'\tVal Loss: {loss_test:.3f} | Val Acc: {accuracy_test * 100:.2f}%') if checkpoint: self.__checkpoint_save(path='Output/linker.pt') if tensorboard: writer.add_scalars(f'Accuracy', { 'Train': avg_accuracy_train}, epoch_i) writer.add_scalars(f'Loss', { 'Train': avg_train_loss}, epoch_i) if validation_rate > 0.0: writer.add_scalars(f'Accuracy', { 'Validation': accuracy_test}, epoch_i) writer.add_scalars(f'Loss', { 'Validation': loss_test}, epoch_i) print('\n') def train_epoch(self, training_dataloader): r""" Train epoch Args: training_dataloader : DataLoader from torch , contains atoms, polarities, axiom_links, sents_tokenized, sents_masks Returns: accuracy on validation set loss on train set """ self.train() # Reset the total loss for this epoch. epoch_loss = 0 accuracy_train = 0 t0 = time.time() # For each batch of training data... with tqdm(training_dataloader, unit="batch") as tepoch: for batch in tepoch: # Unpack this training batch from our dataloader batch_true_links = batch[0].to(self.device) batch_sentences_tokens = batch[1].to(self.device) batch_sentences_mask = batch[2].to(self.device) self.linker_optimizer.zero_grad() # Run the Linker on the atoms logits_predictions_links = self(batch_sentences_tokens, batch_sentences_mask) linker_loss = self.linker_loss(logits_predictions_links, batch_true_links) # Perform a backward pass to calculate the gradients. epoch_loss += float(linker_loss) linker_loss.backward() # This is to help prevent the "exploding gradients" problem. # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0, norm_type=2) # Update parameters and take a step using the computed gradient. self.linker_optimizer.step() pred_axiom_links = torch.argmax(logits_predictions_links, dim=3) accuracy_train += measure_accuracy(batch_true_links, pred_axiom_links) self.linker_scheduler.step() # Measure how long this epoch took. training_time = format_time(time.time() - t0) avg_train_loss = epoch_loss / len(training_dataloader) avg_accuracy_train = accuracy_train / len(training_dataloader) return avg_train_loss, avg_accuracy_train, training_time def eval_batch(self, batch): batch_true_links = batch[0].to(self.device) batch_sentences_tokens = batch[1].to(self.device) batch_sentences_mask = batch[2].to(self.device) logits_predictions_links = self(batch_sentences_tokens, batch_sentences_mask) axiom_links_pred = torch.argmax(logits_predictions_links, dim=3) # atom_vocab, batch_size, max atoms in one type print('\n') print("Les vrais liens de la catégorie n : ", batch_true_links[0][2][:100]) print("Les prédictions : ", axiom_links_pred[2][0][:100]) print('\n') accuracy = measure_accuracy(batch_true_links, axiom_links_pred) linker_loss = self.linker_loss(logits_predictions_links, batch_true_links) return linker_loss, accuracy def eval_epoch(self, dataloader): r"""Average the evaluation of all the batch. Args: dataloader: contains all the batch which contain the tokenized sentences, their masks and the true symbols """ self.eval() accuracy_average = 0 loss_average = 0 with torch.no_grad(): for step, batch in enumerate(dataloader): loss, accuracy = self.eval_batch(batch) accuracy_average += accuracy loss_average += float(loss) return loss_average / len(dataloader), accuracy_average / len(dataloader) def __checkpoint_save(self, path='/linker.pt'): """ @param path: """ self.cpu() torch.save({ 'atom_encoder': self.linker.atom_encoder.state_dict(), 'position_encoder': self.linker.position_encoder.state_dict(), 'transformer': self.linker.transformer.state_dict(), 'linker_encoder': self.linker.linker_encoder.state_dict(), 'pos_transformation': self.linker.pos_transformation.state_dict(), 'neg_transformation': self.linker.neg_transformation.state_dict(), 'cross_entropy_loss': self.linker_loss.state_dict(), 'optimizer': self.linker_optimizer, }, path) self.to(self.device)