-
Caroline DE POURTALES authoredCaroline DE POURTALES authored
DecisionTreeComponent.py 8.04 KiB
import dash_interactive_graphviz
import numpy as np
from dash import html
from pages.application.DecisionTree.utils.data import Data
from pages.application.DecisionTree.utils.dtree import DecisionTree
from pages.application.DecisionTree.utils.dtviz import (visualize,
visualize_expl,
visualize_instance,
visualize_contrastive_expl)
from pages.application.DecisionTree.utils.upload_tree import UploadedDecisionTree
class DecisionTreeComponent():
def __init__(self, tree, type_tree='SKL', info=None, type_info=''):
if info is not None and '.csv' in type_info:
self.categorical = True
data = Data(info)
fvmap = data.mapping_features()
feature_names = data.names[:-1]
self.uploaded_dt = UploadedDecisionTree(tree, type_tree, maxdepth=tree.get_depth(),
feature_names=feature_names, nb_classes=tree.n_classes_)
self.dt_format, self.map, features_names_mapping = self.uploaded_dt.dump(fvmap, feat_names=feature_names)
elif info is not None and '.txt' in type_info:
self.categorical = True
fvmap = {}
feature_names = []
for i, line in enumerate(info.split('\n')):
fid, TYPE = line.split(',')[:2]
dom = line.split(',')[2:]
assert (fid not in feature_names)
feature_names.append(fid)
assert (TYPE in ['Binary', 'Categorical'])
fvmap[f'f{i}'] = dict()
dom = sorted(dom)
for j, v in enumerate(dom):
fvmap[f'f{i}'][j] = (fid, True, v)
self.uploaded_dt = UploadedDecisionTree(tree, type_tree, maxdepth=tree.get_depth(),
feature_names=feature_names, nb_classes=tree.n_classes_)
self.dt_format, self.map, features_names_mapping = self.uploaded_dt.dump(fvmap, feat_names=feature_names)
else:
self.categorical = False
try:
feature_names = tree.feature_names_in_
except:
feature_names = [f'f{i}' for i in range(tree.n_features_in_)]
self.uploaded_dt = UploadedDecisionTree(tree, type_tree, maxdepth=tree.get_depth(),
feature_names=feature_names, nb_classes=tree.n_classes_)
self.dt_format, self.map, features_names_mapping = self.uploaded_dt.convert_dt(feat_names=feature_names)
self.mapping_instance = self.create_fvmap_inverse(features_names_mapping)
self.dt = DecisionTree(from_dt=self.dt_format, mapfile=self.map, feature_names=feature_names)
dot_source = visualize(self.dt)
self.network = html.Div(
[dash_interactive_graphviz.DashInteractiveGraphviz(dot_source=dot_source, style={"width": "60%",
"height": "90%",
"background-color": "transparent"})])
self.explanation = []
def create_fvmap_inverse(self, instance):
def create_fvmap_inverse_with_info(features_names_mapping):
mapping_instance = {}
for feat in features_names_mapping:
feat_dic = {}
feature_description = feat.split(',')
name_feat, id_feat = feature_description[1].split(':')
for mapping in feature_description[2:]:
real_value, mapped_value = mapping.split(':')
feat_dic[np.float32(real_value)] = int(mapped_value)
mapping_instance[name_feat] = feat_dic
return mapping_instance
def create_fvmap_inverse_threashold(features_names_mapping):
mapping_instance = {}
for feat in features_names_mapping:
feature_description = feat.split(',')
name_feat, id_feat = feature_description[1].split(':')
mapping_instance[name_feat] = float(feature_description[2].split(':')[0])
return mapping_instance
if self.categorical:
return create_fvmap_inverse_with_info(instance)
else:
return create_fvmap_inverse_threashold(instance)
def translate_instance(self, instance):
def translate_instance_categorical(instance):
instance_translated = []
for feat, real_value in instance:
instance_translated.append((feat, self.mapping_instance[feat][real_value]))
return instance_translated
def translate_instance_threasholds(instance):
instance_translated = []
for feat, real_value in instance:
try:
if real_value <= self.mapping_instance[feat]:
instance_translated.append((feat, 0))
else:
instance_translated.append((feat, 1))
except:
instance_translated.append((feat, real_value))
return instance_translated
if self.categorical:
return translate_instance_categorical(instance)
else:
return translate_instance_threasholds(instance)
def update_with_explicability(self, instance, enum, xtype, solver):
instance_translated = self.translate_instance(instance)
self.explanation = []
list_explanations_path = []
list_contrastive_explanations_path = []
explanation = self.dt.explain(instance_translated, enum=enum, xtype=xtype, solver=solver)
dot_source = visualize_instance(self.dt, instance_translated)
self.network = html.Div([dash_interactive_graphviz.DashInteractiveGraphviz(
dot_source=dot_source, style={"width": "50%",
"height": "80%",
"background-color": "transparent"}
)])
# Creating a clean and nice text component
# instance plotting
self.explanation.append(html.H4("Instance : \n"))
self.explanation.append(html.P(str([str(instance[i]) for i in range(len(instance))])))
for k in explanation.keys():
if k != "List of path explanation(s)" and k != "List of path contrastive explanation(s)":
if k in ["List of abductive explanation(s)", "List of contrastive explanation(s)"]:
self.explanation.append(html.H4(k))
for expl in explanation[k]:
self.explanation.append(html.Hr())
self.explanation.append(html.P(expl))
self.explanation.append(html.Hr())
else:
self.explanation.append(html.P(k + explanation[k]))
else:
list_explanations_path = explanation["List of path explanation(s)"]
list_contrastive_explanations_path = explanation["List of path contrastive explanation(s)"]
return list_explanations_path, list_contrastive_explanations_path
def draw_explanation(self, instance, expl):
instance = self.translate_instance(instance)
dot_source = visualize_expl(self.dt, instance, expl)
self.network = html.Div([dash_interactive_graphviz.DashInteractiveGraphviz(
dot_source=dot_source,
style={"width": "50%",
"height": "80%",
"background-color": "transparent"})])
def draw_contrastive_explanation(self, instance, cont_expl):
instance = self.translate_instance(instance)
dot_source = visualize_contrastive_expl(self.dt, instance, cont_expl)
self.network = html.Div([dash_interactive_graphviz.DashInteractiveGraphviz(
dot_source=dot_source,
style={"width": "50%",
"height": "80%",
"background-color": "transparent"})])