#!/usr/bin/env python
# coding: utf-8

import os
import torch
from transformers import AutoConfig, AutoTokenizer
from configure import parse_args
import numpy as np
from sklearn.metrics import accuracy_score

args = parse_args()


def open_specific_results(specific_results_file):
    specific_results = {'A1': {}, 'A1_3': {}, 'B':{}}
    
    with open(specific_results_file, 'r') as f:
        next(f)
        for line in f:
            l = line.strip().split('\t')
            if not int(l[1]) in specific_results[l[0]]:
                specific_results[l[0]][int(l[1])] = []
            specific_results[l[0]][int(l[1])].append(l[2])
            
    return specific_results
    

def open_mappings(mappings_file):
    
    ''' Open the mappings file into a dictionary.'''
    
    mappings = {}
    with open(mappings_file, 'r') as f:
        for l in f:
            mappings[l.split('\t')[0]] = int(l.strip().split('\t')[1])
    inv_mappings = {v:k for k, v in mappings.items()}

    return mappings, inv_mappings


def open_file(filename, mappings_dict):   
    
    ''' Function to open a .rels file. 
        Arguments: 
        - filename: the path to a .rels file 
        - mappings_dict: a dictionary of mappings of unique labels to integers
        Returns a list of lists, where each list is:
        the line + [two sentences combined with special BERT token, encoded label]
    '''
    
    max_len = 254 # 512 (max bert len) / 2 (2 sents) -2 (special tokens)
    lines = []
    SEP_token = '[SEP]'

    with open(filename, 'r', encoding='utf-8') as f:
        next(f)
        for line in f:
            l = line.strip().split('\t')
            
            if len(l) > 1:
                # chop the sentences to max_len if too long
                sent_1 = l[3].split(' ')
                sent_2 = l[4].split(' ')      
                
                if len(sent_1) > max_len:
                    sent_1 = sent_1[:max_len]
                if len(sent_2) > max_len:
                    sent_2 = sent_2[:max_len]
                
                # flip them if different direction
                if args.normalize_direction == 'yes':
                    if l[9] == '1>2':
                        lines.append(l + [sent_1 + [SEP_token] + sent_2, mappings_dict[l[11].lower()]])
                    else:
                        lines.append(l + [sent_2 + [SEP_token] + sent_1, mappings_dict[l[11].lower()]])
                else:
                    lines.append(l + [sent_1 + [SEP_token] + sent_2, mappings[l[11].lower()]])

    return lines


def encode_batch(batch):
    
    """ Encodes a batch of input data using the model tokenizer.
        Works for a pandas DF column, instead of a list.
    """
    tokenizer = AutoTokenizer.from_pretrained(args.transformer_model)
    return tokenizer(batch["text"], 
                     max_length=512, 
                     truncation=True, 
                     padding="max_length"
                    )

def open_sentences(path_to_corpora, mappings_dict):
    ''' Opens all the corpora and the surprise corpora in train/dev/test sets.
        Uses the open_file() function from utils.
        Returns:
        - list of sentences for TRAIN: all the corpora and surprise corpora together
        - dict of sentences for DEV: each dev set categorized per corpus
        - dict of sentences for TEST: each test set categorized per corpus
    '''
    
    corpora = [folder for folder in os.listdir(path_to_corpora) 
               if not any(i in folder for i in ['.md', 'DS_', 'utils', 'ipynb'])]
               
    # ---------------------
    train_sentences     = []
    dev_dict_sentences  = {}
    test_dict_sentences = {}

    for corpus in corpora:
        
        try:
            # open normal files   
            train_file = ['/'.join([path_to_corpora, corpus, x])
                              for x in os.listdir(path_to_corpora + '/' + corpus) 
                              if 'train' in x and 'rels' in x][0]
            train_sentences += open_file(train_file, mappings_dict)
        except: # some of them don't have train
            pass

        #open each test separately
        dev_dict_sentences[corpus] = []
        dev_file = ['/'.join([path_to_corpora,corpus,x])
                              for x in os.listdir(path_to_corpora + '/' + corpus) 
                              if 'dev' in x and 'rels' in x][0] 
        dev_dict_sentences[corpus] += open_file(dev_file, mappings_dict)

        #open each test separately
        test_dict_sentences[corpus] = []
        test_file = ['/'.join([path_to_corpora,corpus,x])
                              for x in os.listdir(path_to_corpora + '/' + corpus) 
                              if 'test' in x and 'rels' in x][0] 
        test_dict_sentences[corpus] += open_file(test_file, mappings_dict)

    
    return train_sentences, dev_dict_sentences, test_dict_sentences


