Skip to content
Snippets Groups Projects
pytorch_classifier.py 7.22 KiB
#!/usr/bin/env python
# coding: utf-8

import torch
import numpy as np
from transformers import AutoModel, AutoTokenizer, get_linear_schedule_with_warmup, set_seed
from torch import nn
from torch.optim import AdamW
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch.autograd import Variable
from tqdm import tqdm
import os
from time import sleep
from datetime import datetime
import sys
from sklearn.metrics import classification_report, accuracy_score
from configure import parse_args
from utils import *

args = parse_args()
now = datetime.now()
dt_string = now.strftime("%d.%m.%y-%H:%M:%S")
layers_to_freeze = args.freeze_layers.split(";")
substitutions_file = 'mappings/substitutions.txt'
specific_results = open_specific_results('mappings/specific_results.txt')['B']
set_seed(42)

# ===============
# Dataset class
# ===============

class Dataset(torch.utils.data.Dataset):

    def __init__(self, sentences):

        self.labels = [sent[-1] for sent in sentences]
        self.texts = [tokenizer(sent[-2], 
                                is_split_into_words=True,                              
                                padding='max_length', 
                                max_length = 512, 
                                truncation=True,
                                return_tensors="pt") 
                                for sent in sentences]

    def classes(self):
        return self.labels

    def __len__(self):
        return len(self.labels)
    
    def get_batch_labels(self, idx):
        # Fetch a batch of labels
        return np.array(self.labels[idx])

    def get_batch_texts(self, idx):
        # Fetch a batch of inputs
        return self.texts[idx]

    def __getitem__(self, idx):

        batch_texts = self.get_batch_texts(idx)
        batch_y = self.get_batch_labels(idx)

        return batch_texts, batch_y

# ===============
# Load datasets
# ===============

# Open mappings
mappings, inv_mappings = open_mappings(args.mappings_file)
batch_size = args.batch_size
tokenizer  = AutoTokenizer.from_pretrained(args.transformer_model)

train_sentences, dev_dict_sentences, test_dict_sentences = open_sentences(args.data_path, mappings)

# Determine linear size (= number of classes in the sets + 1)
num_labels = len(set(sent[-1] for sent in train_sentences)) + 1

# make train/dev datasets
train_dataset = Dataset(train_sentences)
dev_dataset   = {corpus: Dataset(s) for corpus, s in dev_dict_sentences.items()}
test_dataset  = {corpus: Dataset(s) for corpus, s in test_dict_sentences.items()}

# Make dasets with batches and dataloader
train_dataloader = DataLoader(train_dataset, batch_size, shuffle=True)
dev_dict_dataloader = {corpus: DataLoader(dev_data, batch_size) 
                        for corpus, dev_data in dev_dataset.items()}
test_dict_dataloader = {corpus: DataLoader(test_data, batch_size) 
                        for corpus, test_data in test_dataset.items()}


# ===============
# Model setup
# ===============

class TransformerClassifier(nn.Module):

    def __init__(self, dropout=args.dropout):

        super(TransformerClassifier, self).__init__()

        self.tr_model = AutoModel.from_pretrained(args.transformer_model)
        self.dropout = nn.Dropout(dropout)
        self.linear = nn.Linear(768, num_labels) # bert input x num of classes
        self.relu = nn.ReLU()

    def forward(self, input_id, mask):
        
        outputs = self.tr_model(input_ids = input_id, 
                                attention_mask = mask,
                                return_dict = True)['last_hidden_state'][:, 0, :]
        dropout_output = self.dropout(outputs)
        linear_output = self.linear(dropout_output)
        final_layer = self.relu(linear_output)

        return final_layer


model = TransformerClassifier()


def train(model, train_dataloader, dev_dict_dataloader, test_dict_sentences, epochs, specific_results):

    device = torch.device("cuda" if args.use_cuda else "cpu")

    criterion = nn.CrossEntropyLoss()
    optimizer = AdamW(model.parameters(), #Adam
                      lr = 2e-5, #1e-6
                      eps = 1e-8
                    )

    if args.use_cuda:
        model = model.cuda()
        criterion = criterion.cuda()
    
    gradient_accumulation_steps = args.gradient_accumulation_steps
    total_steps = len(train_dataloader) * epochs
    scheduler = get_linear_schedule_with_warmup(optimizer, 
                                                num_warmup_steps = 0,
                                                num_training_steps = total_steps)
    
    seed_val = 42
    torch.manual_seed(seed_val)
    torch.cuda.manual_seed_all(seed_val)
    
    # freeze layers, see argument in configure.py
    if args.freeze_layers != '':
        for name, param in model.named_parameters():
            if any(x in name for x in layers_to_freeze):
                param.requires_grad = False

    for epoch_num in range(0, epochs):
        print('\n=== Epoch {:} / {:} ==='.format(epoch_num + 1, epochs))
        model.train()

        total_acc_train = 0
        total_loss_train = 0

        for train_input, train_label in tqdm(train_dataloader):
            train_label = train_label.to(device)
            mask = train_input['attention_mask'].to(device)
            input_id = train_input['input_ids'].squeeze(1).to(device)

            output = model(input_id, mask)
                
            batch_loss = criterion(output, train_label.long())
            total_loss_train += batch_loss.item()
                
            acc = (output.argmax(dim=1) == train_label).sum().item()
            total_acc_train += acc

            model.zero_grad()
            batch_loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            
        # ------ Validation --------
        print('\nValidation for epoch:', epoch_num + 1)
        
        # Dev results for each corpus. We don't need to save the results.
        for corpus in dev_dict_dataloader:
            _ = get_predictions(model, corpus, dev_dict_dataloader[corpus])
            
        # we want the results of specific epochs for specific corpora. 
        # we define the epochs and the corpora and we save only these results.
        
        if epoch_num+1 in specific_results:
            if corpus in specific_results[epoch_num+1]:
                test_results = get_predictions(model, corpus, dev_dict_dataloader[corpus], 
                                               print_results=False)
                print_results_to_file(corpus, test_dict_sentences[corpus], test_results,
                                      inv_mappings, substitutions_file)

                
# ------- Start the training -------   

print('\nModel: ', args.transformer_model)
print('Batch size: ', args.batch_size * args.gradient_accumulation_steps)
print('\nStart training...\n')
train(model, train_dataloader, dev_dict_dataloader, test_dict_sentences, args.num_epochs, specific_results)
print('\nTraining Done!')


# ------- Testing ---------

for corpus in test_dict_dataloader:
    test_results = get_predictions(model, corpus, test_dict_dataloader[corpus])
    print_results_to_file(corpus, test_dict_sentences[corpus], test_results,
                          inv_mappings, substitutions_file)