#!/usr/bin/env python
# coding: utf-8

import os, io
import torch
from transformers import AutoConfig, AutoTokenizer
from configure import parse_args
import numpy as np
from sklearn.metrics import (
    accuracy_score,
    confusion_matrix,
    classification_report,
    ConfusionMatrixDisplay,
)
import matplotlib.pyplot as plt
import seaborn as sns
from time import sleep
from datetime import datetime

now = datetime.now()
dt_string = now.strftime("%d.%m.%y-%H:%M:%S")


args = parse_args()


def open_mappings(mappings_file):
    """Open the mappings file into a dictionary."""

    mappings = {}
    with open(mappings_file, "r") as f:
        next(f)
        for line in f:
            l = line.strip().split("\t")
            mappings[l[0]] = int(l[-1])

    # reject the converted labels
    inv_mappings = {}
    for k, v in mappings.items():
        if v not in inv_mappings:
            inv_mappings[v] = k

    return mappings, inv_mappings


def encode_label(og_label, mappings_dict):
    """Encode the label."""

    label = og_label.lower().strip()
    if label in mappings_dict:
        return mappings_dict[label]
    else:
        return mappings_dict["unk"]


def open_file(filename, mappings_dict):
    """Function to open a .rels file.
    Arguments:
    - filename: the path to a .rels file
    - mappings_dict: a dictionary of mappings of unique labels to integers
    Returns a list of lists, where each list is:
    the line + [two sentences combined with special BERT token, encoded label]
    """

    max_len = 254  # 512 (max bert len) / 2 (2 sents) -2 (special tokens)
    lines = []
    SEP_token = "[SEP]"

    with open(filename, "r", encoding="utf-8") as f:
        next(f)
        for line in f:
            l = line.strip().split("\t")

            if len(l) > 1:
                # chop the sentences to max_len if too long
                sent_1 = l[3].split(" ")
                sent_2 = l[4].split(" ")

                if len(sent_1) > max_len:
                    sent_1 = sent_1[:max_len]
                if len(sent_2) > max_len:
                    sent_2 = sent_2[:max_len]

                # flip them if different direction
                if args.normalize_direction == "yes":
                    if l[9] == "1>2":
                        lines.append(
                            l
                            + [
                                sent_1 + [SEP_token] + sent_2,
                                encode_label(l[-1], mappings_dict),
                            ]
                        )
                    else:
                        lines.append(
                            l
                            + [
                                sent_2 + [SEP_token] + sent_1,
                                encode_label(l[-1], mappings_dict),
                            ]
                        )
                else:
                    lines.append(
                        l
                        + [
                            sent_1 + [SEP_token] + sent_2,
                            encode_label(l[-1], mappings_dict),
                        ]
                    )

    return lines


def open_file_with_lang(filename, mappings_dict):
    """Same as above, but add the lcf toekns at the start of the sequence."""

    max_len = 254  # 512 (max bert len) / 2 (2 sents) -2 (special tokens)
    lines = []
    SEP_token = "[SEP]"

    langs = {
        "deu": "German",
        "eng": "English",
        "eus": "Basque",
        "fas": "Farsi",
        "fra": "French",
        "ita": "Italian",
        "nld": "Dutch",
        "por": "Portuguese",
        "rus": "Russian",
        "spa": "Spanish",
        "tur": "Turkish",
        "tha": "Thai",
        "zho": "Chinese",
    }

    with open(filename, "r", encoding="utf-8") as f:
        next(f)

        lang = langs[filename.split("/")[-2].split(".")[0]]
        framework = filename.split("/")[-2].split(".")[1]
        fullname = filename.split("/")[-2]

        for line in f:
            l = line.strip().split("\t")

            if len(l) > 1:
                # chop the sentences to max_len if too long
                sent_1 = l[3].split(" ")
                sent_2 = l[4].split(" ")

                if len(sent_1) > max_len:
                    sent_1 = sent_1[:max_len]
                if len(sent_2) > max_len:
                    sent_2 = sent_2[:max_len]

                # flip them if different direction
                if args.normalize_direction == "yes":
                    if l[9] == "1>2":
                        # lang, fullname, framework
                        lines.append(
                            l
                            + [
                                [lang, fullname, framework]
                                + sent_1
                                + [SEP_token]
                                + sent_2,
                                encode_label(l[11], mappings_dict),
                            ]
                        )
                    else:
                        lines.append(
                            l
                            + [
                                [lang, fullname, framework]
                                + sent_2
                                + [SEP_token]
                                + sent_1,
                                encode_label(l[11], mappings_dict),
                            ]
                        )
                else:
                    lines.append(
                        l
                        + [
                            [lang, fullname, framework] + sent_1 + [SEP_token] + sent_2,
                            encode_label(l[11], mappings_dict),
                        ]
                    )

    return lines


# ===============
# OPENING FILES FUNCTIONS
# ===============


