Skip to content
Snippets Groups Projects
DecisionTreeComponent.py 8.04 KiB
import dash_interactive_graphviz
import numpy as np
from dash import html

from pages.application.DecisionTree.utils.data import Data
from pages.application.DecisionTree.utils.dtree import DecisionTree
from pages.application.DecisionTree.utils.dtviz import (visualize,
                                                        visualize_expl,
                                                        visualize_instance,
                                                        visualize_contrastive_expl)
from pages.application.DecisionTree.utils.upload_tree import UploadedDecisionTree


class DecisionTreeComponent():

    def __init__(self, tree, type_tree='SKL', info=None, type_info=''):

        if info is not None and '.csv' in type_info:
            self.categorical = True
            data = Data(info)
            fvmap = data.mapping_features()
            feature_names = data.names[:-1]
            self.uploaded_dt = UploadedDecisionTree(tree, type_tree, maxdepth=tree.get_depth(),
                                                    feature_names=feature_names, nb_classes=tree.n_classes_)
            self.dt_format, self.map, features_names_mapping = self.uploaded_dt.dump(fvmap, feat_names=feature_names)

        elif info is not None and '.txt' in type_info:
            self.categorical = True
            fvmap = {}
            feature_names = []
            for i, line in enumerate(info.split('\n')):
                fid, TYPE = line.split(',')[:2]
                dom = line.split(',')[2:]
                assert (fid not in feature_names)
                feature_names.append(fid)
                assert (TYPE in ['Binary', 'Categorical'])
                fvmap[f'f{i}'] = dict()
                dom = sorted(dom)
                for j, v in enumerate(dom):
                    fvmap[f'f{i}'][j] = (fid, True, v)
            self.uploaded_dt = UploadedDecisionTree(tree, type_tree, maxdepth=tree.get_depth(),
                                                    feature_names=feature_names, nb_classes=tree.n_classes_)
            self.dt_format, self.map, features_names_mapping = self.uploaded_dt.dump(fvmap, feat_names=feature_names)

        else:
            self.categorical = False
            try:
                feature_names = tree.feature_names_in_
            except:
                feature_names = [f'f{i}' for i in range(tree.n_features_in_)]
            self.uploaded_dt = UploadedDecisionTree(tree, type_tree, maxdepth=tree.get_depth(),
                                                    feature_names=feature_names, nb_classes=tree.n_classes_)
            self.dt_format, self.map, features_names_mapping = self.uploaded_dt.convert_dt(feat_names=feature_names)

        self.mapping_instance = self.create_fvmap_inverse(features_names_mapping)
        self.dt = DecisionTree(from_dt=self.dt_format, mapfile=self.map, feature_names=feature_names)
        dot_source = visualize(self.dt)
        self.network = html.Div(
            [dash_interactive_graphviz.DashInteractiveGraphviz(dot_source=dot_source, style={"width": "60%",
                                                                                             "height": "90%",
                                                                                             "background-color": "transparent"})])
        self.explanation = []

    def create_fvmap_inverse(self, instance):
        def create_fvmap_inverse_with_info(features_names_mapping):
            mapping_instance = {}
            for feat in features_names_mapping:
                feat_dic = {}
                feature_description = feat.split(',')
                name_feat, id_feat = feature_description[1].split(':')

                for mapping in feature_description[2:]:
                    real_value, mapped_value = mapping.split(':')
                    feat_dic[np.float32(real_value)] = int(mapped_value)
                mapping_instance[name_feat] = feat_dic

            return mapping_instance

        def create_fvmap_inverse_threashold(features_names_mapping):
            mapping_instance = {}
            for feat in features_names_mapping:
                feature_description = feat.split(',')
                name_feat, id_feat = feature_description[1].split(':')
                mapping_instance[name_feat] = float(feature_description[2].split(':')[0])

            return mapping_instance

        if self.categorical:
            return create_fvmap_inverse_with_info(instance)
        else:
            return create_fvmap_inverse_threashold(instance)

    def translate_instance(self, instance):
        def translate_instance_categorical(instance):
            instance_translated = []
            for feat, real_value in instance:
                instance_translated.append((feat, self.mapping_instance[feat][real_value]))
            return instance_translated

        def translate_instance_threasholds(instance):
            instance_translated = []
            for feat, real_value in instance:
                try:
                    if real_value <= self.mapping_instance[feat]:
                        instance_translated.append((feat, 0))
                    else:
                        instance_translated.append((feat, 1))
                except:
                    instance_translated.append((feat, real_value))
            return instance_translated

        if self.categorical:
            return translate_instance_categorical(instance)
        else:
            return translate_instance_threasholds(instance)

    def update_with_explicability(self, instance, enum, xtype, solver):

        instance_translated = self.translate_instance(instance)
        self.explanation = []
        list_explanations_path = []
        list_contrastive_explanations_path = []
        explanation = self.dt.explain(instance_translated, enum=enum, xtype=xtype, solver=solver)

        dot_source = visualize_instance(self.dt, instance_translated)
        self.network = html.Div([dash_interactive_graphviz.DashInteractiveGraphviz(
            dot_source=dot_source, style={"width": "50%",
                                          "height": "80%",
                                          "background-color": "transparent"}
        )])

        # Creating a clean and nice text component
        # instance plotting
        self.explanation.append(html.H4("Instance : \n"))
        self.explanation.append(html.P(str([str(instance[i]) for i in range(len(instance))])))
        for k in explanation.keys():
            if k != "List of path explanation(s)" and k != "List of path contrastive explanation(s)":
                if k in ["List of abductive explanation(s)", "List of contrastive explanation(s)"]:
                    self.explanation.append(html.H4(k))
                    for expl in explanation[k]:
                        self.explanation.append(html.Hr())
                        self.explanation.append(html.P(expl))
                        self.explanation.append(html.Hr())
                else:
                    self.explanation.append(html.P(k + explanation[k]))
            else:
                list_explanations_path = explanation["List of path explanation(s)"]
                list_contrastive_explanations_path = explanation["List of path contrastive explanation(s)"]

        return list_explanations_path, list_contrastive_explanations_path

    def draw_explanation(self, instance, expl):
        instance = self.translate_instance(instance)
        dot_source = visualize_expl(self.dt, instance, expl)
        self.network = html.Div([dash_interactive_graphviz.DashInteractiveGraphviz(
            dot_source=dot_source,
            style={"width": "50%",
                   "height": "80%",
                   "background-color": "transparent"})])

    def draw_contrastive_explanation(self, instance, cont_expl):
        instance = self.translate_instance(instance)
        dot_source = visualize_contrastive_expl(self.dt, instance, cont_expl)
        self.network = html.Div([dash_interactive_graphviz.DashInteractiveGraphviz(
            dot_source=dot_source,
            style={"width": "50%",
                   "height": "80%",
                   "background-color": "transparent"})])