Skip to content
Snippets Groups Projects
Commit da034c90 authored by Caroline DE POURTALES's avatar Caroline DE POURTALES
Browse files

pickle done, cleaning, customizing to do

parent 8b31c676
Branches
No related tags found
No related merge requests found
......@@ -7,4 +7,5 @@ decision_tree_classifier_20170212.pkl
push_command
adult.pkl
adult_data_00000.inst
iris_00000.txt
\ No newline at end of file
iris_00000.txt
tests
\ No newline at end of file
......@@ -4,103 +4,52 @@ import json
import dash
import dash_bootstrap_components as dbc
import pandas as pd
from dash import Input, Output, State, dcc, html
from dash.exceptions import PreventUpdate
from dash import dcc, html
from pages.application.layout_application import Model, View
from utils import extract_data, parse_contents_instance, parse_contents_tree
from callbacks import register_callbacks
from pages.application.application import Application, Model, View
from utils import extract_data
'''
Loading data
'''
app = dash.Dash(external_stylesheets=[dbc.themes.LUX], suppress_callback_exceptions=True)
#################################################################################
############################# Layouts ###########################################
#################################################################################
models_data = open('data_retriever.json')
data = json.load(models_data)["data"]
app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])
'''
Construction of the layout
'''
#For home directory
page_home = dbc.Row([])
#For course directory
page_course = dbc.Row([])
#For the application
names_models, dict_components = extract_data(data)
model = Model(names_models, dict_components)
view = View(model)
tabs = dcc.Tabs([
dcc.Tab(label='Course on Explainable AI', children=[]),
view.tab,
])
model_application = Model(names_models, dict_components)
view_application = View(model_application)
page_application = Application(view_application)
app.layout = html.Div([
html.H1('FXToolKit'),
tabs])
'''
Callback for the app
'''
@app.callback(
Output('dataset_filename', 'children'),
Output('instance_filename', 'children'),
Output('graph', 'children'),
Output('explanation', 'children'),
Input('ml_model_choice', 'value'),
Input('ml_dataset_choice', 'contents'),
Input('ml_instance_choice', 'contents'),
Input('number_explanations', 'value'),
Input('explanation_type', 'value'),
Input('solver_sat', 'value'),
State('ml_dataset_choice', 'filename'),
State('ml_instance_choice', 'filename'),
prevent_initial_call=True
)
def update_ml_type(value_ml_model, dataset_contents, instance_contents, enum, xtype, solver, dataset_filename, instance_filename):
ctx = dash.callback_context
if ctx.triggered:
ihm_id = ctx.triggered[0]['prop_id'].split('.')[0]
if ihm_id == 'ml_model_choice' :
model.update_ml_model(value_ml_model)
return "", "", "", ""
elif ihm_id == 'ml_dataset_choice':
if value_ml_model == None :
raise PreventUpdate
tree, typ = parse_contents_tree(dataset_contents, dataset_filename)
model.update_dataset(tree, typ)
return dataset_filename, "", model.component.network, ""
elif ihm_id == 'ml_instance_choice' :
if value_ml_model == None or dataset_contents == None or enum == None or xtype==None:
raise PreventUpdate
instance = parse_contents_instance(instance_contents, instance_filename)
model.update_instance(instance, enum, xtype)
return dataset_filename, instance_filename, model.component.network, model.component.explanation
elif ihm_id == 'number_explanations' :
if value_ml_model == None or dataset_contents == None or instance_contents == None or xtype==None:
raise PreventUpdate
instance = parse_contents_instance(instance_contents, instance_filename)
model.update_instance(instance, enum, xtype)
return dataset_filename, instance_filename, model.component.network, model.component.explanation
elif ihm_id == 'explanation_type' :
if value_ml_model == None or dataset_contents == None or instance_contents == None or enum == None :
raise PreventUpdate
instance = parse_contents_instance(instance_contents, instance_filename)
model.update_instance(instance, enum, xtype)
return dataset_filename, instance_filename, model.component.network, model.component.explanation
elif ihm_id == 'solver_sat' :
if value_ml_model == None or dataset_contents == None or instance_contents == None or enum == None or xtype == None:
raise PreventUpdate
instance = parse_contents_instance(instance_contents, instance_filename)
model.update_instance(instance, enum, xtype, solver=solver)
return dataset_filename, instance_filename, model.component.network, model.component.explanation
dcc.Location(id='url', refresh=False),
html.Nav(id='navbar-container',
children=[dbc.NavbarSimple(
children=[
dbc.NavItem(dbc.NavLink("Home", id="home-link", href="/")),
dbc.NavItem(dbc.NavLink("Course", id="course-link", href="/course")),
dbc.NavItem(dbc.NavLink("Application on explainable AI", id="application-link", href="/application")),
],
brand="FX ToolKit",
color="primary",
dark=True,)]),
html.Div(id='page-content')
])
#################################################################################
################################# Callback for the app ##########################
#################################################################################
register_callbacks(page_home, page_course, page_application, app)
'''
Launching app
'''
#################################################################################
################################# Launching app #################################
#################################################################################
if __name__ == '__main__':
app.run_server(debug=True)
/* NAVBAR */
.navbar-dark .navbar-brand {
color: #fff;
font-size: 30px;
}
/* SIDEBAR */
.sidebar {
padding: 2rem;
padding-top:0.5rem;
color: rgb(255, 255, 255);
font-weight: 300;
background-color: black;
}
.sidebar .tab.jsx-3468109796 {
color:rgb(255, 255, 255);
font-weight: 500;
background-color: #1a1c1d;
}
.sidebar .tab--selected.jsx-3468109796:hover {
background-color:gray;
}
.sidebar .upload {
width: 100%;
height: 50px;
line-height: 50px;
border-width: 1px;
border-style: dashed;
border-radius: 5px;
text-align: center;
margin: 10px
}
.sidebar .Select-control {
width: 100%;
height: 30px;
line-height: 30px;
border-width: 1px;
border-radius: 5px;
text-align: center;
margin: 10px;
color:rgb(255, 255, 255);
font-weight: 400;
background-color: black;
}
.sidebar .sidebar-dropdown{
width: 100%;
height: 30px;
line-height: 30px;
border-width: 1px;
border-radius: 5px;
text-align: center;
margin: 10px;
color:rgb(255, 255, 255);
font-weight: 400;
background-color: black;
}
.sidebar .has-value.Select--single > .Select-control .Select-value .Select-value-label, .has-value.is-pseudo-focused.Select--single > .sidebar .Select-control .Select-value .Select-value-label {
color:rgb(255, 255, 255);
}
.sidebar .Select-menu-outer{
width: 100%;
border-width: 1px;
border-radius: 5px;
text-align: center;
margin: 10px;
color:rgb(255, 255, 255);
font-weight: 400;
background-color: black;
}
/* EXPLANATION */
main#explanation {
width: 95%;
margin-bottom: 5rem;
margin-top: 5rem;
border-width: 4px;
border-style:double;
border-radius: 5px;
padding: 2rem;
border-radius: 10px;
}
/* GRAPH */
.column_graph {
margin-top: 5rem;
}
body {
font-family: sans-serif;
}
H4 {
font-size: 20px;
text-decoration-line:underline;
text-decoration-thickness:2px;
text-decoration-style:solid;
color: hsl(229, 58%, 19%)
}
H5 {
font-size: 16px;
color: hsl(228, 58%, 12%);
}
p {
font-size: 15px;
color: hsl(0, 0%, 0%)
}
\ No newline at end of file
import dash
import pandas as pd
from dash import Input, Output, State
from dash.dependencies import Input, Output, State
from dash.exceptions import PreventUpdate
from utils import parse_contents_graph, parse_contents_instance
def register_callbacks(page_home, page_course, page_application, app):
page_list = ['home', 'course', 'application']
@app.callback(
Output('page-content', 'children'),
Input('url', 'pathname'))
def display_page(pathname):
if pathname == '/':
return page_home
if pathname == '/application':
return page_application.view.layout
if pathname == '/course':
return page_course
@app.callback(Output('home-link', 'active'),
Output('course-link', 'active'),
Output('application-link', 'active'),
Input('url', 'pathname'))
def navbar_state(pathname):
active_link = ([pathname == f'/{i}' for i in page_list])
return active_link[0], active_link[1], active_link[2]
@app.callback(
Output('pretrained_model_filename', 'children'),
Output('instance_filename', 'children'),
Output('graph', 'children'),
Output('explanation', 'children'),
Input('ml_model_choice', 'value'),
Input('ml_pretrained_model_choice', 'contents'),
State('ml_pretrained_model_choice', 'filename'),
Input('ml_instance_choice', 'contents'),
State('ml_instance_choice', 'filename'),
Input('number_explanations', 'value'),
Input('explanation_type', 'value'),
Input('solver_sat', 'value'),
Input('expl_choice', 'value'),
prevent_initial_call=True
)
def update_ml_type(value_ml_model, pretrained_model_contents, pretrained_model_filename, instance_contents, instance_filename, enum, xtype, solver, expl_choice):
ctx = dash.callback_context
if ctx.triggered:
ihm_id = ctx.triggered[0]['prop_id'].split('.')[0]
model_application = page_application.model
if ihm_id == 'ml_model_choice' :
model_application.update_ml_model(value_ml_model)
return None, None, None, None
elif ihm_id == 'ml_pretrained_model_choice':
if value_ml_model is None :
raise PreventUpdate
tree, typ = parse_contents_graph(pretrained_model_contents, pretrained_model_filename)
model_application.update_pretrained_model(tree, typ)
return pretrained_model_filename, None, model_application.component.network, None
elif ihm_id == 'ml_instance_choice' :
if value_ml_model is None or pretrained_model_contents is None or enum is None or xtype is None:
raise PreventUpdate
instance = parse_contents_instance(instance_contents, instance_filename)
model_application.update_instance(instance, enum, xtype)
return pretrained_model_filename, instance_filename, model_application.component.network, model_application.component.explanation
elif ihm_id == 'number_explanations' :
if value_ml_model is None or pretrained_model_contents is None or instance_contents is None or xtype is None:
raise PreventUpdate
instance = parse_contents_instance(instance_contents, instance_filename)
model_application.update_instance(instance, enum, xtype)
return pretrained_model_filename, instance_filename, model_application.component.network, model_application.component.explanation
elif ihm_id == 'explanation_type' :
if value_ml_model is None or pretrained_model_contents is None or instance_contents is None or enum is None :
raise PreventUpdate
instance = parse_contents_instance(instance_contents, instance_filename)
model_application.update_instance(instance, enum, xtype)
return pretrained_model_filename, instance_filename, model_application.component.network, model_application.component.explanation
elif ihm_id == 'solver_sat' :
if value_ml_model is None or pretrained_model_contents is None or instance_contents is None or enum is None or xtype is None:
raise PreventUpdate
instance = parse_contents_instance(instance_contents, instance_filename)
model_application.update_instance(instance, enum, xtype, solver=solver)
return pretrained_model_filename, instance_filename, model_application.component.network, model_application.component.explanation
elif ihm_id == 'expl_choice' :
if instance_contents is None :
raise PreventUpdate
model_application.update_expl(expl_choice)
return pretrained_model_filename, instance_filename, model_application.component.network, model_application.component.explanation
@app.callback(
Output('explanation', 'hidden'),
Output('navigate_label', 'hidden'),
Output('navigate_dropdown', 'hidden'),
Output('expl_choice', 'options'),
Input('explanation', 'children'),
Input('explanation_type', 'value'),
prevent_initial_call=True
)
def layout_buttons_navigate_expls(explanation, explanation_type):
if explanation is None or "AXp" not in explanation_type:
return True, True, True, {}
else :
options = {}
model_application = page_application.model
for i in range (len(model_application.list_expls)):
options[str(model_application.list_expls[i])] = model_application.list_expls[i]
return False, False, False, options
from dash import dcc
from pages.application.DecisionTree.utils.dtviz import visualize, visualize_instance
from pages.application.DecisionTree.utils.dtree import DecisionTree
from os import path
import dash_bootstrap_components as dbc
import dash_interactive_graphviz
import numpy as np
from dash import dcc, html
from pages.application.DecisionTree.utils.dtree import DecisionTree
from pages.application.DecisionTree.utils.dtviz import (visualize,
visualize_expl,
visualize_instance)
import os.path
from os import path
import numpy as np
class DecisionTreeComponent():
def __init__(self, tree, typ_data):
if typ_data == "dt" :
self.dt = DecisionTree(from_dt = tree)
elif typ_data == "pkl" :
self.dt = DecisionTree(from_pickle = tree)
self.dt = DecisionTree(from_pickle = tree)
dot_source = visualize(self.dt)
self.network = dash_interactive_graphviz.DashInteractiveGraphviz(
dot_source=dot_source
)
self.explanation = dcc.Textarea(value = "", style = { "font_size" : "15px",
"width": "40rem",
"height": "40rem",
"margin-bottom": "5rem",
"background-color": "#f8f9fa",
})
self.network = [dbc.Row(dash_interactive_graphviz.DashInteractiveGraphviz(dot_source=dot_source, style = {"width": "60%",
"height": "90%",
"background-color": "transparent"}))]
self.explanation = []
def update_with_explicability(self, instance, enum, xtype, solver) :
instance = str(instance).strip().split(',')
instance = list(map(lambda i: tuple([i[0], np.float32(i[1])]), [i.split('=') for i in instance]))
dot_source = visualize_instance(self.dt, instance)
self.network = dash_interactive_graphviz.DashInteractiveGraphviz(
dot_source=dot_source
)
self.explanation.value = self.dt.explain(instance, enum=enum, xtype = xtype, solver=solver)
self.network = [dbc.Row(dash_interactive_graphviz.DashInteractiveGraphviz(
dot_source=dot_source, style = {"width": "50%",
"height": "80%",
"background-color": "transparent"}
))]
self.explanation = []
list_explanations_path=[]
explanation = self.dt.explain(instance, enum=enum, xtype = xtype, solver=solver)
#Creating a clean and nice text component
for k in explanation.keys() :
if k != "List of path explanation(s)":
if k in ["List of abductive explanation(s)","List of contrastive explanation(s)"] :
self.explanation.append(html.H4(k))
for expl in explanation[k] :
self.explanation.append(html.Hr())
self.explanation.append(html.P(expl))
self.explanation.append(html.Hr())
else :
self.explanation.append(html.P(k + explanation[k]))
else :
list_explanations_path = explanation["List of path explanation(s)"]
return list_explanations_path
def draw_explanation(self, instance, expl) :
dot_source = visualize_expl(self.dt, instance, expl)
self.network = [dbc.Row(dash_interactive_graphviz.DashInteractiveGraphviz(
dot_source=dot_source,
style = {"width": "50%",
"height": "80%",
"background-color": "transparent"}))]
......@@ -11,21 +11,26 @@
#
#==============================================================================
from __future__ import print_function
import collections
from functools import reduce
import sklearn
from pysat.card import *
from pysat.examples.hitman import Hitman
from pysat.formula import CNF, IDPool
from pysat.solvers import Solver
import sklearn
from torch import threshold
try: # for Python2
from cStringIO import StringIO
except ImportError: # for Python3
from io import StringIO
from sklearn.tree import _tree
import numpy as np
from dash import dcc, html
from sklearn.tree import _tree
#
#==============================================================================
......@@ -81,7 +86,7 @@ class DecisionTree():
def from_pickle_file(self, tree):
#help(_tree.Tree)
self.tree_ = tree.tree_
print(sklearn.tree.export_text(tree))
#print(sklearn.tree.export_text(tree))
try:
feature_names = tree.feature_names_in_
except:
......@@ -132,7 +137,7 @@ class DecisionTree():
Traverse the tree and extract explicit paths.
"""
if root in self.terms:
if root in self.terms.keys():
# store the path
term = self.terms[root]
self.paths[term].append(prefix)
......@@ -159,7 +164,7 @@ class DecisionTree():
sets = []
for t, paths in self.paths.items():
# ignoring the right class
if t == term:
if term in self.terms.keys() and self.terms[term] == t:
continue
# computing the sets to hit
......@@ -190,11 +195,15 @@ class DecisionTree():
inst_dic = {}
for i in range(len(inst)):
inst_dic[inst[i][0]] = np.float32(inst[i][1])
inst_orig = inst[:]
path, term = self.execute(inst_values)
explanation = str(inst_dic) + "\n \n"
decision_path_str = "c inst : IF : "
#contaiins all the elements for explanation
explanation_dic = {}
#instance plotting
explanation_dic["Instance : "] = str(inst_dic)
#decision path
decision_path_str = "IF : "
for node_id in path:
# continue to the next node if it is a leaf node
if term == node_id:
......@@ -207,43 +216,45 @@ class DecisionTree():
threshold=self.nodes[node_id].threshold)
decision_path_str += "THEN " + str(self.terms[term])
explanation += decision_path_str + "\n \n"
explanation +='c path len:'+ str(len(path))+ "\n \n \n"
explanation_dic["Decision path of instance : "] = decision_path_str
explanation_dic["Decision path length : "] = 'Path length is :'+ str(len(path))
# computing the sets to hit
to_hit = self.prepare_sets(inst_dic, term)
for type in xtype :
if type == "AXp":
explanation += "Abductive explanation : " + "\n \n"
explanation += self.enumerate_abductive(to_hit, enum, solver, htype, term)+ "\n \n"
explanation_dic.update(self.enumerate_abductive(to_hit, enum, solver, htype, term))
else :
explanation += "Contrastive explanation : "+ "\n \n"
explanation += self.enumerate_contrastive(to_hit, term)+ "\n \n"
explanation_dic.update(self.enumerate_contrastive(to_hit, term))
return explanation
return explanation_dic
def enumerate_abductive(self, to_hit, enum, solver, htype, term):
"""
Enumerate abductive explanations.
"""
explanation = ""
list_expls = []
list_expls_str = []
explanation = {}
with Hitman(bootstrap_with=to_hit, solver='m22', htype=htype) as hitman:
expls = []
for i, expl in enumerate(hitman.enumerate(), 1):
explanation += 'c expl: IF {0} THEN class={1}'.format(' AND '.join(["(inst[{feature}] = {value}) {inequality} {threshold})".format(feature=p[0],
list_expls.append([ p[0] + p[2] + p[3] for p in expl])
list_expls_str.append('Explanation: IF {0} THEN class={1}'.format(' AND '.join(["(inst[{feature}] = {value}) {inequality} {threshold})".format(feature=p[0],
value=p[1],
inequality=p[2],
threshold=p[3])
for p in sorted(expl, key=lambda p: p[0])]), str(self.terms[term]))+ "\n"
for p in sorted(expl, key=lambda p: p[0])]), str(self.terms[term])))
expls.append(expl)
if i == enum:
break
explanation += 'c nof expls:' + str(i)+ "\n"
explanation += 'c min expl:'+ str( min([len(e) for e in expls]))+ "\n"
explanation += 'c max expl:'+ str( max([len(e) for e in expls]))+ "\n"
explanation += 'c avg expl: {0:.2f}'.format(sum([len(e) for e in expls]) / len(expls))+ "\n \n \n"
explanation["List of path explanation(s)"] = list_expls
explanation["List of abductive explanation(s)"] = list_expls_str
explanation["Number of abductive explanation(s) : "] = str(i)
explanation["Minimal abductive explanation : "] = str( min([len(e) for e in expls]))
explanation["Maximal abductive explanation : "] = str( max([len(e) for e in expls]))
explanation["Average abductive explanation : "] = '{0:.2f}'.format(sum([len(e) for e in expls]) / len(expls))
return explanation
......@@ -263,15 +274,17 @@ class DecisionTree():
to_hit = [set(s) for s in to_hit]
to_hit.sort(key=lambda s: len(s))
expls = list(reduce(process_set, to_hit, []))
explanation = ""
list_expls_str = []
explanation = {}
for expl in expls:
explanation += 'c expl: IF {0} THEN class!={1}'.format(' OR '.join(["inst[{feature}] {inequality} {threshold})".format(feature=p[0],
list_expls_str.append('Contrastive: IF {0} THEN class!={1}'.format(' OR '.join(["inst[{feature}] {inequality} {threshold})".format(feature=p[0],
inequality="<=" if p[2]==">" else ">",
threshold=p[3])
for p in sorted(expl, key=lambda p: p[0])]), str(self.terms[term]))+ "\n"
explanation +='c nof expls:'+ str(len(expls))+ "\n"
explanation +='c min expl:'+ str( min([len(e) for e in expls]))+ "\n"
explanation +='c max expl:'+ str( max([len(e) for e in expls]))+ "\n"
explanation +='c avg expl: {0:.2f}'.format(sum([len(e) for e in expls]) / len(expls))+ "\n"
for p in sorted(expl, key=lambda p: p[0])]), str(self.terms[term])))
explanation["List of contrastive explanation(s)"] = list_expls_str
explanation["Number of contrastive explanation(s) : "]=str(len(expls))
explanation["Minimal contrastive explanation : "]= str( min([len(e) for e in expls]))
explanation["Maximal contrastive explanation : "]= str( max([len(e) for e in expls]))
explanation["Average contrastive explanation : "]='{0:.2f}'.format(sum([len(e) for e in expls]) / len(expls))
return explanation
\ No newline at end of file
return explanation
......@@ -8,22 +8,40 @@
## E-mail: alexey.ignatiev@monash.edu
##
#
#==============================================================================
from pages.application.DecisionTree.utils.dtree import DecisionTree
import pygraphviz
import numpy as np
import pandas as pd
import pygraphviz
#
#==============================================================================
def create_legend(g):
legend = g.subgraphs()[-1]
legend.add_node("a", style = "invis")
legend.add_node("b", style = "invis")
legend.add_node("c", style = "invis")
legend.add_node("d", style = "invis")
legend.add_edge("a","b")
edge = legend.get_edge("a","b")
edge.attr["label"] = "instance"
edge.attr["style"] = "dashed"
legend.add_edge("c","d")
edge = legend.get_edge("c","d")
edge.attr["label"] = "instance with explanation"
edge.attr["color"] = "blue"
edge.attr["style"] = "dashed"
def visualize(dt):
"""
Visualize a DT with graphviz.
"""
g = pygraphviz.AGraph(directed=True, strict=True)
g = pygraphviz.AGraph(name='root', rankdir="TB")
g.is_directed()
g.is_strict()
#g = pygraphviz.AGraph(name = "main", directed=True, strict=True)
g.edge_attr['dir'] = 'forward'
g.graph_attr['rankdir'] = 'TB'
# non-terminal nodes
for n in dt.nodes:
......@@ -56,6 +74,9 @@ def visualize(dt):
edge.attr['fontsize'] = 10
edge.attr['arrowsize'] = 0.8
g.add_subgraph(name='legend')
create_legend(g)
# saving file
g.layout(prog='dot')
return(g.string())
......@@ -102,7 +123,7 @@ def visualize_instance(dt, instance):
edge.attr['arrowsize'] = 0.8
#instance path in blue
if ((n1,children_left) in edges_instance):
edge.attr['color'] = 'blue'
edge.attr['style'] = 'dashed'
children_right = dt.nodes[n1].children_right
g.add_edge(n1, children_right)
......@@ -112,8 +133,74 @@ def visualize_instance(dt, instance):
edge.attr['arrowsize'] = 0.8
#instance path in blue
if ((n1,children_right) in edges_instance):
edge.attr['color'] = 'blue'
edge.attr['style'] = 'dashed'
g.add_subgraph(name='legend')
create_legend(g)
# saving file
g.layout(prog='dot')
return(g.to_string())
#
#==============================================================================
def visualize_expl(dt, instance, expl):
"""
Visualize a DT with graphviz and plot the running instance.
"""
g = pygraphviz.AGraph(directed=True, strict=True)
g.edge_attr['dir'] = 'forward'
g.graph_attr['rankdir'] = 'TB'
# non-terminal nodes
for n in dt.nodes:
g.add_node(n, label=str(dt.nodes[n].feat))
node = g.get_node(n)
node.attr['shape'] = 'circle'
node.attr['fontsize'] = 13
# terminal nodes
for n in dt.terms:
g.add_node(n, label=str(dt.terms[n]))
node = g.get_node(n)
node.attr['shape'] = 'square'
node.attr['fontsize'] = 13
#path that follows the instance - colored in blue
instance = [np.float32(i[1]) for i in instance]
path, term_id_node = dt.execute(instance)
edges_instance = []
for i in range (len(path)-1) :
edges_instance.append((path[i], path[i+1]))
for n1 in dt.nodes:
threshold = dt.nodes[n1].threshold
children_left = dt.nodes[n1].children_left
g.add_edge(n1, children_left)
edge = g.get_edge(n1, children_left)
edge.attr['label'] = str(dt.nodes[n1].feat) + "<=" + str(threshold)
edge.attr['fontsize'] = 10
edge.attr['arrowsize'] = 0.8
#instance path in blue
if ((n1,children_left) in edges_instance):
edge.attr['style'] = 'dashed'
if edge.attr['label'] in expl :
edge.attr['color'] = 'blue'
children_right = dt.nodes[n1].children_right
g.add_edge(n1, children_right)
edge = g.get_edge(n1, children_right)
edge.attr['label'] = str(dt.nodes[n1].feat) + ">" + str(threshold)
edge.attr['fontsize'] = 10
edge.attr['arrowsize'] = 0.8
#instance path in blue
if ((n1,children_right) in edges_instance):
edge.attr['style'] = 'dashed'
if edge.attr['label'] in expl :
edge.attr['color'] = 'blue'
g.add_subgraph(name='legend')
create_legend(g)
g.layout(prog='dot')
return(g.to_string())
from dash import dcc, html
import dash
import dash_bootstrap_components as dbc
from pages.application.DecisionTree.DecisionTreeComponent import DecisionTreeComponent
SIDEBAR_STYLE = {
}
class Application():
def __init__(self, view):
self.view = view
self.model = view.model
CONTENT_STYLE = {
}
class Model():
......@@ -19,11 +18,14 @@ class Model():
self.ml_models = names_models
self.ml_model = ''
self.dataset = ''
self.pretrained_model = ''
self.typ_data = ''
self.instance = ''
self.list_expls = []
self.expl_path = []
self.component_class = ''
self.component = ''
......@@ -32,41 +34,38 @@ class Model():
self.component_class = self.dict_components[self.ml_model]
self.component_class = globals()[self.component_class]
def update_dataset(self, dataset_update, typ_data):
self.dataset = dataset_update
def update_pretrained_model(self, pretrained_model_update, typ_data):
self.pretrained_model = pretrained_model_update
self.typ_data = typ_data
self.component = self.component_class(self.dataset, self.typ_data)
self.component = self.component_class(self.pretrained_model, self.typ_data)
def update_instance(self, instance, enum, xtype, solver="g3"):
self.instance = instance
self.component.update_with_explicability(self.instance, enum, xtype, solver)
self.list_expls = self.component.update_with_explicability(self.instance, enum, xtype, solver)
def update_expl(self, expl):
self.expl = expl
self.component.draw_explanation(self.instance, expl)
class View():
def __init__(self, model):
self.model = model
self.ml_menu_models = dcc.Dropdown(self.model.ml_models, id='ml_model_choice')
self.ml_menu_models = dcc.Dropdown(self.model.ml_models,
id='ml_model_choice',
className="sidebar-dropdown")
self.dataset_upload = html.Div([
dcc.Upload(
id='ml_dataset_choice',
self.pretrained_model_upload = html.Div([
dcc.Upload(
id='ml_pretrained_model_choice',
children=html.Div([
'Drag and Drop or ',
html.A('Select File')
]),
style={
'width': '100%',
'height': '60px',
'lineHeight': '60px',
'borderWidth': '1px',
'borderStyle': 'dashed',
'borderRadius': '5px',
'textAlign': 'center',
'margin': '10px'
}
className="upload"
),
html.Div(id='dataset_filename')])
html.Div(id='pretrained_model_filename')])
self.instance_upload = html.Div([
dcc.Upload(
......@@ -75,29 +74,20 @@ class View():
'Drag and Drop or ',
html.A('Select instance')
]),
style={
'width': '100%',
'height': '60px',
'lineHeight': '60px',
'borderWidth': '1px',
'borderStyle': 'dashed',
'borderRadius': '5px',
'textAlign': 'center',
'margin': '10px'
}
className="upload"
),
html.Div(id='instance_filename')])
self.sidebar = dbc.Col([
dcc.Tabs(children=[
dcc.Tab(label='Basic Parameters', children = [
self.sidebar = dcc.Tabs(children=[
dcc.Tab(label='Basic Parameters', children = [
html.Br(),
html.Label("Choose the Machine Learning algorithm :"),
html.Br(),
self.ml_menu_models,
html.Hr(),
html.Label("Choose the dataset : "),
html.Label("Choose the pretrained model : "),
html.Br(),
self.dataset_upload,
self.pretrained_model_upload,
html.Hr(),
html.Label("Choose the instance to explain : "),
html.Br(),
......@@ -107,8 +97,10 @@ class View():
html.Br(),
dcc.Input(
id="number_explanations",
value=1,
type="number",
placeholder="How many explanations ?"),
placeholder="How many explanations ?",
className="sidebar-dropdown"),
html.Hr(),
html.Label("Choose the kind of explanation : "),
html.Br(),
......@@ -116,17 +108,24 @@ class View():
id="explanation_type",
options={'AXp' : "Abductive Explanation", 'CXp': "Contrastive explanation"},
value = ['AXp', 'CXp'],
inline=True)]),
className="sidebar-dropdown",
inline=True)], className="sidebar"),
dcc.Tab(label='Advanced Parameters', children = [
html.Hr(),
html.Label("Choose the SAT solver : "),
html.Br(),
dcc.Dropdown(['g3', 'g4', 'lgl', 'mcb', 'mcm', 'mpl', 'm22', 'mc', 'mgh'], 'g3', id='solver_sat')
])
])],width=3)
self.layout = dbc.Row([self.sidebar,
dbc.Col(html.Div(id = "graph", children=" "), width=4),
dbc.Col(html.Div(id = "explanation", children=" "), width=3)])
self.tab = dcc.Tab(label='Application on Explainable AI', children=self.layout)
], className="sidebar")
])
self.expl_choice = dcc.Dropdown(self.model.list_expls,
id='expl_choice',
className="dropdown")
self.layout = dbc.Row([ dbc.Col([self.sidebar], width=3, class_name="sidebar"),
dbc.Col([dbc.Row(id = "graph", children=[]),
dbc.Row(html.Div([html.H5(id = "navigate_label", hidden=True, children="Navigate through the explanations and plot them on the tree : "),
html.Div(self.expl_choice, id='navigate_dropdown', hidden=True)]))], width=5, class_name="column_graph"),
dbc.Col(html.Main(id = "explanation", children=[], hidden=True), width=4)])
\ No newline at end of file
import base64
import io
import pickle
import pandas as pd
import sklearn
import numpy as np
from dash import html
def parse_contents_tree(contents, filename):
def parse_contents_graph(contents, filename):
content_type, content_string = contents.split(',')
decoded = base64.b64decode(content_string)
try:
......@@ -31,7 +31,8 @@ def parse_contents_instance(contents, filename):
data = decoded.decode('utf-8')
else :
data = decoded.decode('utf-8')
data = str(data).strip().split(',')
data = list(map(lambda i: tuple([i[0], np.float32(i[1])]), [i.split('=') for i in data]))
except Exception as e:
print(e)
return html.Div([
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment