from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig,HfArgumentParser,TrainingArguments,pipeline, logging, TextStreamer
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model
import os, torch, platform, warnings, json
from datasets import Dataset
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
import pandas as pd
import importlib
import transformers
from alive_progress import alive_bar

spec = importlib.util.spec_from_file_location("utils", "../utils.py")
utils = importlib.util.module_from_spec(spec)
spec.loader.exec_module(utils)


def fine_tune(base_model, new_model):

    path = '../../../data/finetuned_dataset.csv'
    dataframe = pd.read_csv(path)
    dataframe = dataframe.replace('<s> ', '', regex=True)
    dataframe['concat'] = dataframe['input'].astype(str) + dataframe['output'].astype(str)
    dataset = Dataset.from_pandas(dataframe, split="train")

    # Load base model
    bnb_config = BitsAndBytesConfig(
        load_in_4bit= True,
        bnb_4bit_quant_type= "nf4",
        bnb_4bit_compute_dtype= torch.bfloat16,
        bnb_4bit_use_double_quant= False,
    )
    model = AutoModelForCausalLM.from_pretrained(
        base_model,
        quantization_config=bnb_config,
        device_map={"": 0}
    )
    model.config.use_cache = False # silence the warnings. Please re-enable for inference!
    model.config.pretraining_tp = 1
    model.gradient_checkpointing_enable()
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.add_eos_token = True
    tokenizer.add_bos_token, tokenizer.add_eos_token

    model = prepare_model_for_kbit_training(model)
    peft_config = LoraConfig(
            r=16,
            lora_alpha=16,
            lora_dropout=0.05,
            bias="none",
            task_type="CAUSAL_LM",
            target_modules=["q_proj", "k_proj", "v_proj", "o_proj","gate_proj", "up_proj", "down_proj"]
        )
    model = get_peft_model(model, peft_config)

    # Training Arguments
    # Hyperparameters should beadjusted based on the hardware you using
    training_arguments = TrainingArguments(
        per_device_train_batch_size=2,
        gradient_accumulation_steps=1,
        learning_rate=1e-4,
        logging_steps=2,
        optim="adamw_torch",
        save_strategy="steps",
        output_dir="./results"
    )

    # Setting sft parameters
    trainer = SFTTrainer(
        model=model,
        train_dataset=dataset,
        peft_config=peft_config,
        max_seq_length= None,
        dataset_text_field="concat",
        tokenizer=tokenizer,
        args=training_arguments,
        packing= False,
    )


    trainer.train()

    # Save the fine-tuned model
    trainer.model.save_pretrained(new_model)
    model.config.use_cache = True
    model.eval()

def generate(base_model, new_model):

    base_model_reload = transformers.AutoModelForCausalLM.from_pretrained(
        base_model,
        #torch_dtype=torch.float16,
        device_map="auto",
        load_in_8bit=False,
        load_in_4bit=True,
        #attn_implementation="flash_attention_2"
    )
    model = PeftModel.from_pretrained(base_model_reload, new_model)
    model = model.merge_and_unload()

    tokenizer = transformers.AutoTokenizer.from_pretrained(base_model)

    generate_text = transformers.pipeline(
        model=model, tokenizer=tokenizer,
        return_full_text=False,  # if using langchain set True
        task="text-generation",
        # we pass model parameters here too
        do_sample=True,
        temperature=0.5,  # 'randomness' of outputs, 0.0 is the min and 1.0 the max
        top_p=0.15,  # select from top tokens whose probability add up to 15%
        top_k=0,  # select from top 0 tokens (because zero, relies on top_p)
        max_new_tokens=2048,  # max number of tokens to generate in the output
        repetition_penalty=1.0  # if output begins repeating increase
    )

    def instruction_format(sys_message: str, query: str):
        # note, don't "</s>" to the end
        return f'<s> [INST] {sys_message} [/INST]\nUser: {query}\nAssistant: '


    with open('../../../data/evalQS.json', 'r', encoding='utf-8') as file:
        loaded = json.load(file)

    input = []
    output = {}

    with alive_bar(len(loaded)) as bar:
        for sentence in loaded:
            input.append(instruction_format(utils.get_pre_prompt_zero_shot(), sentence))
            bar()
    print("Input creation finished")

    res = generate_text(input)

    i = 0
    for sentence in loaded:
        output[sentence] = res[i][0]["generated_text"]
        i += 1


    with open('../../../results/LLM/Mistral-7b/MISTRAL_fine_tuned_raw_answers.json', 'w', encoding='utf-8') as file:
        json.dump(output, file)
    

#######################################################################################################################

base_model = "../../../models/Mistral-7B-Instruct-v0.2"
new_model = "../../../models/Fine-tuned_Mistral-7B"

fine_tune(base_model, new_model)
generate(base_model, new_model)

print("========== Program finished ==========")