Skip to content
Snippets Groups Projects
Commit d715463d authored by Caroline DE POURTALES's avatar Caroline DE POURTALES
Browse files

decisiontree color

parent 8e2acebb
Branches
No related tags found
No related merge requests found
Showing
with 236 additions and 622 deletions
# Run this app with `python app.py` and # Run this app with `python app.py` and
# visit http://127.0.0.1:8050/ in your web browser. # visit http://127.0.0.1:8050/ in your web browser.
from pages.layout import create_layout
import dash import dash
import json import json
from dash import dcc
from dash import html
from dash import dcc, html, Input, Output
import dash_bootstrap_components as dbc
from pages.application.layout_application import Model, View
'''
Loading data
'''
models_data = open('data_retriever.json') models_data = open('data_retriever.json')
data = json.load(models_data) data = json.load(models_data)["data"]
app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])
'''
Construction of the layout
'''
model = Model(data)
view = View(model)
tabs = dcc.Tabs([
dcc.Tab(label='Course on Explainable AI', children=[]),
view.tab,
])
app.layout = html.Div([
html.H1('FXToolKit'),
tabs])
'''
Callback for the app
'''
@app.callback(
Output('ml_datasets_choice', 'options'),
Output('ml_instances_choice', 'options'),
Output('graph', 'children'),
Output('explanation', 'children'),
Input('ml_model_choice', 'value'),
Input('ml_datasets_choice', 'value'),
Input('ml_instances_choice', 'value'),
prevent_initial_call=True
)
def update_ml_type(value_ml_model, value_dataset, value_instance):
ctx = dash.callback_context
if ctx.triggered:
dropdown_id = ctx.triggered[0]['prop_id'].split('.')[0]
if dropdown_id == 'ml_model_choice' :
model.update_ml_model(value_ml_model)
return model.datasets, [], "", ""
elif dropdown_id == 'ml_datasets_choice':
model.update_dataset(value_dataset)
view.update_dataset()
return model.datasets, model.instances, view.component.network, ""
elif dropdown_id == 'ml_instances_choice' :
model.update_instance(value_instance)
view.update_instance()
return model.datasets, model.instances, view.component.network, view.component.explanation
app = dash.Dash(__name__)
'''
Launching app
'''
if __name__ == '__main__': if __name__ == '__main__':
app.layout = create_layout(data)
app.run_server(debug=True) app.run_server(debug=True)
\ No newline at end of file
...@@ -5,8 +5,15 @@ ...@@ -5,8 +5,15 @@
"ml_type" : "DecisionTree", "ml_type" : "DecisionTree",
"trained_models" : "pages/application/DecisionTree/trained_models", "trained_models" : "pages/application/DecisionTree/trained_models",
"instances" : "pages/application/DecisionTree/instances/", "instances" : "pages/application/DecisionTree/instances/",
"explicability_algorithm" : "pages/application/DecisionTree/decision_tree_explicability", "map" : "pages/application/DecisionTree/map/",
"component" : "DecisionTreeComponent" "component" : "DecisionTreeComponent"
},
{
"ml_type" : "NaiveBayes",
"trained_models" : "pages/application/NaiveBayes/trained_models",
"instances" : "pages/application/NaiveBayes/instances/",
"map" : "pages/application/NaiveBayes/map/",
"component" : "NaiveBayesComponent"
} }
] ]
......
File deleted
from dash import dcc, html, Input, Output, callback from dash import dcc, html, Input, Output, callback
from pages.application.DecisionTree.utils.dtviz import visualize from pages.application.DecisionTree.utils.dtviz import visualize, visualize_instance
from pages.application.DecisionTree.utils.dtree import DecisionTree from pages.application.DecisionTree.utils.dtree import DecisionTree
import dash_interactive_graphviz
class DecisionTreeComponent(): class DecisionTreeComponent():
def __init__(self): def __init__(self, dataset):
self.network = html.Img() map = "pages/application/DecisionTree/map/"+ dataset + "/" + dataset + ".map"
def update(self, dataset) : self.dt = DecisionTree(from_file = "pages/application/DecisionTree/trained_models/"+dataset+'/'+dataset+".dt", mapfile = map)
print(dataset) dot_source = visualize(self.dt)
dt = DecisionTree(from_file = "pages/application/DecisionTree/trained_models/"+dataset+'/'+dataset+".dt")
visualize(dt, "svg", "pages/application/component.svg")
self.network = html.Img(src="pages/application/component.svg")
def update_with_explicability(self, instance) : self.network = dash_interactive_graphviz.DashInteractiveGraphviz(
self.network = self.network dot_source=dot_source
\ No newline at end of file )
self.explanation = dcc.Textarea(value = "", style = { "font_size" : "15px",
"width": "40rem",
"height": "40rem",
"margin-bottom": "5rem",
"background-color": "#f8f9fa",
})
def update_with_explicability(self, dataset, instance) :
instance = open("pages/application/DecisionTree/instances/"+ dataset + "/" + instance, "r")
instance = str(instance.read()).strip().split(',')
dot_source = visualize_instance(self.dt, instance)
self.network = dash_interactive_graphviz.DashInteractiveGraphviz(
dot_source=dot_source
)
self.explanation.value = self.dt.explain(instance)
File deleted
File deleted
File deleted
File deleted
...@@ -351,8 +351,7 @@ class DecisionTree(): ...@@ -351,8 +351,7 @@ class DecisionTree():
# returning the set of sets with no duplicates # returning the set of sets with no duplicates
return list(dict.fromkeys(sets)) return list(dict.fromkeys(sets))
def explain(self, inst, enum=1, pathlits=False, solver='g3', xtype='abd', def explain(self, inst, enum=5, pathlits=False, solver='g3', htype='sorted'):
htype='sorted'):
""" """
Compute a given number of explanations. Compute a given number of explanations.
""" """
...@@ -364,14 +363,16 @@ class DecisionTree(): ...@@ -364,14 +363,16 @@ class DecisionTree():
else: else:
# input expected by Yacine - 'value1,value2,...' # input expected by Yacine - 'value1,value2,...'
inst = list(map(lambda i : tuple(['f{0}'.format(i[0]), int(i[1])]), [(i, j) for i,j in enumerate(inst)])) inst = list(map(lambda i : tuple(['f{0}'.format(i[0]), int(i[1])]), [(i, j) for i,j in enumerate(inst)]))
print(inst)
inst_orig = inst[:] inst_orig = inst[:]
path, term, depth = self.execute(inst, pathlits) path, term, depth = self.execute(inst, pathlits)
print(path)
explanation = str(inst) + "\n \n"
#print('c instance: IF {0} THEN class={1}'.format(' AND '.join([self.fvmap[p] for p in inst_orig]), term)) #print('c instance: IF {0} THEN class={1}'.format(' AND '.join([self.fvmap[p] for p in inst_orig]), term))
#print(term) #print(term)
print('c instance: IF {0} THEN class={1}'.format(' AND '.join([self.fvmap[ inst_orig[self.feids[self.nodes[n].feat]] ] for n in path]), term)) explanation += 'c instance: IF {0} THEN class={1}'.format(' AND '.join([self.fvmap[ inst_orig[self.feids[self.nodes[n].feat]] ] for n in path]), term) + "\n"
print('c path len:', depth) explanation +='c path len:'+ str(depth)+ "\n \n \n"
if self.ohmap.dir: if self.ohmap.dir:
f2v = {fv[0]: fv[1] for fv in inst} f2v = {fv[0]: fv[1] for fv in inst}
...@@ -383,30 +384,34 @@ class DecisionTree(): ...@@ -383,30 +384,34 @@ class DecisionTree():
# computing the sets to hit # computing the sets to hit
to_hit = self.prepare_sets(inst, term) to_hit = self.prepare_sets(inst, term)
if xtype == 'abd': explanation += "Abductive explanation : " + "\n \n"
self.enumerate_abductive(to_hit, enum, solver, htype, term) explanation += self.enumerate_abductive(to_hit, enum, solver, htype, term)
else: explanation += "Contrastive explanation : "+ "\n \n"
self.enumerate_contrastive(to_hit, term) explanation += self.enumerate_contrastive(to_hit, term)
return explanation
def enumerate_abductive(self, to_hit, enum, solver, htype, term): def enumerate_abductive(self, to_hit, enum, solver, htype, term):
""" """
Enumerate abductive explanations. Enumerate abductive explanations.
""" """
explanation = ""
with Hitman(bootstrap_with=to_hit, solver=solver, htype=htype) as hitman: with Hitman(bootstrap_with=to_hit, solver=solver, htype=htype) as hitman:
expls = [] expls = []
for i, expl in enumerate(hitman.enumerate(), 1): for i, expl in enumerate(hitman.enumerate(), 1):
print('c expl: IF {0} THEN class={1}'.format(' AND '.join([self.fvmap[p] for p in sorted(expl, key=lambda p: p[0])]), term)) explanation += 'c expl: IF {0} THEN class={1}'.format(' AND '.join([self.fvmap[p] for p in sorted(expl, key=lambda p: p[0])]), term) + "\n"
expls.append(expl) expls.append(expl)
if i == enum: if i == enum:
break break
print('c nof expls:', i) explanation += 'c nof expls:' + str(i)+ "\n"
print('c min expl:', min([len(e) for e in expls])) explanation += 'c min expl:'+ str( min([len(e) for e in expls]))+ "\n"
print('c max expl:', max([len(e) for e in expls])) explanation += 'c max expl:'+ str( max([len(e) for e in expls]))+ "\n"
print('c avg expl: {0:.2f}'.format(sum([len(e) for e in expls]) / len(expls))) explanation += 'c avg expl: {0:.2f}'.format(sum([len(e) for e in expls]) / len(expls))+ "\n \n \n"
return explanation
def enumerate_contrastive(self, to_hit, term): def enumerate_contrastive(self, to_hit, term):
""" """
...@@ -424,14 +429,17 @@ class DecisionTree(): ...@@ -424,14 +429,17 @@ class DecisionTree():
to_hit = [set(s) for s in to_hit] to_hit = [set(s) for s in to_hit]
to_hit.sort(key=lambda s: len(s)) to_hit.sort(key=lambda s: len(s))
expls = list(reduce(process_set, to_hit, [])) expls = list(reduce(process_set, to_hit, []))
explanation = ""
for expl in expls: for expl in expls:
print('c expl: IF {0} THEN class!={1}'.format(' OR '.join(['!{0}'.format(self.fvmap[p]) for p in sorted(expl, key=lambda p: p[0])]), term)) explanation += 'c expl: IF {0} THEN class!={1}'.format(' OR '.join(['!{0}'.format(self.fvmap[p]) for p in sorted(expl, key=lambda p: p[0])]), term)+ "\n"
explanation +='c nof expls:'+ str(len(expls))+ "\n"
explanation +='c min expl:'+ str( min([len(e) for e in expls]))+ "\n"
explanation +='c max expl:'+ str( max([len(e) for e in expls]))+ "\n"
explanation +='c avg expl: {0:.2f}'.format(sum([len(e) for e in expls]) / len(expls))+ "\n"
print('c nof expls:', len(expls)) return explanation
print('c min expl:', min([len(e) for e in expls]))
print('c max expl:', max([len(e) for e in expls]))
print('c avg expl: {0:.2f}'.format(sum([len(e) for e in expls]) / len(expls)))
def execute_path(self, path): def execute_path(self, path):
""" """
......
...@@ -19,7 +19,7 @@ import sys ...@@ -19,7 +19,7 @@ import sys
# #
#============================================================================== #==============================================================================
def visualize(dt, fmt, output): def visualize(dt):
""" """
Visualize a DT with graphviz. Visualize a DT with graphviz.
""" """
...@@ -57,6 +57,67 @@ def visualize(dt, fmt, output): ...@@ -57,6 +57,67 @@ def visualize(dt, fmt, output):
edge.attr['arrowsize'] = 0.8 edge.attr['arrowsize'] = 0.8
# saving file # saving file
g.in_edges
g.layout(prog='dot') g.layout(prog='dot')
g.draw(path=output, format=fmt) return(g.to_string())
#
#==============================================================================
def visualize_instance(dt, instance):
"""
Visualize a DT with graphviz and plot the running instance.
"""
if '=' in instance[0]:
instance = list(map(lambda i: tuple([i[0], int(i[1])]), [i.split('=') for i in instance]))
else:
instance = list(map(lambda i : tuple(['f{0}'.format(i[0]), int(i[1])]), [(i, j) for i,j in enumerate(instance)]))
#path that follows the instance - colored in blue
path, term, depth = dt.execute(instance)
edges_instance = []
for i in range (len(path)-1) :
edges_instance.append((path[i], path[i+1]))
edges_instance.append((path[-1],"term:"+term))
g = pygraphviz.AGraph(directed=True, strict=True)
g.edge_attr['dir'] = 'forward'
g.graph_attr['rankdir'] = 'TB'
# non-terminal nodes
for n in dt.nodes:
g.add_node(n, label='{0}\\n({1})'.format(dt.nodes[n].feat, n))
node = g.get_node(n)
node.attr['shape'] = 'circle'
node.attr['fontsize'] = 13
# terminal nodes
for n in dt.terms:
g.add_node(n, label='{0}\\n({1})'.format(dt.terms[n], n))
node = g.get_node(n)
node.attr['shape'] = 'square'
node.attr['fontsize'] = 13
# transitions
for n1 in dt.nodes:
for v in dt.nodes[n1].vals:
n2 = dt.nodes[n1].vals[v]
n2_type = g.get_node(n2).attr['shape']
g.add_edge(n1, n2)
edge = g.get_edge(n1, n2)
if len(v) == 1:
edge.attr['label'] = dt.fvmap[tuple([dt.nodes[n1].feat, tuple(v)[0]])]
else:
edge.attr['label'] = '{0}'.format('\n'.join([dt.fvmap[tuple([dt.nodes[n1].feat, val])] for val in tuple(v)]))
#instance path in blue
if ((n1,n2) in edges_instance) or (n2_type=='square' and (n1, "term:"+ dt.terms[n2]) in edges_instance):
edge.attr['color'] = 'blue'
edge.attr['fontsize'] = 10
edge.attr['arrowsize'] = 0.8
# saving file
g.layout(prog='dot')
return(g.to_string())
File deleted
File deleted
This diff is collapsed.
from dash import dcc, html, Input, Output, callback from dash import dcc, html, Input, Output
import dash import dash
import dash_bootstrap_components as dbc import dash_bootstrap_components as dbc
import visdcc
import pandas as pd import pandas as pd
import numpy as np import numpy as np
from os import listdir from os import listdir
...@@ -9,68 +8,23 @@ from os.path import isfile, join ...@@ -9,68 +8,23 @@ from os.path import isfile, join
import importlib import importlib
from pages.application.utils_data import extract_data from pages.application.utils_data import extract_data
from pages.application.DecisionTree.DecisionTreeComponent import DecisionTreeComponent
def create_tab_application(data):
model = Model(data)
view = View(model)
@callback(
Output('ml_datasets_choice', 'options'),
Input('ml_model_choice', 'value')
)
def update_ml_type(value):
model.update_ml_model(value)
view.update_ml_model()
return model.datasets
@callback(
Output('ml_instances_choice', 'options'),
Output('model_layout', 'children'),
Input('ml_datasets_choice', 'value'),
Input('ml_instances_choice', 'value')
)
def update_dataset(value_dataset, value_instance):
if model.ml_model!='':
ctx = dash.callback_context
if ctx.triggered:
dropdown = button_id = ctx.triggered[0]['prop_id'].split('.')[0]
if dropdown == 'ml_datasets_choice':
model.update_dataset(value_dataset)
model_layout = view.update_dataset()
return model.instances, model_layout
else :
model.update_instance(value_instance)
model_layout = view.update_instance()
return model.instances, model_layout
else :
return [], html.Div(children=[])
return view.create_tab()
SIDEBAR_STYLE = { SIDEBAR_STYLE = {
"width": "20rem",
"margin-bottom": "5rem",
"background-color": "#f8f9fa",
} }
CONTENT_STYLE = { CONTENT_STYLE = {
"padding": "2rem 1rem",
} }
class Model(): class Model():
def __init__(self, data): def __init__(self, data):
names_models, dict_data_models, dict_data_instances, dict_components, dict_expl = extract_data(data) names_models, dict_data_models, dict_data_instances, dict_components = extract_data(data)
self.dict_data_models = dict_data_models self.dict_data_models = dict_data_models
self.dict_data_instances = dict_data_instances self.dict_data_instances = dict_data_instances
self.dict_components = dict_components self.dict_components = dict_components
self.dict_expl = dict_expl
self.ml_models = names_models self.ml_models = names_models
self.ml_model = '' self.ml_model = ''
...@@ -83,8 +37,6 @@ class Model(): ...@@ -83,8 +37,6 @@ class Model():
self.component_class = '' self.component_class = ''
self.explicability_algorithm = ''
def update_ml_model(self, ml_model_update): def update_ml_model(self, ml_model_update):
self.ml_model = ml_model_update self.ml_model = ml_model_update
...@@ -92,8 +44,6 @@ class Model(): ...@@ -92,8 +44,6 @@ class Model():
self.component_class = self.dict_components[self.ml_model] self.component_class = self.dict_components[self.ml_model]
self.explicability_algorithm = self.dict_expl[self.ml_model]
def update_dataset(self, dataset_update): def update_dataset(self, dataset_update):
self.dataset = dataset_update self.dataset = dataset_update
...@@ -111,41 +61,28 @@ class View(): ...@@ -111,41 +61,28 @@ class View():
self.datasets_menu_models = dcc.Dropdown(self.model.datasets, id='ml_datasets_choice') self.datasets_menu_models = dcc.Dropdown(self.model.datasets, id='ml_datasets_choice')
self.instances_menu = dcc.Dropdown(self.model.instances, id='ml_instances_choice') self.instances_menu = dcc.Dropdown(self.model.instances, id='ml_instances_choice')
self.sidebar = html.Div([
html.Label("Choose the Machine Learning algorithm :"),
html.Br(),
self.ml_menu_models,
html.Hr(),
html.Label("Choose the dataset : "),
html.Br(),
self.datasets_menu_models,
html.Hr(),
html.Label("Choose the instance to explain : "),
html.Br(),
self.instances_menu],
style=SIDEBAR_STYLE)
self.component = '' self.component = ''
self.model_layout = html.Div( self.sidebar = dbc.Col([
id="model_layout", html.Label("Choose the Machine Learning algorithm :"),
style= CONTENT_STYLE) html.Br(),
self.ml_menu_models,
def create_tab(self): html.Hr(),
return dcc.Tab(label='Application on Explainable AI', children=[ html.Label("Choose the dataset : "),
html.Div(children=[ html.Br(),
self.sidebar, self.datasets_menu_models,
self.model_layout html.Hr(),
]) html.Label("Choose the instance to explain : "),
]) html.Br(),
self.instances_menu], width=2)
def update_ml_model(self): self.layout = dbc.Row([self.sidebar,
cls = getattr(importlib.import_module("pages.application.DecisionTree."+self.model.component_class), self.model.component_class) dbc.Col(html.Div(id = "graph", children=" "), width=5),
self.component = cls() dbc.Col(html.Div(id = "explanation", children=" "), width=3)])
self.tab = dcc.Tab(label='Application on Explainable AI', children=self.layout)
def update_dataset(self): def update_dataset(self):
self.component.update(self.model.dataset) class_ = globals()[self.model.component_class]
return html.Div(children=[self.component.network]) self.component = class_(self.model.dataset)
def update_instance(self): def update_instance(self):
self.component.update_with_explicability(self.model.instance) self.component.update_with_explicability(self.model.dataset, self.model.instance)
return html.Div(children=[self.component.network]) \ No newline at end of file
\ No newline at end of file
...@@ -9,21 +9,19 @@ from os.path import isfile, join ...@@ -9,21 +9,19 @@ from os.path import isfile, join
import importlib import importlib
def extract_data(data): def extract_data(data):
names_models = [data['data'][i]['ml_type'] for i in range (len(data))] names_models = [data[i]['ml_type'] for i in range (len(data))]
dict_data_models = {} dict_data_models = {}
dict_data_instances = {} dict_data_instances = {}
dict_components = {} dict_components = {}
dict_expl = {}
for i in range (len(data)) : for i in range (len(data)) :
ml_type = data['data'][i]['ml_type'] ml_type = data[i]['ml_type']
dict_components[ml_type] = data['data'][i]['component'] dict_components[ml_type] = data[i]['component']
dict_expl[ml_type] = data['data'][i]['explicability_algorithm'] dict_data_models[ml_type] = [f for f in listdir(data[i]['trained_models'])]
dict_data_models[ml_type] = [f for f in listdir(data['data'][i]['trained_models'])] dict_dataset_instances = {}
dict_data_instances = {}
for j in range (len(dict_data_models[ml_type])) : for j in range (len(dict_data_models[ml_type])) :
dataset = dict_data_models[ml_type][j] dataset = dict_data_models[ml_type][j]
dict_data_instances[dataset] = [f for f in listdir(data['data'][i]['instances']+dataset) if isfile(join(data['data'][i]['instances']+dataset, f))] dict_dataset_instances[dataset] = [f for f in listdir(data[i]['instances']+dataset) if isfile(join(data[i]['instances']+dataset, f))]
dict_data_instances[ml_type] = dict_data_instances dict_data_instances[ml_type] = dict_dataset_instances
return names_models, dict_data_models, dict_data_instances, dict_components, dict_expl return names_models, dict_data_models, dict_data_instances, dict_components
\ No newline at end of file \ No newline at end of file
from dash import dcc
from dash import html
import visdcc
import pandas as pd
import numpy as np
from pages.application.layout_application import create_tab_application
def create_tabs(data):
tabs = dcc.Tabs([
dcc.Tab(label='Course on Explainable AI', children=[]),
create_tab_application(data),
])
return tabs
def create_layout(data):
return html.Div([
html.H1('FXToolKit'),
create_tabs(data) ])
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment