Skip to content
Snippets Groups Projects
Linker.py 18.33 KiB
import os
import sys
import datetime

import time

import torch
import torch.nn.functional as F
from torch.nn import Sequential, LayerNorm, Dropout
from torch.optim import AdamW
from torch.utils.data import TensorDataset, random_split
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

from Configuration import Configuration
from Linker.AtomEmbedding import AtomEmbedding
from Linker.AtomTokenizer import AtomTokenizer
from Linker.MHA import AttentionDecoderLayer
from Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn
from Linker.atom_map import atom_map
from Linker.eval import mesure_accuracy, SinkhornLoss
from Linker.utils_linker import find_pos_neg_idexes, get_atoms_batch, FFN, get_axiom_links, get_pos_encoding_for_s_idx, \
    get_neg_encoding_for_s_idx
from Supertagger import *
from utils import pad_sequence


def format_time(elapsed):
    '''
    Takes a time in seconds and returns a string hh:mm:ss
    '''
    # Round to the nearest second.
    elapsed_rounded = int(round(elapsed))

    # Format as hh:mm:ss
    return str(datetime.timedelta(seconds=elapsed_rounded))


def output_create_dir():
    """
    Create le output dir for tensorboard and checkpoint
    @return: output dir, tensorboard writter
    """
    from datetime import datetime
    outpout_path = 'TensorBoard'
    training_dir = os.path.join(outpout_path, 'Tranning_' + datetime.today().strftime('%d-%m_%H-%M'))
    logs_dir = os.path.join(training_dir, 'logs')
    writer = SummaryWriter(log_dir=logs_dir)
    return training_dir, writer


