from dash import html
import dash_interactive_graphviz
from sklearn import tree
from pages.application.RandomForest.utils import xrf
from pages.application.RandomForest.utils.xrf.xforest import XRF, Dataset


class RandomForestComponent:

    def __init__(self, model, info=None, type_info=''):

        # Conversion model
        self.data = Dataset(info)

        if info is not None and 'csv' in type_info:
            self.random_forest = XRF(model, self.data.feature_names, self.data.target_name)
            # encoding here so not in the explanation

        self.tree_to_plot = 0
        dot_source = tree.export_graphviz(self.random_forest.cls.estimators()[self.tree_to_plot],
                                          feature_names=self.data.feature_names, class_names=list(map(lambda cl : str(cl), self.data.target_name)),
                                          impurity=False, filled=False, rounded=True)
        self.network = html.Div([dash_interactive_graphviz.DashInteractiveGraphviz(
            dot_source=dot_source, style={"width": "50%",
                                          "height": "80%",
                                          "background-color": "transparent"}
        )])
        self.explanation = []

    def update_with_explicability(self, instances, enum_feats=None, xtype=None, solver=None):

        instances = [list(map(lambda feature: feature[1], instance)) for instance in instances]
        self.explanation = []
        for instance in instances:
            self.explanation.append(html.H4("Sample : "))
            # Call instance
            self.explanation.append(html.H5("Instance"))
            self.explanation.append(html.Hr())
            self.explanation.append(
                html.P(str([tuple((self.data.feature_names[i], str(instance[i]))) for i in
                            range(len(instance) - 1)]) + " THEN " + str(
                    tuple((self.data.target_name, str(instance[-1]))))))
            self.explanation.append(html.Hr())

            # Call explanation
            explanation_result = None
            if isinstance(self.random_forest, XRF):
                explanation_result = self.random_forest.explain(instance)
            # Creating a clean and nice text component
            for k in explanation_result.keys():
                self.explanation.append(html.H5(k))
                self.explanation.append(html.Hr())
                self.explanation.append(html.P(explanation_result[k]))
                self.explanation.append(html.Hr())

        del self.random_forest.enc
        del self.random_forest.x

        return [], []

    def update_plotted_tree(self, tree_to_plot):
        self.tree_to_plot = tree_to_plot
        dot_source = tree.export_graphviz(self.random_forest.cls.estimators()[self.tree_to_plot],
                                          feature_names=self.data.feature_names, class_names=list(map(lambda cl : str(cl), self.data.target_name)),
                                          impurity=False, filled=False, rounded=True)
        self.network = html.Div([dash_interactive_graphviz.DashInteractiveGraphviz(
            dot_source=dot_source, style={"width": "50%",
                                          "height": "80%",
                                          "background-color": "transparent"}
        )])