def open_sentences(path_to_corpora, mappings_dict):
    """Opens all the corpora and the surprise corpora in train/dev/test sets.
    Uses the open_file() function from utils.
    Returns:
    - list of sentences for TRAIN: all the corpora and surprise corpora together
    - dict of sentences for DEV: each dev set categorized per corpus
    - dict of sentences for TEST: each test set categorized per corpus
    - ** NEW ** : dict of labels per framework
    """
    langs_to_use = False

    if args.langs_to_use != "@":
        langs_to_use = args.langs_to_use.split(";")

    corpora = [
        folder
        for folder in os.listdir(path_to_corpora)
        if not any(i in folder for i in [".md", "DS_", "utils", "ipynb"])
    ]

    # ---------------------
    train_sentences = []
    dev_dict_sentences = {}
    test_dict_sentences = {}

    all_labels = {}

    for corpus in corpora:
        framework = corpus.split(".")[-2]
        if not framework in all_labels:
            all_labels[framework] = []

        # ===== open train ====
        try:
            # open normal files

            if langs_to_use:
                # if we only train with cetrain corpora, we only load them
                train_file = [
                    "/".join([path_to_corpora, corpus, x])
                    for x in os.listdir(path_to_corpora + "/" + corpus)
                    if "train" in x and "rels" in x
                    if any(l in x for l in langs_to_use)
                ][0]
            else:
                train_file = train_file = os.path.join(
                    args.data_path, corpus, corpus + "_train.rels"
                )
            temp = open_file(train_file, mappings_dict)
            train_sentences += temp
            all_labels[framework] += [l[-1] for l in temp]

        except:  # some of them don't have train
            pass

        # ======== open dev ========
        dev_dict_sentences[corpus] = []
        dev_file = os.path.join(args.data_path, corpus, corpus + "_dev.rels")
        temp = open_file(dev_file, mappings_dict)
        dev_dict_sentences[corpus] += temp
        all_labels[framework] += [l[-1] for l in temp]

        # ======== open test ========
        test_dict_sentences[corpus] = []
        test_file = os.path.join(args.data_path, corpus, corpus + "_test.rels")
        temp = open_file(test_file, mappings_dict)
        test_dict_sentences[corpus] += temp
        all_labels[framework] += [l[-1] for l in temp]

    corpus_labels = {framework: set(all_labels[framework]) for framework in all_labels}
    # delete unk as a sanity check
    for framework in corpus_labels:
        if "unk" in corpus_labels[framework]:
            corpus_labels[framework].remove("unk")

    return train_sentences, dev_dict_sentences, test_dict_sentences, corpus_labels


def open_sentences_with_lang(path_to_corpora, mappings_dict):
    """Opens all the corpora and the surprise corpora in train/dev/test sets.
    Uses the open_file() function from utils.
    Returns:
    - list of sentences for TRAIN: all the corpora and surprise corpora together
    - dict of sentences for DEV: each dev set categorized per corpus
    - dict of sentences for TEST: each test set categorized per corpus
    """
    langs_to_use = False

    if args.langs_to_use != "@":
        langs_to_use = args.langs_to_use.split(";")

    corpora = [
        folder
        for folder in os.listdir(path_to_corpora)
        if not any(i in folder for i in [".md", "DS_", "utils", "ipynb"])
    ]

    # ---------------------
    train_sentences = []
    dev_dict_sentences = {}
    test_dict_sentences = {}

    all_labels = {}

    for corpus in corpora:
        framework = corpus.split(".")[-2]
        if not framework in all_labels:
            all_labels[framework] = []

        # ===== open train ====
        try:
            # open normal files
            if langs_to_use:
                # if we only train with cetrain corpora, we only load them
                train_file = [
                    "/".join([path_to_corpora, corpus, x])
                    for x in os.listdir(path_to_corpora + "/" + corpus)
                    if "train" in x and "rels" in x
                    if any(l in x for l in langs_to_use)
                ][0]
            else:
                train_file = [
                    "/".join([path_to_corpora, corpus, x])
                    for x in os.listdir(path_to_corpora + "/" + corpus)
                    if "train" in x and "rels" in x
                ][0]
            temp = open_file_with_lang(train_file, mappings_dict)
            train_sentences += temp
            all_labels[framework] += [l[-1] for l in temp]
        except:  # some of them don't have train
            pass

        # open each test separately
        dev_dict_sentences[corpus] = []
        dev_file = [
            "/".join([path_to_corpora, corpus, x])
            for x in os.listdir(path_to_corpora + "/" + corpus)
            if "dev" in x and "rels" in x
        ][0]
        temp = open_file_with_lang(dev_file, mappings_dict)
        dev_dict_sentences[corpus] += temp
        all_labels[framework] += [l[-1] for l in temp]

        # open each test separately
        test_dict_sentences[corpus] = []
        test_file = [
            "/".join([path_to_corpora, corpus, x])
            for x in os.listdir(path_to_corpora + "/" + corpus)
            if "test" in x and "rels" in x
        ][0]
        temp = open_file_with_lang(test_file, mappings_dict)
        test_dict_sentences[corpus] += temp
        all_labels[framework] += [l[-1] for l in temp]

    corpus_labels = {framework: set(all_labels[framework]) for framework in all_labels}
    # delete unk as a sanity check
    for framework in corpus_labels:
        if "unk" in corpus_labels[framework]:
            corpus_labels[framework].remove("unk")

    return train_sentences, dev_dict_sentences, test_dict_sentences, corpus_labels


