Skip to content
Snippets Groups Projects
NeuralProofNet.py 11.87 KiB
import time

import torch
from torch.nn import Module
from torch.optim import AdamW
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import TensorDataset, random_split
from tqdm import tqdm

from Configuration import Configuration
from Linker import Linker
from Linker.eval import measure_accuracy, SinkhornLoss
from Linker.utils_linker import get_axiom_links, get_GOAL, get_pos_idx, get_num_atoms_batch, get_neg_idx
from NeuralProofNet.utils_proofnet import get_info_for_tagger
from utils import pad_sequence, format_time, output_create_dir


class NeuralProofNet(Module):
    
    def __init__(self, supertagger_path_model, linker_path_model=None):
        super(NeuralProofNet, self).__init__()
        config = Configuration.read_config()
        datasetConfig = config["DATASET_PARAMS"]

        # settings
        self.max_len_sentence = int(datasetConfig['max_len_sentence'])
        self.max_atoms_in_sentence = int(datasetConfig['max_atoms_in_sentence'])
        self.max_atoms_in_one_type = int(datasetConfig['max_atoms_in_one_type'])
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        linker = Linker(supertagger_path_model)
        if linker_path_model is not None:
            linker.load_weights(linker_path_model)
        self.linker = linker

        # Learning
        self.linker_loss = SinkhornLoss()
        self.linker_optimizer = AdamW(self.linker.parameters(),
                                      lr=0.001)
        self.linker_scheduler = StepLR(self.linker_optimizer, step_size=2, gamma=0.5)

        self.to(self.device)

    def __pretrain_linker__(self, df_axiom_links, pretrain_linker_epochs, batch_size, checkpoint=False, tensorboard=True):
        print("\nLinker Pre-Training\n")
        self.linker.train_linker(df_axiom_links, validation_rate=0.05, epochs=pretrain_linker_epochs,
                                 batch_size=batch_size, checkpoint=checkpoint, tensorboard=tensorboard)
        print("\nEND Linker Pre-Training\n")

    def __preprocess_data(self, batch_size, df_axiom_links, validation_rate=0.1):
        r"""
        Args:
            batch_size : int
            df_axiom_links pandas DataFrame
            validation_rate
        Returns:
            the training dataloader and the validation dataloader. They contain the list of atoms, their polarities, the axiom links, the sentences tokenized, sentence mask
        """
        print("Start preprocess Data")
        sentences_batch = df_axiom_links["X"].str.strip().tolist()
        sentences_tokens, sentences_mask = self.linker.Supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch)

        _, polarities, _ = get_GOAL(self.max_len_sentence, df_axiom_links)
        atoms_polarity_batch = pad_sequence(
            [torch.as_tensor(polarities[i], dtype=torch.bool) for i in range(len(polarities))],
            max_len=self.max_atoms_in_sentence, padding_value=0)

        truth_links_batch = get_axiom_links(self.max_atoms_in_one_type, atoms_polarity_batch,
                                            df_axiom_links["Y"])
        truth_links_batch = truth_links_batch.permute(1, 0, 2)

        # Construction tensor dataset
        dataset = TensorDataset(truth_links_batch, sentences_tokens, sentences_mask)

        if validation_rate > 0.0:
            train_size = int(0.9 * len(dataset))
            val_size = len(dataset) - train_size
            train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
            validation_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
        else:
            validation_dataloader = None
            train_dataset = dataset

        training_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
        print("End preprocess Data")
        return training_dataloader, validation_dataloader

    def forward(self, batch_sentences_tokens, batch_sentences_mask):

        # get sentence embedding from BERT which is already trained
        output = self.linker.Supertagger.forward(batch_sentences_tokens, batch_sentences_mask)
        last_hidden_state = output['logit']
        pred_categories = torch.argmax(torch.softmax(last_hidden_state, dim=2), dim=2)
        pred_categories = self.linker.Supertagger.tags_tokenizer.convert_ids_to_tags(pred_categories)

        # get information from tagger predictions
        atoms_batch, polarities, batch_num_atoms_per_word = get_info_for_tagger(self.max_len_sentence, pred_categories)
        atoms_polarity_batch = pad_sequence(
            [torch.as_tensor(polarities[i], dtype=torch.bool) for i in range(len(polarities))],
            max_len=self.max_atoms_in_sentence, padding_value=0)
        atoms_batch_tokenized = self.linker.atoms_tokenizer.convert_batchs_to_ids(atoms_batch)
        batch_pos_idx = get_pos_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type)
        batch_neg_idx = get_neg_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type)

        batch_num_atoms_per_word = batch_num_atoms_per_word.to(self.device)
        atoms_batch_tokenized = atoms_batch_tokenized.to(self.device)
        batch_pos_idx = batch_pos_idx.to(self.device)
        batch_neg_idx = batch_neg_idx.to(self.device)

        logits_links = self.linker(batch_num_atoms_per_word, atoms_batch_tokenized, batch_pos_idx, batch_neg_idx,
                                   output['word_embedding'])

        return torch.log_softmax(logits_links, dim=3)

    def train_neuralproofnet(self, df_axiom_links, validation_rate=0.1, epochs=20, pretrain_linker_epochs=0, 
                             batch_size=32, checkpoint=True, tensorboard=False):
        r"""
        Args:
            df_axiom_links : pandas dataFrame containing the atoms anoted with _i
            validation_rate : float
            epochs : int
            batch_size : int
            checkpoint : boolean
            tensorboard : boolean
        Returns:
            Final accuracy and final loss
        """
        # Pretrain the linker
        self.__pretrain_linker__(df_axiom_links, pretrain_linker_epochs, batch_size)

        # Start learning with output from tagger
        training_dataloader, validation_dataloader = self.__preprocess_data(batch_size, df_axiom_links,
                                                                            validation_rate)
        if checkpoint or tensorboard:
            checkpoint_dir, writer = output_create_dir()

        for epoch_i in range(epochs):
            print("")
            print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs))
            print('Training...')
            avg_train_loss, avg_accuracy_train, training_time = self.train_epoch(training_dataloader)

            print("")
            print(f'Epoch: {epoch_i + 1:02} | Epoch Time: {training_time}')
            print(f'\tTrain Loss: {avg_train_loss:.3f} | Train Acc: {avg_accuracy_train * 100:.2f}%')

            if validation_rate > 0.0:
                loss_test, accuracy_test = self.eval_epoch(validation_dataloader)
                print(f'\tVal Loss: {loss_test:.3f} | Val Acc: {accuracy_test * 100:.2f}%')

            if checkpoint:
                self.__checkpoint_save(path='Output/linker.pt')

            if tensorboard:
                writer.add_scalars(f'Accuracy', {
                    'Train': avg_accuracy_train}, epoch_i)
                writer.add_scalars(f'Loss', {
                    'Train': avg_train_loss}, epoch_i)
                if validation_rate > 0.0:
                    writer.add_scalars(f'Accuracy', {
                        'Validation': accuracy_test}, epoch_i)
                    writer.add_scalars(f'Loss', {
                        'Validation': loss_test}, epoch_i)

            print('\n')

    def train_epoch(self, training_dataloader):
        r""" Train epoch

        Args:
            training_dataloader : DataLoader from torch , contains atoms, polarities, axiom_links, sents_tokenized, sents_masks
        Returns:
             accuracy on validation set
             loss on train set
        """
        self.train()

        # Reset the total loss for this epoch.
        epoch_loss = 0
        accuracy_train = 0
        t0 = time.time()

        # For each batch of training data...
        with tqdm(training_dataloader, unit="batch") as tepoch:
            for batch in tepoch:
                # Unpack this training batch from our dataloader
                batch_true_links = batch[0].to(self.device)
                batch_sentences_tokens = batch[1].to(self.device)
                batch_sentences_mask = batch[2].to(self.device)

                self.linker_optimizer.zero_grad()

                # Run the Linker on the atoms
                logits_predictions_links = self(batch_sentences_tokens, batch_sentences_mask)

                linker_loss = self.linker_loss(logits_predictions_links, batch_true_links)
                # Perform a backward pass to calculate the gradients.
                epoch_loss += float(linker_loss)
                linker_loss.backward()

                # This is to help prevent the "exploding gradients" problem.
                # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0, norm_type=2)

                # Update parameters and take a step using the computed gradient.
                self.linker_optimizer.step()

                pred_axiom_links = torch.argmax(logits_predictions_links, dim=3)
                accuracy_train += measure_accuracy(batch_true_links, pred_axiom_links)

        self.linker_scheduler.step()

        # Measure how long this epoch took.
        training_time = format_time(time.time() - t0)
        avg_train_loss = epoch_loss / len(training_dataloader)
        avg_accuracy_train = accuracy_train / len(training_dataloader)

        return avg_train_loss, avg_accuracy_train, training_time

    def eval_batch(self, batch):
        batch_true_links = batch[0].to(self.device)
        batch_sentences_tokens = batch[1].to(self.device)
        batch_sentences_mask = batch[2].to(self.device)

        logits_predictions_links = self(batch_sentences_tokens, batch_sentences_mask)
        axiom_links_pred = torch.argmax(logits_predictions_links,
                                        dim=3)  # atom_vocab, batch_size, max atoms in one type

        print('\n')
        print("Les vrais liens de la catégorie n : ", batch_true_links[0][2][:100])
        print("Les prédictions : ", axiom_links_pred[2][0][:100])
        print('\n')

        accuracy = measure_accuracy(batch_true_links, axiom_links_pred)
        linker_loss = self.linker_loss(logits_predictions_links, batch_true_links)

        return linker_loss, accuracy

    def eval_epoch(self, dataloader):
        r"""Average the evaluation of all the batch.

        Args:
            dataloader: contains all the batch which contain the tokenized sentences, their masks and the true symbols
        """
        self.eval()
        accuracy_average = 0
        loss_average = 0
        with torch.no_grad():
            for step, batch in enumerate(dataloader):
                loss, accuracy = self.eval_batch(batch)
                accuracy_average += accuracy
                loss_average += float(loss)

        return loss_average / len(dataloader), accuracy_average / len(dataloader)

    def __checkpoint_save(self, path='/linker.pt'):
        """
        @param path:
        """
        self.cpu()

        torch.save({
            'atom_encoder': self.linker.atom_encoder.state_dict(),
            'position_encoder': self.linker.position_encoder.state_dict(),
            'transformer': self.linker.transformer.state_dict(),
            'linker_encoder': self.linker.linker_encoder.state_dict(),
            'pos_transformation': self.linker.pos_transformation.state_dict(),
            'neg_transformation': self.linker.neg_transformation.state_dict(),
            'cross_entropy_loss': self.linker_loss.state_dict(),
            'optimizer': self.linker_optimizer,
        }, path)
        self.to(self.device)