Skip to content
Snippets Groups Projects
train.py 6.73 KiB
import os
import time
from datetime import datetime

import numpy as np
import torch
from torch.optim import AdamW
from torch.utils.data import Dataset, TensorDataset, random_split
from transformers import (get_cosine_schedule_with_warmup)

from Configuration import Configuration
from Linker.AtomTokenizer import AtomTokenizer
from Linker.Linker import Linker
from Linker.atom_map import atom_map
from Linker.utils_linker import get_axiom_links, get_atoms_batch, find_pos_neg_idexes
from Linker.eval import SinkhornLoss
from utils import format_time, read_csv_pgbar

torch.cuda.empty_cache()

# region ParamsModel

max_len_sentence = int(Configuration.datasetConfig['max_len_sentence'])
max_atoms_in_sentence = int(Configuration.datasetConfig['max_atoms_in_sentence'])
max_atoms_in_one_type = int(Configuration.datasetConfig['max_atoms_in_one_type'])
atom_vocab_size = int(Configuration.datasetConfig['atom_vocab_size'])

# endregion ParamsModel

# region ParamsTraining

batch_size = int(Configuration.modelTrainingConfig['batch_size'])
nb_sentences = batch_size * 10
epochs = int(Configuration.modelTrainingConfig['epoch'])
seed_val = int(Configuration.modelTrainingConfig['seed_val'])
learning_rate = float(Configuration.modelTrainingConfig['learning_rate'])

# endregion ParamsTraining

# region Data loader

file_path_axiom_links = 'Datasets/aa1_links_dataset_links.csv'
df_axiom_links = read_csv_pgbar(file_path_axiom_links, nb_sentences)

sentences_batch = df_axiom_links["Sentences"]

atoms_batch = get_atoms_batch(df_axiom_links["sub_tree"])
atom_tokenizer = AtomTokenizer(atom_map, max_atoms_in_sentence)
atoms_batch_tokenized = atom_tokenizer.convert_batchs_to_ids(atoms_batch)
print("atoms_tokens", atoms_batch_tokenized.shape)

atoms_polarity_batch = find_pos_neg_idexes(max_atoms_in_sentence, df_axiom_links["sub_tree"])
print("atoms_polarity_batch", atoms_polarity_batch.shape)

torch.set_printoptions(edgeitems=20)
truth_links_batch = get_axiom_links(max_atoms_in_one_type, atoms_polarity_batch, df_axiom_links["sub_tree"])
truth_links_batch = truth_links_batch.permute(1, 0, 2)
print("truth_links_batch", truth_links_batch.shape)
print("sentence", sentences_batch[14])
print("categories ", df_axiom_links["sub_tree"][14])
print("atoms_batch", atoms_batch[14])
print("atoms_polarity_batch", atoms_polarity_batch[14])
print(" truth_links_batch example on a sentence class txt", truth_links_batch[14][16])

# Construction tensor dataset
dataset = TensorDataset(atoms_batch_tokenized, atoms_polarity_batch, truth_links_batch)

# Calculate the number of samples to include in each set.
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size

# Divide the dataset by randomly selecting samples.
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

training_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
validation_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# endregion Data loader


# region Models

# supertagger = SuperTagger()
# supertagger.load_weights("models/model_check.pt")

linker = Linker()

# endregion Models


# region Fit tunning

# Optimizer
optimizer_linker = AdamW(linker.parameters(),
                         weight_decay=1e-5,
                         lr=learning_rate)

# Create the learning rate scheduler.
scheduler_linker = get_cosine_schedule_with_warmup(optimizer_linker,
                                                   num_warmup_steps=0,
                                                   num_training_steps=100)

# Loss
cross_entropy_loss = SinkhornLoss()

np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)
torch.autograd.set_detect_anomaly(True)

# endregion Fit tunning

# region Train

# Measure the total training time for the whole run.
total_t0 = time.time()

validate = True
checkpoint = True


def run_epochs(epochs):
    # For each epoch...
    for epoch_i in range(0, epochs):
        # ========================================
        #               Training
        # ========================================

        # Perform one full pass over the training set.

        print("")
        print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs))
        print('Training...')

        # Measure how long the training epoch takes.
        t0 = time.time()

        # Reset the total loss for this epoch.
        total_train_loss = 0

        linker.train()

        # For each batch of training data...
        for step, batch in enumerate(training_dataloader):
            # Unpack this training batch from our dataloader
            batch_atoms = batch[0].to("cuda" if torch.cuda.is_available() else "cpu")
            batch_polarity = batch[1].to("cuda" if torch.cuda.is_available() else "cpu")
            batch_true_links = batch[2].to("cuda" if torch.cuda.is_available() else "cpu")
            # batch_sentences = batch[3].to("cuda" if torch.cuda.is_available() else "cpu")

            optimizer_linker.zero_grad()

            # get sentence embedding from BERT which is already trained
            # sentences_embedding = supertagger(batch_sentences)

            # Run the kinker on the categories predictions
            logits_predictions = linker(batch_atoms, batch_polarity, [])

            linker_loss = cross_entropy_loss(logits_predictions.permute(1, 0, 2, 3), batch_true_links)
            # Perform a backward pass to calculate the gradients.
            total_train_loss += float(linker_loss)
            linker_loss.backward()

            # This is to help prevent the "exploding gradients" problem.
            # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0, norm_type=2)

            # Update parameters and take a step using the computed gradient.
            optimizer_linker.step()
            scheduler_linker.step()

        avg_train_loss = total_train_loss / len(training_dataloader)

        # Measure how long this epoch took.
        training_time = format_time(time.time() - t0)

        if checkpoint:
            checkpoint_dir = os.path.join("Output", 'Tranning_' + datetime.today().strftime('%d-%m_%H-%M'))
            linker.__checkpoint_save(path=os.path.join(checkpoint_dir, 'model_check.pt'))

        if validate:
            linker.eval()
            with torch.no_grad():
                print("Start eval")
                accuracy, loss = linker.eval_epoch(validation_dataloader, cross_entropy_loss)
                print("")
                print("  Average accuracy on epoch: {0:.2f}".format(accuracy))
                print("  Average loss on epoch: {0:.2f}".format(loss))

        print("")
        print("  Average training loss: {0:.2f}".format(avg_train_loss))
        print("  Training epcoh took: {:}".format(training_time))


run_epochs(epochs)
# endregion Train