#!/usr/bin/env python
"""
This module is the entrypoint for the generation of time series.
"""
import json
import os

from argparse import ArgumentParser
from datetime import datetime

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from generator import DatasetGenerator
from split import BaseSplitter
from trainers import BaseTrainer
from results import ResultExtractor, ROCResults

plt.rcParams["figure.figsize"] = (20, 10)


if __name__ == "__main__":
    # =================================================================================================================
    #                                           Fetch arguments from CLI
    # =================================================================================================================
    # Create the parser
    parser = ArgumentParser(prog="Time Series Generator", description="Create time series.")

    # Add arguments
    parser.add_argument(
        "task",
        help="Either generate data, test an algorithm, compute the results of an execution or do everything.",
        choices=["generate", "train", "results", "all"]
    )
    parser.add_argument(
        "-c", "--config", 
        help="Path to config file(s) to use to generate data.",
        default=["config/2d-mix.json"],
        nargs="+"
    )
    parser.add_argument("-a", "--algorithms", help="Which algorithm to train.", default=["kmeans"], nargs="+")
    parser.add_argument(
        "-m", "--cluster-method",
        help="Which cluster algorithm to use.",
        default="HDBSCAN",
        dest="method"
    )
    parser.add_argument("-i", "--input", help="Input directory. Only to be used when no data will be generated")
    parser.add_argument("-o", "--output", help="Output directory")
    parser.add_argument(
        "-s", "--split",
        help="Automatically split dataset into subsystems and treat them once at a time.",
        action="store_true"
    )
    parser.add_argument(
        "-optim",
        help="Optimize the different hyper parameters according to what's wirtten in the algorithm_params.json file",
        action="store_true"
    )
    parser.add_argument(
        "-d", "--docker",
        help="Use Docker containers directly to run the algorithm. Allow to run algorithms without cloning repo",
        action="store_true"
    )
    parser.add_argument(
        "--merge",
        help="Whether the method should merge the correlation matrices or not.",
        action="store_true"
    )

    # Load args
    args = parser.parse_args()

    # Prepare output directory
    if args.output is None:
        OUTPUT_DIR = f"output/{datetime.now().strftime('%Y-%m-%d.%H-%M-%S')}"
    else:
        OUTPUT_DIR = f"output/{args.output}"

    # Prepare input directory if needed
    if args.task not in ["generate", "all"] and args.input is None:
        raise ValueError("Impossible to skip data generation and not give an input directory")

    if args.task not in ["generate", "all"]:
        INPUT_DIR = f"output/{args.input}"
    else:
        INPUT_DIR = OUTPUT_DIR

    # =================================================================================================================
    #                                                Generate Data
    # =================================================================================================================
    # Load config file
    for config_file in args.config:
        # Compute config name
        config_name = config_file.split("/")[-1][:-5] if "." in config_file else config_file.split("/")[-1]

        if args.task in ["generate", "all"]:
            # Create output dir
            os.makedirs(f"{OUTPUT_DIR}/{config_name}", exist_ok=True)

            # Read config file
            with open(config_file, "r", encoding="utf-8") as f:
                config = json.load(f)

            # Data generation
            generator = DatasetGenerator(config)

            dataset, train_dataset, labels, variables_labels = generator.generate()
            subsystems, splitted_data, splitted_train, splitted_labels = generator.get_splitted_data()

            # Save data to disk
            # Prepare the data
            columns = list(range(0, dataset.shape[0]))
            indices = list(range(0, dataset.shape[1]))
            t_indices = list(range(0, train_dataset.shape[1]))
            df_test = pd.DataFrame(data=dataset.T, index=indices, columns=columns)
            df_test["is_anomaly"] = labels
            df_train = pd.DataFrame(data=train_dataset.T, index=t_indices, columns=columns)
            df_train["is_anomaly"] = np.zeros(train_dataset.shape[1])

            df_test.to_csv(f"{OUTPUT_DIR}/{config_name}/dataset.csv", index_label="Timestamp")
            df_train.to_csv(f"{OUTPUT_DIR}/{config_name}/dataset_train.csv", index_label="Timestamp")
            pd.DataFrame(data=labels).to_csv(f"{OUTPUT_DIR}/{config_name}/dataset_labels.csv", index=False)
            pd.DataFrame(data=variables_labels.T)\
                .to_csv(f"{OUTPUT_DIR}/{config_name}/dataset_variables_labels.csv", index=False)

            # Plot data and save it to disk
            for dimension in dataset:
                plt.plot(dimension)
            plt.savefig(f"{OUTPUT_DIR}/{config_name}/dataset.png")
            plt.clf()

            for dimension in train_dataset:
                plt.plot(dimension)
            plt.savefig(f"{OUTPUT_DIR}/{config_name}/train_dataset.png")
            plt.clf()

            # Handle splitted data
            for idx, (data, train, lab) in enumerate(zip(splitted_data, splitted_train, splitted_labels)):
                columns = list(range(0, data.shape[0]))
                df_test_s = pd.DataFrame(data=data.T, index=indices, columns=columns)
                df_test_s["is_anomaly"] = lab
                df_train_s = pd.DataFrame(data=train.T, index=t_indices, columns=columns)
                df_train_s["is_anomaly"] = np.zeros(train.shape[1])

                df_test_s.to_csv(f"{OUTPUT_DIR}/{config_name}/dataset_{idx}.csv", index_label="Timestamp")
                df_train_s.to_csv(f"{OUTPUT_DIR}/{config_name}/dataset_{idx}_train.csv", index_label="Timestamp")
                pd.DataFrame(data=lab).to_csv(f"{OUTPUT_DIR}/{config_name}/dataset_{idx}_labels.csv", index=False)

            with open(f"{OUTPUT_DIR}/{config_name}/dataset_clusters.txt", "w", encoding="utf-8") as f:
                clusters = np.zeros(dataset.shape[0]).astype(int)
                for idx, cluster in enumerate(subsystems):
                    for member in cluster:
                        clusters[member] = idx
                f.write(json.dumps(clusters.tolist()))


    # =================================================================================================================
    #                                                 Split data
    # =================================================================================================================

        if args.split and args.task in ["train", "all"]:
            splitter = BaseSplitter(f"{INPUT_DIR}/{config_name}")
            splitter.split_data(method=args.method, merge=args.merge)

    # =================================================================================================================
    #                                               Train algorithm
    # =================================================================================================================
        if args.task in ["train", "all"]:
            trainers = []

            with open("algorithm_params.json", "r", encoding="utf-8") as f:
                algo_params = json.load(f)

            with open(f"{INPUT_DIR}/{config_name}/time.csv", "a", encoding="utf-8") as f:
                f.write("Algorithm,Dataset,Step,Duration\n")

            for algo in args.algorithms:
                params = algo_params.get(algo, {"training": True})
                train = params.pop("training")
                trainer = BaseTrainer(f"{INPUT_DIR}/{config_name}", algo, train, **params)
                trainer.start(optim=args.optim)
                trainers.append(trainer)

    # =================================================================================================================
    #                                           Compute and plot results
    # =================================================================================================================
    if args.task in ["results", "all"]:
        # ROCResults(INPUT_DIR, args.algorithms, args.config).compute(auto_split=args.split)
        results = ResultExtractor(INPUT_DIR).fetch_results().compute_results()
        results.to_csv(f"{INPUT_DIR}/results.csv")