Skip to content
Snippets Groups Projects
classifier_features_pytorch.py 8.03 KiB
#!/usr/bin/env python
# coding: utf-8

import torch
import numpy as np
from transformers import (
    AutoModel,
    AutoTokenizer,
    get_linear_schedule_with_warmup,
    set_seed,
)
from torch import nn
from torch.optim import AdamW
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch.autograd import Variable
from tqdm import tqdm
import os
from time import sleep
from datetime import datetime
import sys
from sklearn.metrics import classification_report, accuracy_score
from configure import parse_args
from utils import *

args = parse_args()
now = datetime.now()
dt_string = now.strftime("%d.%m.%y-%H:%M:%S")
layers_to_freeze = args.freeze_layers.split(";")

print("Datasets used: " + args.langs_to_use)
print("\nDirection: " + args.normalize_direction)
print("\nMappings file: " + args.mappings_file, flush="True")


# ===============
# Dataset class
# ===============


class Dataset(torch.utils.data.Dataset):
    def __init__(self, sentences):
        self.labels = [sent[-1] for sent in sentences]
        self.texts = [
            tokenizer(
                sent[-2],
                is_split_into_words=True,
                padding="max_length",
                max_length=512,
                truncation=True,
                return_tensors="pt",
            )
            for sent in sentences
        ]

    def classes(self):
        return self.labels

    def __len__(self):
        return len(self.labels)

    def get_batch_labels(self, idx):
        # Fetch a batch of labels
        return np.array(self.labels[idx])

    def get_batch_texts(self, idx):
        # Fetch a batch of inputs
        return self.texts[idx]

    def __getitem__(self, idx):
        batch_texts = self.get_batch_texts(idx)
        batch_y = self.get_batch_labels(idx)

        return batch_texts, batch_y


# ===============
# Load datasets
# ===============

# Open mappings
mappings, inv_mappings = open_mappings(args.mappings_file)
batch_size = args.batch_size
tokenizer = AutoTokenizer.from_pretrained(args.transformer_model)

(
    train_sentences,
    dev_dict_sentences,
    test_dict_sentences,
    framework_labels,
    disco_features,
) = open_sentences_with_feats(args.data_path, mappings)

# add disco features as tokens of bert
tokenizer.add_tokens(
    [
        "German",
        "English",
        "Basque",
        "Farsi",
        "French",
        "Dutch",
        "Portuguese",
        "Russian",
        "Spanish",
        "Turkish",
        "Chinese",
        "spa.rst.sctb",
        "rus.rst.rrt",
        "fra.sdrt.annodis",
        "por.rst.cstn",
        "eng.sdrt.stac",
        "eus.rst.ert",
        "eng.pdtb.pdtb",
        "deu.rst.pcc",
        "eng.rst.rstdt",
        "zho.rst.sctb",
        "nld.rst.nldt",
        "tur.pdtb.tdb",
        "spa.rst.rststb",
        "fas.rst.prstc",
        "zho.pdtb.cdtb",
        "eng.rst.gum",
        "rst",
        "pdtb",
        "sdrt",
    ]
)
tokenizer.add_tokens(disco_features)

# Determine linear size (= number of classes in the sets + 1)
num_labels = len(set(sent[-1] for sent in train_sentences)) + 1

# make train/dev datasets
train_dataset = Dataset(train_sentences)
dev_dataset = {corpus: Dataset(s) for corpus, s in dev_dict_sentences.items()}
test_dataset = {corpus: Dataset(s) for corpus, s in test_dict_sentences.items()}

# Make dasets with batches and dataloader
train_dataloader = DataLoader(train_dataset, batch_size, shuffle=True)
dev_dict_dataloader = {
    corpus: DataLoader(dev_data, batch_size) for corpus, dev_data in dev_dataset.items()
}
test_dict_dataloader = {
    corpus: DataLoader(test_data, batch_size)
    for corpus, test_data in test_dataset.items()
}


# ===============
# Model setup
# ===============