# ===============
# Testing functions
# ===============

def get_predictions(model, corpus, test_dataloader, print_results=True):
    
    ''' Function to get the model's predictions for one corpus' test set.
        Can print accuracy using scikit-learn.
        Also works with dev sets -- just don't save the outputs.
        Returns: list of predictions that match test file's lines.
    '''
    
    device = torch.device("cuda" if args.use_cuda else "cpu")

    if args.use_cuda:
        model = model.cuda()
    
    model.eval()
    test_loss, test_accuracy = 0, 0

    all_labels = []
    all_preds = []
    
    with torch.no_grad():
        for test_input, test_label in test_dataloader:

            mask = test_input['attention_mask'].to(device)
            input_id = test_input['input_ids'].squeeze(1).to(device)
            output = model(input_id, mask)

            logits = output[0]
            logits = logits.detach().cpu().numpy()
            label_ids = test_label.to('cpu').numpy()

            all_labels += label_ids.tolist()
            all_preds += output.argmax(dim=1).tolist()

        assert len(all_labels) == len(all_preds)
        test_acc = round(accuracy_score(all_labels, all_preds), 4)
    
    if print_results:
        print(corpus, '\tAccuracy:\t', test_acc)
    
    return all_preds
    
    
def get_predictions_huggingface(trainer, corpus, test_set, print_results=True):
    
    ''' SPECIFI FUNCTION FOR THE HUGGINGFACE TRAINER.
        Function to get the model's predictions for one corpus' test set.
        Can print accuracy using scikit-learn.
        Also works with dev sets -- just don't save the outputs.
        Returns: list of predictions that match test file's lines.
    '''

    results = trainer.predict(test_set)
    preds = np.argmax(results.predictions, axis=1)
    results = results.label_ids
    test_acc = round(accuracy_score(preds, results), 4)
    
    if print_results:
        print(corpus, '\tAccuracy:\t', test_acc, '\n')
    
    return preds
    
    
def print_results_to_file(corpus, test_sentences, test_results, 
                          inv_mappings_dict, substitutions_file):
    
    ''' Function to print a new file with the test predictions per 
        the specifications of the Shared task.
        Returns: one file per corpus with predictions.
    '''
    
    # create a dict of all the substitutions that were made
    revert_substitutions = {}
    with open(substitutions_file, 'r', encoding='utf-8') as f:
        next(f)
        for line in f:
            l = line.strip().split('\t')
            if not l[1] in revert_substitutions:
                revert_substitutions[l[1]] = {}
            revert_substitutions[l[1]][l[2]] = l[0]
    
    # save the results in a separate folder, one file per corpus
    if not os.path.exists('test_results_ST3'):
        os.makedirs('test_results_ST3')
    
    results_to_write = []
    
    for n, sent in enumerate(test_sentences):
        label = test_results[n]
        label = inv_mappings_dict[label]
        if corpus in revert_substitutions:
            if label in revert_substitutions[corpus]:
                label = revert_substitutions[corpus][label]
        temp  = sent[:-3] + [label]
        assert len(temp) == 12
        results_to_write.append(temp)
    
    with open('test_results_ST3/' + corpus + '.tsv', 'a+', encoding='utf-8') as f:
        for line in results_to_write:
            f.write('\t'.join([str(x) for x in line]))