# ===============
# Testing functions
# ===============


def get_predictions(model, corpus, test_dataloader, print_results=True):
    """Function to get the model's predictions for one corpus' test set.
    Can print accuracy using scikit-learn.
    Also works with dev sets -- just don't save the outputs.
    Returns: list of predictions that match test file's lines.
    """

    device = torch.device("cuda" if args.use_cuda else "cpu")

    if args.use_cuda:
        model = model.cuda()

    model.eval()
    test_loss, test_accuracy = 0, 0

    all_labels = []
    all_preds = []

    with torch.no_grad():
        for test_input, test_label in test_dataloader:
            mask = test_input["attention_mask"].to(device)
            input_id = test_input["input_ids"].squeeze(1).to(device)
            output = model(input_id, mask)

            logits = output[0]
            logits = logits.detach().cpu().numpy()
            label_ids = test_label.to("cpu").numpy()

            all_labels += label_ids.tolist()
            all_preds += output.argmax(dim=1).tolist()

        assert len(all_labels) == len(all_preds)
        test_acc = round(accuracy_score(all_labels, all_preds), 4)

    if print_results:
        print(corpus, "\tAccuracy:\t", test_acc)

    return all_preds


def make_confusion_matrices(y_test, y_pred, corpus_name, inv_mappings, epoch):
    save_path = "conf_matrix/" + dt_string
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    print(
        classification_report(
            y_test,
            y_pred,
        )
    )

    cm = confusion_matrix(y_test, y_pred, labels=list(inv_mappings.keys()))
    print(cm)

    xticklabels = list(inv_mappings.values())
    yticklabels = list(inv_mappings.values())

    sns.color_palette("cubehelix", as_cmap=True)
    # Plot the confusion matrix.

    fig, ax = plt.subplots()
    #     ax.tick_params(axis='both', which='major', labelsize=6)
    #     ax.tick_params(axis='both', which='minor', labelsize=6)
    ax = sns.heatmap(
        cm,
        # annot=Truex
        xticklabels=xticklabels,
        yticklabels=yticklabels,
    )
    plt.ylabel("Predicted label")
    plt.xlabel("Corpus label")
    plt.xticks(fontsize=2)
    plt.yticks(fontsize=2)
    #     plt.xticks(x, labels, rotation='vertical')
    #     plt.margins(0.5)
    plt.subplots_adjust(bottom=0.5, left=0.5)
    plt.title("Confusion Matrix: " + corpus_name + " (epoch:" + str(epoch) + ")")
    plt.savefig(save_path + "/" + corpus_name + "_" + str(epoch) + ".png", dpi=300)
    plt.clf()


def get_better_predictions(
    model,
    corpus,
    test_dataloader,
    corpus_labels,
    inv_mappings,
    epoch,
    print_results=True,
    save_conf_matrix=False,
):
    device = torch.device("cuda" if args.use_cuda else "cpu")

    if args.use_cuda:
        model = model.cuda()

    model.eval()
    all_labels = []
    all_preds = []

    with torch.no_grad():
        for test_input, test_label in test_dataloader:
            mask = test_input["attention_mask"].to(device)
            input_id = test_input["input_ids"].squeeze(1).to(device)
            output = model(input_id, mask)

            logits = output[0]
            logits = logits.detach().cpu().numpy()
            label_ids = test_label.to("cpu").numpy()

            # all_labels += label_ids.tolist()
            batch_labels = label_ids.tolist()
            batch_probs = []
            for p in output.softmax(dim=-1).tolist():
                batch_probs.append(dict(enumerate(p)))

            for probs in batch_probs:
                final_probs = {}
                sorted_probs = dict(sorted(probs.items(), key=lambda item: item[1]))
                for pred_label in sorted_probs:
                    if pred_label in corpus_labels:
                        final_probs[pred_label] = sorted_probs[pred_label]

                all_preds += [final_probs]

            all_labels += batch_labels

    # get the top predictions in order to get the acc

    top_preds = []
    for probs in all_preds:
        top_preds.append(max(zip(probs.values(), probs.keys()))[1])
    test_acc = round(accuracy_score(all_labels, top_preds), 4)

    if print_results:
        print("After label filtering:\t" + str(test_acc), flush="True")

        print(classification_report(all_labels, top_preds))

    if save_conf_matrix:
        try:
            make_confusion_matrices(all_labels, top_preds, corpus, inv_mappings, epoch)
        except ValueError:
            print("matrix failed to print")

    print()
    print("----")

    return all_labels, all_preds