Skip to content
Snippets Groups Projects
train.py 11.02 KiB
import os
import os
import time
from datetime import datetime

import numpy as np
import torch
import torch.nn.functional as F
import transformers
from torch.optim import SGD
from torch.utils.data import Dataset, TensorDataset, random_split
from transformers import (AutoTokenizer, get_cosine_schedule_with_warmup)
from transformers import (CamembertModel)

from Configuration import Configuration
from SuperTagger.Encoder.EncoderInput import EncoderInput
from SuperTagger.EncoderDecoder import EncoderDecoder
from SuperTagger.Symbol.SymbolTokenizer import SymbolTokenizer
from SuperTagger.Symbol.symbol_map import symbol_map
from SuperTagger.eval import NormCrossEntropy
from SuperTagger.utils import format_time, read_csv_pgbar, checkpoint_save, checkpoint_load

from torch.utils.tensorboard import SummaryWriter

transformers.TOKENIZERS_PARALLELISM = True
torch.cuda.empty_cache()

# region ParamsModel

max_symbols_in_sentence = int(Configuration.modelDecoderConfig['max_symbols_in_sentence'])
max_len_sentence = int(Configuration.modelDecoderConfig['max_len_sentence'])
symbol_vocab_size = int(Configuration.modelDecoderConfig['symbols_vocab_size'])
num_gru_layers = int(Configuration.modelDecoderConfig['num_rnn_layers'])

# endregion ParamsModel

# region ParamsTraining

file_path = 'Datasets/m2_dataset.csv'
batch_size = int(Configuration.modelTrainingConfig['batch_size'])
nb_sentences = batch_size * 50
epochs = int(Configuration.modelTrainingConfig['epoch'])
seed_val = int(Configuration.modelTrainingConfig['seed_val'])
learning_rate = float(Configuration.modelTrainingConfig['learning_rate'])
loss_scaled_by_freq = True

# endregion ParamsTraining

# region OutputTraining

outpout_path = str(Configuration.modelTrainingConfig['output_path'])

training_dir = os.path.join(outpout_path, 'Tranning_' + datetime.today().strftime('%d-%m_%H-%M'))
logs_dir = os.path.join(training_dir, 'logs')

checkpoint_dir = training_dir
writer = SummaryWriter(log_dir=logs_dir)

use_checkpoint_SAVE = bool(Configuration.modelTrainingConfig.getboolean('use_checkpoint_SAVE'))

# endregion OutputTraining

# region InputTraining

input_path = str(Configuration.modelTrainingConfig['input_path'])
model_to_load = str(Configuration.modelTrainingConfig['model_to_load'])
model_to_load_path = os.path.join(input_path, model_to_load)
use_checkpoint_LOAD = bool(Configuration.modelTrainingConfig.getboolean('use_checkpoint_LOAD'))

# endregion InputTraining

# region Print config

print("##" * 15 + "\nConfiguration : \n")

print("ParamsModel\n")

print("\tmax_symbols_in_sentence :", max_symbols_in_sentence)
print("\tsymbol_vocab_size :", symbol_vocab_size)
print("\tbidirectional : ", False)
print("\tnum_gru_layers : ", num_gru_layers)

print("\n ParamsTraining\n")

print("\tDataset :", file_path)
print("\tb_sentences :", nb_sentences)
print("\tbatch_size :", batch_size)
print("\tepochs :", epochs)
print("\tseed_val :", seed_val)

print("\n Output\n")
print("\tuse checkpoint save :", use_checkpoint_SAVE)
print("\tcheckpoint_dir :", checkpoint_dir)
print("\tlogs_dir :", logs_dir)

print("\n Input\n")
print("\tModel to load :", model_to_load_path)
print("\tLoad checkpoint :", use_checkpoint_LOAD)

print("\nLoss and optimizer : ")

print("\tlearning_rate :", learning_rate)
print("\twith loss scaled by freq :", loss_scaled_by_freq)

print("\n Device\n")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("\t", device)

print()
print("##" * 15)

# endregion Print config

# region Model

file_path = 'Datasets/m2_dataset.csv'
BASE_TOKENIZER = AutoTokenizer.from_pretrained(
    'camembert-base',
    do_lower_case=True)
BASE_MODEL = CamembertModel.from_pretrained("camembert-base")
symbols_tokenizer = SymbolTokenizer(symbol_map, max_symbols_in_sentence, max_len_sentence)
sents_tokenizer = EncoderInput(BASE_TOKENIZER)
model = EncoderDecoder(BASE_TOKENIZER, BASE_MODEL, symbol_map)
model = model.to("cuda" if torch.cuda.is_available() else "cpu")

# endregion Model

# region Data loader
df = read_csv_pgbar(file_path, nb_sentences)

symbols_tokenized = symbols_tokenizer.convert_batchs_to_ids(df['sub_tree'])
sents_tokenized, sents_mask = sents_tokenizer.fit_transform_tensors(df['Sentences'].tolist())

dataset = TensorDataset(sents_tokenized, sents_mask, symbols_tokenized)

# Calculate the number of samples to include in each set.
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size

# Divide the dataset by randomly selecting samples.
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

print('{:>5,} training samples'.format(train_size))
print('{:>5,} validation samples'.format(val_size))

training_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
validation_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

# endregion Data loader

# region Fit tunning

# Optimizer
optimizer_encoder = SGD(model.encoder.parameters(),
                        lr=5e-5)
optimizer_decoder = SGD(model.decoder.parameters(),
                        lr=learning_rate)

# Total number of training steps is [number of batches] x [number of epochs].
# (Note that this is not the same as the number of training samples).
total_steps = len(training_dataloader) * epochs