class TransformerClassifier(nn.Module):
    def __init__(self, dropout=args.dropout):
        super(TransformerClassifier, self).__init__()

        self.tr_model = AutoModel.from_pretrained(args.transformer_model)
        self.dropout = nn.Dropout(dropout)
        self.linear = nn.Linear(768, num_labels)  # bert input x num of classes
        self.relu = nn.ReLU()

    def forward(self, input_id, mask):
        outputs = self.tr_model(
            input_ids=input_id, attention_mask=mask, return_dict=True
        )["last_hidden_state"][:, 0, :]
        dropout_output = self.dropout(outputs)
        linear_output = self.linear(dropout_output)
        final_layer = self.relu(linear_output)

        return final_layer


model = TransformerClassifier()
model.resize_token_embeddings(len(tokenizer))


def train(
    model,
    train_dataloader,
    dev_dict_dataloader,
    test_dict_sentences,
    test_dict_dataloader,
    epochs,
    # specific_results
):
    device = torch.device("cuda" if args.use_cuda else "cpu")

    criterion = nn.CrossEntropyLoss()
    optimizer = AdamW(model.parameters(), lr=2e-5, eps=1e-8)  # Adam  # 1e-6

    if args.use_cuda:
        model = model.cuda()
        criterion = criterion.cuda()

    gradient_accumulation_steps = args.gradient_accumulation_steps
    total_steps = len(train_dataloader) * epochs
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=0, num_training_steps=total_steps
    )

    seed_val = seed_val
    set_seed(42)
    torch.manual_seed(seed_val)
    torch.cuda.manual_seed_all(seed_val)

    # freeze layers, see argument in configure.py
    if args.freeze_layers != "":
        for name, param in model.named_parameters():
            if any(x in name for x in layers_to_freeze):
                param.requires_grad = False

    for epoch_num in range(0, epochs):
        print("\n=== Epoch {:} / {:} ===".format(epoch_num + 1, epochs), flush="True")

        model.train()

        total_acc_train = 0
        total_loss_train = 0
        batch_counter = 0

        for train_input, train_label in tqdm(train_dataloader):
            batch_counter += 1
            train_label = train_label.to(device)
            mask = train_input["attention_mask"].to(device)
            input_id = train_input["input_ids"].squeeze(1).to(device)

            output = model(input_id, mask)

            # Compute Loss and Perform Back-propagation
            loss = criterion(output, train_label.long())

            # Normalize the Gradients
            loss = loss / gradient_accumulation_steps
            loss.backward()

            if batch_counter % gradient_accumulation_steps == 0:
                # Update Optimizer
                optimizer.step()  # or flip them?
                optimizer.zero_grad()

                model.zero_grad()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                scheduler.step()

        # ------ Validation --------

        print("\nValidation for epoch:", epoch_num + 1)

        # Dev and test results for each corpus. We don't need to save the results.
        for corpus in dev_dict_dataloader:
            dev_results = get_predictions(model, corpus, dev_dict_dataloader[corpus])
            better_dev_results = get_better_predictions(
                model,
                corpus,
                dev_dict_dataloader[corpus],
                framework_labels[corpus.split(".")[1]],
                inv_mappings,
                epoch_num + 1,
                save_conf_matrix=False,
            )

        # ------ Test --------

        print("\nTest results for epoch:", epoch_num + 1)

        for corpus in test_dict_dataloader:
            test_results = get_predictions(model, corpus, test_dict_dataloader[corpus])
            better_test_results = get_better_predictions(
                model,
                corpus,
                test_dict_dataloader[corpus],
                framework_labels[corpus.split(".")[1]],
                inv_mappings,
                epoch_num + 1,
                save_conf_matrix=False,
            )


# ------- Start the training -------

print("\nModel: ", args.transformer_model)
print("Batch size: ", args.batch_size * args.gradient_accumulation_steps)
print("\nStart training...\n")
train(
    model,
    train_dataloader,
    dev_dict_dataloader,
    test_dict_sentences,
    test_dict_dataloader,
    args.num_epochs,
)
print("\nTraining Done!")