#!/usr/bin/env python
# coding: utf-8

import torch
import numpy as np
from transformers import AutoModel, AutoTokenizer, get_linear_schedule_with_warmup, AutoAdapterModel, AutoModelWithHeads, AutoConfig, TrainingArguments, Trainer, EvalPrediction, set_seed
from torch import nn
from torch.optim import AdamW
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch.autograd import Variable
from tqdm import tqdm
import os
from time import sleep
from datetime import datetime
import sys
from sklearn.metrics import classification_report, accuracy_score
from utils import open_file
import pandas as pd
import datasets
from configure import parse_args
from utils import *

args = parse_args()
now = datetime.now()
dt_string = now.strftime("%d.%m.%y-%H:%M:%S")
adapter_name = args.adapter_name
mappings, inv_mappings = open_mappings(args.mappings_file)
substitutions_file = 'mappings/substitutions.txt'
tokenizer = AutoTokenizer.from_pretrained(args.transformer_model)

# we are saving the test results of specific epochs
# specific_results = open_specific_results('mappings/specific_results.txt')
# if '1-2-3' in adapter_name or 'layer1;layer2;layer3' in adapter_name:
#     specific_results = list(specific_results['A1_3'][args.num_epochs])
# else:
#     specific_results = list(specific_results['A1'][args.num_epochs])

set_seed(42)

print('Train classifier with adapter\n')
print('Adapter name:', adapter_name)
print('Model:', args.transformer_model)
print('Batch size:', args.batch_size * args.gradient_accumulation_steps)
print('Num epochs:', args.num_epochs)

# Open mappings
mappings, inv_mappings = open_mappings(args.mappings_file)

# Open sentences
train_sentences, dev_dict_sentences, test_dict_sentences = open_sentences(args.data_path, mappings)

# make pandas dataframes
file_header = ['text', 'labels']

train_df = pd.DataFrame([[' '.join(x[-2]), x[-1]] for x in train_sentences], 
                        columns =file_header)
train_df = train_df.sample(frac = 1) # shuffle the train

dev_dict_df = {corpus : pd.DataFrame([[' '.join(x[-2]), x[-1]] 
                                      for x in sents], 
                                     columns = file_header)
               for corpus, sents in dev_dict_sentences.items()}

test_dict_df = {corpus : pd.DataFrame([[' '.join(x[-2]), x[-1]] 
                                      for x in sents], 
                                     columns = file_header)
               for corpus, sents in test_dict_sentences.items()}

#Make datasets from dataframes
train_dataset = datasets.Dataset.from_pandas(train_df)
dev_dict_dataset  = {corpus:datasets.Dataset.from_pandas(dev_df) 
                     for corpus, dev_df in dev_dict_df.items()}
test_dict_dataset = {corpus:datasets.Dataset.from_pandas(dev_df) 
                     for corpus, dev_df in test_dict_df.items()}

# get number of labels
num_labels = len(set([int(x.strip()) 
                      for x in train_df['labels'].to_string(index=False).split('\n')])) +1

# Encode the data
train_dataset = train_dataset.map(encode_batch, batched=True)
train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])

encoded_dev_dataset = {}
for corpus in dev_dict_dataset:
    temp = dev_dict_dataset[corpus].map(encode_batch, batched=True)
    temp.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
    encoded_dev_dataset[corpus] = temp

encoded_test_dataset = {}
for corpus in test_dict_dataset:
    temp = test_dict_dataset[corpus].map(encode_batch, batched=True)
    temp.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
    encoded_test_dataset[corpus] = temp

# ===============================
# Training params
# ===============================

model = AutoAdapterModel.from_pretrained(args.transformer_model)
active_adapter = model.load_adapter(adapter_name,
                                  config = adapter_name + "/adapter_config.json")
model.set_active_adapters(active_adapter)


training_args = TrainingArguments(
    learning_rate    = 2e-5, #1e-4,
    num_train_epochs = args.num_epochs,
    per_device_train_batch_size = args.batch_size,
    per_device_eval_batch_size  = args.batch_size,
    gradient_accumulation_steps = args.gradient_accumulation_steps,
    logging_steps  = (len(train_sentences)/(args.batch_size * args.gradient_accumulation_steps)),
    output_dir = "./training_output",
    overwrite_output_dir =True,
    remove_unused_columns=False,
)


trainer = Trainer(
    model = model,
    args  = training_args,
    train_dataset = train_dataset
)

# Freeze layers in the classifier if desired
if args.freeze_layers != '':
    layers_to_freeze = args.freeze_layers.split(';')
    for name, param in model.named_parameters():
        if any(x in name for x in layers_to_freeze):
            param.requires_grad = False


# ===============================
# Start the training 🚀
# ===============================

print('Start training...')
trainer.train()

# Dev results

print('\nDev results:')
for corpus in encoded_dev_dataset:
    print()
    dev_results = get_predictions_huggingface(trainer, corpus, 
                                    encoded_dev_dataset[corpus])
    
    
    path_results = 'results/dev/' + adapter_name + '_' + str(args.num_epochs)
    if not os.path.exists(path_results):
        os.makedirs(path_results)
                
    print_results_to_file(corpus, 
                          dev_dict_sentences[corpus], 
                          dev_results,
                          inv_mappings, 
                          substitutions_file, 
                          path_results)

# Test results

print('\ntest results:')
for corpus in encoded_test_dataset:
    print()
    test_results = get_predictions_huggingface(trainer, 
                                               corpus, 
                                               encoded_test_dataset[corpus])
    
    
    path_results = 'results/test/' + adapter_name + '_' + str(args.num_epochs)
    if not os.path.exists(path_results):
        os.makedirs(path_results)
                
    print_results_to_file(corpus, 
                          test_dict_sentences[corpus], 
                          test_results,
                          inv_mappings, 
                          substitutions_file, 
                          path_results)



#         for corpus in test_dict_dataloader:
#             test_results = get_predictions(model, 
#                                 corpus, 
#                                 test_dict_dataloader[corpus])
            
#             path_results = 'results/test/pytorch' + str(epoch_num+1)
#             if not os.path.exists(path_results):
#                 os.makedirs(path_results)
                
#             print_results_to_file(corpus, 
#                                 test_dict_sentences[corpus], 
#                                 test_results,
#                                 inv_mappings, substitutions_file, 
#                                 path_results)    
    
    
    
    
    
    

# Save specific test results

# print('\nTest results:')
# for corpus in encoded_test_dataset:
#     print()
#     test_results = get_predictions_huggingface(trainer, corpus, 
#                                     encoded_test_dataset[corpus])
# 
#     print_results_to_file(corpus, test_dict_sentences[corpus], test_results, 
#                           inv_mappings, substitutions_file)