# Create the learning rate scheduler.
scheduler_encoder = get_cosine_schedule_with_warmup(optimizer_encoder,
                                                    num_warmup_steps=0,
                                                    num_training_steps=5)
scheduler_decoder = get_cosine_schedule_with_warmup(optimizer_decoder,
                                            num_warmup_steps=0,
                                            num_training_steps=total_steps)

# Loss
if loss_scaled_by_freq:
    weights = torch.as_tensor(
        [6.9952, 1.0763, 1.0317, 43.274, 16.5276, 11.8821, 28.2416, 2.7548, 1.0728, 3.1847, 8.4521, 6.77, 11.1887,
         6.6692, 23.1277, 11.8821, 4.4338, 1.2303, 5.0238, 8.4376, 1.0656, 4.6886, 1.028, 4.273, 4.273, 0],
        device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
    cross_entropy_loss = NormCrossEntropy(symbols_tokenizer.pad_token_id, symbols_tokenizer.sep_token_id,
                                          weights=weights)
else:
    cross_entropy_loss = NormCrossEntropy(symbols_tokenizer.pad_token_id, symbols_tokenizer.sep_token_id)

np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)
torch.autograd.set_detect_anomaly(True)

# endregion Fit tunning

# region Train

# Measure the total training time for the whole run.
total_t0 = time.time()

validate = True

if use_checkpoint_LOAD:
    model, optimizer_decoder, last_epoch, loss = checkpoint_load(model, optimizer_decoder, model_to_load_path)
    epochs = epochs - last_epoch


def run_epochs(epochs):
    # For each epoch...
    for epoch_i in range(0, epochs):
        # ========================================
        #               Training
        # ========================================

        # Perform one full pass over the training set.

        print("")
        print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs))
        print('Training...')

        # Measure how long the training epoch takes.
        t0 = time.time()

        # Reset the total loss for this epoch.
        total_train_loss = 0

        model.train()

        # For each batch of training data...
        for step, batch in enumerate(training_dataloader):

            # if epoch_i == 0 and step == 0:
            #     writer.add_graph(model, input_to_model=batch[0], verbose=False)

            # Progress update every 40 batches.
            if step % 40 == 0 and not step == 0:
                # Calculate elapsed time in minutes.
                elapsed = format_time(time.time() - t0)
                # Report progress.
                print('  Batch {:>5,}  of  {:>5,}.    Elapsed: {:}.'.format(step, len(training_dataloader), elapsed))

                # Unpack this training batch from our dataloader.
            b_sents_tokenized = batch[0].to("cuda" if torch.cuda.is_available() else "cpu")
            b_sents_mask = batch[1].to("cuda" if torch.cuda.is_available() else "cpu")
            b_symbols_tokenized = batch[2].to("cuda" if torch.cuda.is_available() else "cpu")

            optimizer_encoder.zero_grad()
            optimizer_decoder.zero_grad()

            logits_predictions = model(b_sents_tokenized, b_sents_mask, b_symbols_tokenized)

            predict_trad = [{v: k for k, v in symbol_map.items()}[int(i)] for i in
                            torch.argmax(F.softmax(logits_predictions, dim=2), dim=2)[0]]
            true_trad = [{v: k for k, v in symbol_map.items()}[int(i)] for i in b_symbols_tokenized[0]]
            l = len([i for i in true_trad if i != '[PAD]'])
            if step % 40 == 0 and not step == 0:
                writer.add_text("Sample", "\ntrain true (" + str(l) + ") : " + str(
                    [token for token in true_trad if token != '[PAD]']) + "\ntrain predict (" + str(
                    len([i for i in predict_trad if i != '[PAD]'])) + ") : " + str(
                    [token for token in predict_trad[:l] if token != '[PAD]']))

            loss = cross_entropy_loss(logits_predictions, b_symbols_tokenized)
            # Perform a backward pass to calculate the gradients.
            total_train_loss += float(loss)
            loss.backward()

            # This is to help prevent the "exploding gradients" problem.
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0, norm_type=2)

            # Update parameters and take a step using the computed gradient.
            optimizer_encoder.step()
            optimizer_decoder.step()

            scheduler_encoder.step()
            scheduler_decoder.step()

        # checkpoint

        if use_checkpoint_SAVE:
            checkpoint_save(model, optimizer_decoder, epoch_i, checkpoint_dir, loss)

        avg_train_loss = total_train_loss / len(training_dataloader)

        # Measure how long this epoch took.
        training_time = format_time(time.time() - t0)

        if validate:
            model.eval()
            with torch.no_grad():
                print("Start eval")
                accuracy_sents, accuracy_symbol, v_loss = model.eval_epoch(validation_dataloader, cross_entropy_loss)
                print("")
                print("  Average accuracy sents on epoch: {0:.2f}".format(accuracy_sents))
                print("  Average accuracy symbol on epoch: {0:.2f}".format(accuracy_symbol))
                writer.add_scalar('Accuracy/sents', accuracy_sents, epoch_i + 1)
                writer.add_scalar('Accuracy/symbol', accuracy_symbol, epoch_i + 1)

        print("")
        print("  Average training loss: {0:.2f}".format(avg_train_loss))
        print("  Training epcoh took: {:}".format(training_time))

        # writer.add_scalar('Loss/train', total_train_loss, epoch_i+1)

        writer.add_scalars('Training vs. Validation Loss',
                           {'Training': avg_train_loss, 'Validation': v_loss},
                           epoch_i + 1)
        writer.flush()


run_epochs(epochs)
# endregion Train