class Linker(Module):
    def __init__(self, supertagger_path_model):
        super(Linker, self).__init__()

        self.dim_embedding_atoms = int(Configuration.modelLinkerConfig['dim_embedding_atoms'])
        self.nhead = int(Configuration.modelDecoderConfig['nhead'])
        dim_pre_sinkhorn_transfo = int(Configuration.modelLinkerConfig['dim_pre_sinkhorn_transfo'])
        dim_polarity_transfo = int(Configuration.modelLinkerConfig['dim_polarity_transfo'])
        self.sinkhorn_iters = int(Configuration.modelLinkerConfig['sinkhorn_iters'])
        self.max_atoms_in_sentence = int(Configuration.datasetConfig['max_atoms_in_sentence'])
        self.max_atoms_in_one_type = int(Configuration.datasetConfig['max_atoms_in_one_type'])
        atom_vocab_size = int(Configuration.datasetConfig['atom_vocab_size'])
        learning_rate = float(Configuration.modelTrainingConfig['learning_rate'])
        self.dropout = Dropout(0.1)
        self.device = "cpu"

        supertagger = SuperTagger()
        supertagger.load_weights(supertagger_path_model)
        self.Supertagger = supertagger

        self.atom_map = atom_map
        self.sub_atoms_type_list = ['cl_r', 'pp', 'n', 'np', 'cl_y', 'txt', 's']
        self.padding_id = self.atom_map['[PAD]']
        self.atoms_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence)
        self.inverse_map = self.atoms_tokenizer.inverse_atom_map
        self.atoms_embedding = AtomEmbedding(self.dim_embedding_atoms, atom_vocab_size, self.padding_id)

        self.linker_encoder = AttentionDecoderLayer()

        self.pos_transformation = Sequential(
            FFN(self.dim_embedding_atoms, dim_polarity_transfo, 0.1, d_out=dim_pre_sinkhorn_transfo),
            LayerNorm(dim_pre_sinkhorn_transfo, eps=1e-12)
        )
        self.neg_transformation = Sequential(
            FFN(self.dim_embedding_atoms, dim_polarity_transfo, 0.1, d_out=dim_pre_sinkhorn_transfo),
            LayerNorm(dim_pre_sinkhorn_transfo, eps=1e-12)
        )

        self.cross_entropy_loss = SinkhornLoss()
        self.optimizer = AdamW(self.parameters(),
                               lr=learning_rate)

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.to(self.device)

    def __preprocess_data(self, batch_size, df_axiom_links, validation_rate=0.1):
        r"""
        Args:
            batch_size : int
            df_axiom_links pandas DataFrame
            validation_rate
        Returns:
            the training dataloader and the validation dataloader. They contains the list of atoms, their polarities, the axiom links, the sentences tokenized, sentence mask
        """
        print("Start preprocess Data")
        sentences_batch = df_axiom_links["X"].tolist()
        sentences_tokens, sentences_mask = self.Supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch)

        atoms_batch = get_atoms_batch(df_axiom_links["Z"])
        atoms_batch_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms_batch)

        atoms_polarity_batch = find_pos_neg_idexes(self.max_atoms_in_sentence, df_axiom_links["Z"])

        truth_links_batch = get_axiom_links(self.max_atoms_in_one_type, self.sub_atoms_type_list, atoms_polarity_batch,
                                            df_axiom_links["Y"])
        truth_links_batch = truth_links_batch.permute(1, 0, 2)

        # Construction tensor dataset
        dataset = TensorDataset(atoms_batch_tokenized, atoms_polarity_batch, truth_links_batch, sentences_tokens,
                                sentences_mask)

        if validation_rate > 0.0:
            train_size = int(0.9 * len(dataset))
            val_size = len(dataset) - train_size
            train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
            validation_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
        else:
            validation_dataloader = None
            train_dataset = dataset

        training_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
        print("End preprocess Data")
        return training_dataloader, validation_dataloader

    def make_decoder_mask(self, atoms_token):
        decoder_attn_mask = torch.ones_like(atoms_token, dtype=torch.float64, device=self.device)
        decoder_attn_mask[atoms_token.eq(self.padding_id)] = 0.0
        return decoder_attn_mask.unsqueeze(1).repeat(1, atoms_token.shape[1], 1).repeat(self.nhead, 1, 1)

    def forward(self, atoms_batch_tokenized, atoms_polarity_batch, sents_embedding, sents_mask=None):
        r"""
        Args:
            atoms_batch_tokenized : (batch_size, max_atoms_in_one_sentence) flattened categories
            atoms_polarity_batch : (batch_size, max_atoms_in_one_sentence) flattened categories polarities
            sents_embedding : (batch_size, len_sentence, dim_encoder) output of BERT for context
            sents_mask : mask from BERT tokenizer
        Returns:
            link_weights : atom_vocab_size, batch-size, max_atoms_in_one_cat, max_atoms_in_one_cat) log probabilities
        """

        # atoms embedding
        atoms_embedding = self.atoms_embedding(atoms_batch_tokenized)

        # MHA ou LSTM avec sortie de BERT
        sents_mask = sents_mask.unsqueeze(1).repeat(self.nhead, self.max_atoms_in_sentence, 1).to(torch.float64)
        atoms_encoding = self.linker_encoder(atoms_embedding, sents_embedding, sents_mask,
                                             self.make_decoder_mask(atoms_batch_tokenized))

        link_weights = []
        for atom_type in self.sub_atoms_type_list:
            pos_encoding = pad_sequence(
                [get_pos_encoding_for_s_idx(self.dim_embedding_atoms, atoms_encoding, atoms_batch_tokenized,
                                            atoms_polarity_batch, atom_type, self.inverse_map, s_idx)
                 for s_idx in range(len(atoms_polarity_batch))], padding_value=0,
                max_len=self.max_atoms_in_one_type // 2)

            neg_encoding = pad_sequence(
                [get_neg_encoding_for_s_idx(self.dim_embedding_atoms, atoms_encoding, atoms_batch_tokenized,
                                            atoms_polarity_batch, atom_type, self.inverse_map, s_idx)
                 for s_idx in range(len(atoms_polarity_batch))], padding_value=0,
                max_len=self.max_atoms_in_one_type // 2)

            pos_encoding = self.pos_transformation(pos_encoding)
            neg_encoding = self.neg_transformation(neg_encoding)

            weights = torch.bmm(pos_encoding, neg_encoding.transpose(2, 1))
            link_weights.append(sinkhorn(weights, iters=self.sinkhorn_iters))

        total_link_weights = torch.stack(link_weights)
        link_weights_per_batch = total_link_weights.permute(1, 0, 2, 3)

        return F.log_softmax(link_weights_per_batch, dim=3)

    def train_linker(self, df_axiom_links, validation_rate=0.1, epochs=20,
                     batch_size=32, checkpoint=True, tensorboard=False):
        r"""
        Args:
            df_axiom_links : pandas dataFrame containing the atoms anoted with _i
            validation_rate : float
            epochs : int
            batch_size : int
            checkpoint : boolean
            tensorboard : boolean
        Returns:
            Final accuracy and final loss
        """
        training_dataloader, validation_dataloader = self.__preprocess_data(batch_size, df_axiom_links,
                                                                            validation_rate)
        if checkpoint or tensorboard:
            checkpoint_dir, writer = output_create_dir()

        for epoch_i in range(epochs):
            print("")
            print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs))
            print('Training...')
            avg_train_loss, avg_accuracy_train, training_time = self.train_epoch(training_dataloader)

            print("")
            print(f'Epoch: {epoch_i + 1:02} | Epoch Time: {training_time}')
            print(f'\tTrain Loss: {avg_train_loss:.3f} | Train Acc: {avg_accuracy_train * 100:.2f}%')
            if validation_rate > 0.0:
                loss_test, accuracy_test = self.eval_epoch(validation_dataloader, self.cross_entropy_loss)
                print(f'\tVal Loss: {loss_test:.3f} | Val Acc: {accuracy_test * 100:.2f}%')

            if checkpoint:
                self.__checkpoint_save(
                    path=os.path.join("Output", 'linker' + datetime.today().strftime('%d-%m_%H-%M') + '.pt'))

            if tensorboard:
                writer.add_scalars(f'Accuracy', {
                    'Train': avg_accuracy_train}, epoch_i)
                writer.add_scalars(f'Loss', {
                    'Train': avg_train_loss}, epoch_i)
                if validation_rate > 0.0:
                    writer.add_scalars(f'Accuracy', {
                        'Validation': accuracy_test}, epoch_i)
                    writer.add_scalars(f'Loss', {
                        'Validation': loss_test}, epoch_i)

            print('\n')

    def train_epoch(self, training_dataloader):
        r""" Train epoch

        Args:
            training_dataloader : DataLoader from torch , contains atoms, polarities, axiom_links, sents_tokenized, sents_masks
            validation_dataloader : DataLoader from torch , contains atoms, polarities, axiom_links, sents_tokenized, sents_masks
        Returns:
             accuracy on validation set
             loss on train set
        """
        self.train()

        # Reset the total loss for this epoch.
        epoch_loss = 0
        accuracy_train = 0
        t0 = time.time()

        # For each batch of training data...
        with tqdm(training_dataloader, unit="batch") as tepoch:
            for batch in tepoch:
                # Unpack this training batch from our dataloader
                batch_atoms = batch[0].to(self.device)
                batch_polarity = batch[1].to(self.device)
                batch_true_links = batch[2].to(self.device)
                batch_sentences_tokens = batch[3].to(self.device)
                batch_sentences_mask = batch[4].to(self.device)

                self.optimizer.zero_grad()

                # get sentence embedding from BERT which is already trained
                logits, sentences_embedding = self.Supertagger.forward(batch_sentences_tokens, batch_sentences_mask)

                # Run the kinker on the categories predictions
                logits_predictions = self(batch_atoms, batch_polarity, sentences_embedding, batch_sentences_mask)

                linker_loss = self.cross_entropy_loss(logits_predictions, batch_true_links)
                # Perform a backward pass to calculate the gradients.
                epoch_loss += float(linker_loss)
                linker_loss.backward()

                # This is to help prevent the "exploding gradients" problem.
                # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0, norm_type=2)

                # Update parameters and take a step using the computed gradient.
                self.optimizer.step()

                pred_axiom_links = torch.argmax(logits_predictions, dim=3)
                accuracy_train += mesure_accuracy(batch_true_links, pred_axiom_links)

        # Measure how long this epoch took.
        training_time = format_time(time.time() - t0)
        avg_train_loss = epoch_loss / len(training_dataloader)
        avg_accuracy_train = accuracy_train / len(training_dataloader)

        return avg_train_loss, avg_accuracy_train, training_time

    def predict(self, categories, sents_embedding, sents_mask=None):
        r"""Prediction from categories output by BERT and hidden_state from BERT

        Args:
            categories : (batch_size, len_sentence)
            sents_embedding : (batch_size, len_sentence, dim_encoder) output of BERT for context
            sents_mask
        Returns:
            axiom_links : atom_vocab_size, batch-size, max_atoms_in_one_cat)
        """
        self.eval()
        with torch.no_grad():
            # get atoms
            atoms_batch = get_atoms_batch(categories)
            atoms_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms_batch)

            # get polarities
            polarities = find_pos_neg_idexes(self.max_atoms_in_sentence, categories)

            # atoms embedding
            atoms_embedding = self.atoms_embedding(atoms_tokenized)

            # MHA ou LSTM avec sortie de BERT
            atoms_encoding = self.linker_encoder(atoms_embedding, sents_embedding, sents_mask,
                                                 self.make_decoder_mask(atoms_tokenized))

            link_weights = []
            for atom_type in self.sub_atoms_type_list:
                pos_encoding = pad_sequence(
                    [get_pos_encoding_for_s_idx(self.dim_embedding_atoms, atoms_encoding, atoms_tokenized,
                                                polarities, atom_type, self.inverse_map, s_idx)
                     for s_idx in range(len(polarities))], padding_value=0,
                    max_len=self.max_atoms_in_one_type // 2)

                neg_encoding = pad_sequence(
                    [get_neg_encoding_for_s_idx(self.dim_embedding_atoms, atoms_encoding, atoms_tokenized,
                                                polarities, atom_type, self.inverse_map, s_idx)
                     for s_idx in range(len(polarities))], padding_value=0,
                    max_len=self.max_atoms_in_one_type // 2)

                pos_encoding = self.pos_transformation(pos_encoding)
                neg_encoding = self.neg_transformation(neg_encoding)

                weights = torch.bmm(pos_encoding, neg_encoding.transpose(2, 1))
                link_weights.append(sinkhorn(weights, iters=3))

            logits_predictions = torch.stack(link_weights).permute(1, 0, 2, 3)
            axiom_links = torch.argmax(F.log_softmax(logits_predictions, dim=3), dim=3)
            return axiom_links

    def eval_batch(self, batch, cross_entropy_loss):
        batch_atoms = batch[0].to(self.device)
        batch_polarity = batch[1].to(self.device)
        batch_true_links = batch[2].to(self.device)
        batch_sentences_tokens = batch[3].to(self.device)
        batch_sentences_mask = batch[4].to(self.device)

        logits, sentences_embedding = self.Supertagger.forward(batch_sentences_tokens, batch_sentences_mask)
        logits_axiom_links_pred = self(batch_atoms, batch_polarity, sentences_embedding,
                                       batch_sentences_mask)
        axiom_links_pred = torch.argmax(logits_axiom_links_pred, dim=3)

        accuracy = mesure_accuracy(batch_true_links, axiom_links_pred)
        loss = cross_entropy_loss(logits_axiom_links_pred, batch_true_links)

        return loss, accuracy

    def eval_epoch(self, dataloader, cross_entropy_loss):
        r"""Average the evaluation of all the batch.

        Args:
            dataloader: contains all the batch which contain the tokenized sentences, their masks and the true symbols
        """
        self.eval()
        accuracy_average = 0
        loss_average = 0
        with torch.no_grad():
            for step, batch in enumerate(dataloader):
                loss, accuracy = self.eval_batch(batch, cross_entropy_loss)
                accuracy_average += accuracy
                loss_average += float(loss)

        return loss_average / len(dataloader), accuracy_average / len(dataloader)

    def load_weights(self, model_file):
        print("#" * 15)
        try:
            params = torch.load(model_file, map_location=self.device)
            args = params['args']
            self.atom_map = args['atom_map']
            self.max_atoms_in_sentence = args['max_atoms_in_sentence']
            self.atoms_tokenizer = AtomTokenizer(self.atom_map, self.max_atoms_in_sentence)
            self.atoms_embedding.load_state_dict(params['atoms_embedding'])
            self.linker_encoder.load_state_dict(params['linker_encoder'])
            self.pos_transformation.load_state_dict(params['pos_transformation'])
            self.neg_transformation.load_state_dict(params['neg_transformation'])
            self.optimizer.load_state_dict(params['optimizer'])
            print("\n The loading checkpoint was successful ! \n")
        except Exception as e:
            print("\n/!\ Can't load checkpoint model /!\ because :\n\n " + str(e), file=sys.stderr)
            raise e
        print("#" * 15)

    def __checkpoint_save(self, path='/linker.pt'):
        """
        @param path:
        """
        self.cpu()

        torch.save({
            'args': dict(atom_map=self.atom_map, max_atoms_in_sentence=self.max_atoms_in_sentence),
            'atoms_embedding': self.atoms_embedding.state_dict(),
            'linker_encoder': self.linker_encoder.state_dict(),
            'pos_transformation': self.pos_transformation.state_dict(),
            'neg_transformation': self.neg_transformation.state_dict(),
            'optimizer': self.optimizer,
        }, path)
        self.to(self.device)