diff --git a/.gitignore b/.gitignore index ac7f262d44c6cf8e2e6c604978156a382186ff8c..ed1537f039bd0b1e28e3d3fd73f3cae685e3e067 100644 --- a/.gitignore +++ b/.gitignore @@ -7,4 +7,5 @@ decision_tree_classifier_20170212.pkl push_command adult.pkl adult_data_00000.inst -iris_00000.txt \ No newline at end of file +iris_00000.txt +tests \ No newline at end of file diff --git a/app.py b/app.py index e605e68bf0922f4f6ca56198ba817b533d2ddd50..9fc4fab2db76b7424e8fe6a01c0b5c4cb1c6458e 100644 --- a/app.py +++ b/app.py @@ -4,103 +4,52 @@ import json import dash import dash_bootstrap_components as dbc -import pandas as pd -from dash import Input, Output, State, dcc, html -from dash.exceptions import PreventUpdate +from dash import dcc, html -from pages.application.layout_application import Model, View -from utils import extract_data, parse_contents_instance, parse_contents_tree +from callbacks import register_callbacks +from pages.application.application import Application, Model, View +from utils import extract_data -''' -Loading data -''' +app = dash.Dash(external_stylesheets=[dbc.themes.LUX], suppress_callback_exceptions=True) + +################################################################################# +############################# Layouts ########################################### +################################################################################# models_data = open('data_retriever.json') data = json.load(models_data)["data"] -app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP]) - - -''' -Construction of the layout -''' +#For home directory +page_home = dbc.Row([]) +#For course directory +page_course = dbc.Row([]) +#For the application names_models, dict_components = extract_data(data) -model = Model(names_models, dict_components) -view = View(model) -tabs = dcc.Tabs([ - dcc.Tab(label='Course on Explainable AI', children=[]), - view.tab, -]) +model_application = Model(names_models, dict_components) +view_application = View(model_application) +page_application = Application(view_application) app.layout = html.Div([ - html.H1('FXToolKit'), - tabs]) - - -''' -Callback for the app -''' -@app.callback( - Output('dataset_filename', 'children'), - Output('instance_filename', 'children'), - Output('graph', 'children'), - Output('explanation', 'children'), - Input('ml_model_choice', 'value'), - Input('ml_dataset_choice', 'contents'), - Input('ml_instance_choice', 'contents'), - Input('number_explanations', 'value'), - Input('explanation_type', 'value'), - Input('solver_sat', 'value'), - State('ml_dataset_choice', 'filename'), - State('ml_instance_choice', 'filename'), - prevent_initial_call=True -) -def update_ml_type(value_ml_model, dataset_contents, instance_contents, enum, xtype, solver, dataset_filename, instance_filename): - ctx = dash.callback_context - if ctx.triggered: - ihm_id = ctx.triggered[0]['prop_id'].split('.')[0] - if ihm_id == 'ml_model_choice' : - model.update_ml_model(value_ml_model) - return "", "", "", "" - - elif ihm_id == 'ml_dataset_choice': - if value_ml_model == None : - raise PreventUpdate - tree, typ = parse_contents_tree(dataset_contents, dataset_filename) - model.update_dataset(tree, typ) - return dataset_filename, "", model.component.network, "" - - elif ihm_id == 'ml_instance_choice' : - if value_ml_model == None or dataset_contents == None or enum == None or xtype==None: - raise PreventUpdate - instance = parse_contents_instance(instance_contents, instance_filename) - model.update_instance(instance, enum, xtype) - return dataset_filename, instance_filename, model.component.network, model.component.explanation - - elif ihm_id == 'number_explanations' : - if value_ml_model == None or dataset_contents == None or instance_contents == None or xtype==None: - raise PreventUpdate - instance = parse_contents_instance(instance_contents, instance_filename) - model.update_instance(instance, enum, xtype) - return dataset_filename, instance_filename, model.component.network, model.component.explanation - - elif ihm_id == 'explanation_type' : - if value_ml_model == None or dataset_contents == None or instance_contents == None or enum == None : - raise PreventUpdate - instance = parse_contents_instance(instance_contents, instance_filename) - model.update_instance(instance, enum, xtype) - return dataset_filename, instance_filename, model.component.network, model.component.explanation - - elif ihm_id == 'solver_sat' : - if value_ml_model == None or dataset_contents == None or instance_contents == None or enum == None or xtype == None: - raise PreventUpdate - instance = parse_contents_instance(instance_contents, instance_filename) - model.update_instance(instance, enum, xtype, solver=solver) - return dataset_filename, instance_filename, model.component.network, model.component.explanation - + dcc.Location(id='url', refresh=False), + html.Nav(id='navbar-container', + children=[dbc.NavbarSimple( + children=[ + dbc.NavItem(dbc.NavLink("Home", id="home-link", href="/")), + dbc.NavItem(dbc.NavLink("Course", id="course-link", href="/course")), + dbc.NavItem(dbc.NavLink("Application on explainable AI", id="application-link", href="/application")), + ], + brand="FX ToolKit", + color="primary", + dark=True,)]), + html.Div(id='page-content') +]) +################################################################################# +################################# Callback for the app ########################## +################################################################################# +register_callbacks(page_home, page_course, page_application, app) -''' -Launching app -''' +################################################################################# +################################# Launching app ################################# +################################################################################# if __name__ == '__main__': app.run_server(debug=True) diff --git a/assets/header.css b/assets/header.css new file mode 100644 index 0000000000000000000000000000000000000000..63ccc73777c8d9462f49d5e3a64395233f7de102 --- /dev/null +++ b/assets/header.css @@ -0,0 +1,103 @@ +/* NAVBAR */ + +.navbar-dark .navbar-brand { + color: #fff; + font-size: 30px; + } + + + + +/* SIDEBAR */ + +.sidebar { + padding: 2rem; + padding-top:0.5rem; + color: rgb(255, 255, 255); + font-weight: 300; + background-color: black; +} + +.sidebar .tab.jsx-3468109796 { + color:rgb(255, 255, 255); + font-weight: 500; + background-color: #1a1c1d; +} + +.sidebar .tab--selected.jsx-3468109796:hover { + background-color:gray; + } + +.sidebar .upload { + width: 100%; + height: 50px; + line-height: 50px; + border-width: 1px; + border-style: dashed; + border-radius: 5px; + text-align: center; + margin: 10px +} + +.sidebar .Select-control { + width: 100%; + height: 30px; + line-height: 30px; + border-width: 1px; + border-radius: 5px; + text-align: center; + margin: 10px; + color:rgb(255, 255, 255); + font-weight: 400; + background-color: black; +} + +.sidebar .sidebar-dropdown{ + width: 100%; + height: 30px; + line-height: 30px; + border-width: 1px; + border-radius: 5px; + text-align: center; + margin: 10px; + color:rgb(255, 255, 255); + font-weight: 400; + background-color: black; +} + +.sidebar .has-value.Select--single > .Select-control .Select-value .Select-value-label, .has-value.is-pseudo-focused.Select--single > .sidebar .Select-control .Select-value .Select-value-label { + color:rgb(255, 255, 255); +} + +.sidebar .Select-menu-outer{ + width: 100%; + border-width: 1px; + border-radius: 5px; + text-align: center; + margin: 10px; + color:rgb(255, 255, 255); + font-weight: 400; + background-color: black; +} + +/* EXPLANATION */ + +main#explanation { + width: 95%; + margin-bottom: 5rem; + margin-top: 5rem; + border-width: 4px; + border-style:double; + border-radius: 5px; + padding: 2rem; + border-radius: 10px; +} + +/* GRAPH */ + +.column_graph { + margin-top: 5rem; +} + + + diff --git a/assets/typography.css b/assets/typography.css new file mode 100644 index 0000000000000000000000000000000000000000..be7111a86df23034aac7e48a4931610e19e37ce9 --- /dev/null +++ b/assets/typography.css @@ -0,0 +1,21 @@ +body { + font-family: sans-serif; +} + +H4 { + font-size: 20px; + text-decoration-line:underline; + text-decoration-thickness:2px; + text-decoration-style:solid; + color: hsl(229, 58%, 19%) +} + +H5 { + font-size: 16px; + color: hsl(228, 58%, 12%); +} + +p { + font-size: 15px; + color: hsl(0, 0%, 0%) +} \ No newline at end of file diff --git a/callbacks.py b/callbacks.py new file mode 100644 index 0000000000000000000000000000000000000000..126e27e82da879906f65742006cc4685568371e2 --- /dev/null +++ b/callbacks.py @@ -0,0 +1,116 @@ +import dash +import pandas as pd +from dash import Input, Output, State +from dash.dependencies import Input, Output, State +from dash.exceptions import PreventUpdate + +from utils import parse_contents_graph, parse_contents_instance + + +def register_callbacks(page_home, page_course, page_application, app): + page_list = ['home', 'course', 'application'] + + @app.callback( + Output('page-content', 'children'), + Input('url', 'pathname')) + def display_page(pathname): + if pathname == '/': + return page_home + if pathname == '/application': + return page_application.view.layout + if pathname == '/course': + return page_course + + @app.callback(Output('home-link', 'active'), + Output('course-link', 'active'), + Output('application-link', 'active'), + Input('url', 'pathname')) + def navbar_state(pathname): + active_link = ([pathname == f'/{i}' for i in page_list]) + return active_link[0], active_link[1], active_link[2] + + @app.callback( + Output('pretrained_model_filename', 'children'), + Output('instance_filename', 'children'), + Output('graph', 'children'), + Output('explanation', 'children'), + Input('ml_model_choice', 'value'), + Input('ml_pretrained_model_choice', 'contents'), + State('ml_pretrained_model_choice', 'filename'), + Input('ml_instance_choice', 'contents'), + State('ml_instance_choice', 'filename'), + Input('number_explanations', 'value'), + Input('explanation_type', 'value'), + Input('solver_sat', 'value'), + Input('expl_choice', 'value'), + prevent_initial_call=True + ) + def update_ml_type(value_ml_model, pretrained_model_contents, pretrained_model_filename, instance_contents, instance_filename, enum, xtype, solver, expl_choice): + ctx = dash.callback_context + if ctx.triggered: + ihm_id = ctx.triggered[0]['prop_id'].split('.')[0] + model_application = page_application.model + if ihm_id == 'ml_model_choice' : + model_application.update_ml_model(value_ml_model) + return None, None, None, None + + elif ihm_id == 'ml_pretrained_model_choice': + if value_ml_model is None : + raise PreventUpdate + tree, typ = parse_contents_graph(pretrained_model_contents, pretrained_model_filename) + model_application.update_pretrained_model(tree, typ) + return pretrained_model_filename, None, model_application.component.network, None + + elif ihm_id == 'ml_instance_choice' : + if value_ml_model is None or pretrained_model_contents is None or enum is None or xtype is None: + raise PreventUpdate + instance = parse_contents_instance(instance_contents, instance_filename) + model_application.update_instance(instance, enum, xtype) + return pretrained_model_filename, instance_filename, model_application.component.network, model_application.component.explanation + + elif ihm_id == 'number_explanations' : + if value_ml_model is None or pretrained_model_contents is None or instance_contents is None or xtype is None: + raise PreventUpdate + instance = parse_contents_instance(instance_contents, instance_filename) + model_application.update_instance(instance, enum, xtype) + return pretrained_model_filename, instance_filename, model_application.component.network, model_application.component.explanation + + elif ihm_id == 'explanation_type' : + if value_ml_model is None or pretrained_model_contents is None or instance_contents is None or enum is None : + raise PreventUpdate + instance = parse_contents_instance(instance_contents, instance_filename) + model_application.update_instance(instance, enum, xtype) + return pretrained_model_filename, instance_filename, model_application.component.network, model_application.component.explanation + + elif ihm_id == 'solver_sat' : + if value_ml_model is None or pretrained_model_contents is None or instance_contents is None or enum is None or xtype is None: + raise PreventUpdate + instance = parse_contents_instance(instance_contents, instance_filename) + model_application.update_instance(instance, enum, xtype, solver=solver) + return pretrained_model_filename, instance_filename, model_application.component.network, model_application.component.explanation + + elif ihm_id == 'expl_choice' : + if instance_contents is None : + raise PreventUpdate + model_application.update_expl(expl_choice) + return pretrained_model_filename, instance_filename, model_application.component.network, model_application.component.explanation + + + @app.callback( + Output('explanation', 'hidden'), + Output('navigate_label', 'hidden'), + Output('navigate_dropdown', 'hidden'), + Output('expl_choice', 'options'), + Input('explanation', 'children'), + Input('explanation_type', 'value'), + prevent_initial_call=True + ) + def layout_buttons_navigate_expls(explanation, explanation_type): + if explanation is None or "AXp" not in explanation_type: + return True, True, True, {} + else : + options = {} + model_application = page_application.model + for i in range (len(model_application.list_expls)): + options[str(model_application.list_expls[i])] = model_application.list_expls[i] + return False, False, False, options diff --git a/pages/__init__.py b/pages/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pages/application/DecisionTree/DecisionTreeComponent.py b/pages/application/DecisionTree/DecisionTreeComponent.py index 93716c0abca762b56602827cdf356fc3320eceba..662884f4a755556a110e97b99456adaf860b100e 100644 --- a/pages/application/DecisionTree/DecisionTreeComponent.py +++ b/pages/application/DecisionTree/DecisionTreeComponent.py @@ -1,41 +1,59 @@ -from dash import dcc -from pages.application.DecisionTree.utils.dtviz import visualize, visualize_instance -from pages.application.DecisionTree.utils.dtree import DecisionTree +from os import path + +import dash_bootstrap_components as dbc import dash_interactive_graphviz +import numpy as np +from dash import dcc, html +from pages.application.DecisionTree.utils.dtree import DecisionTree +from pages.application.DecisionTree.utils.dtviz import (visualize, + visualize_expl, + visualize_instance) -import os.path -from os import path -import numpy as np class DecisionTreeComponent(): def __init__(self, tree, typ_data): - if typ_data == "dt" : - self.dt = DecisionTree(from_dt = tree) - elif typ_data == "pkl" : - self.dt = DecisionTree(from_pickle = tree) + self.dt = DecisionTree(from_pickle = tree) dot_source = visualize(self.dt) - - self.network = dash_interactive_graphviz.DashInteractiveGraphviz( - dot_source=dot_source - ) - - self.explanation = dcc.Textarea(value = "", style = { "font_size" : "15px", - "width": "40rem", - "height": "40rem", - "margin-bottom": "5rem", - "background-color": "#f8f9fa", - }) + self.network = [dbc.Row(dash_interactive_graphviz.DashInteractiveGraphviz(dot_source=dot_source, style = {"width": "60%", + "height": "90%", + "background-color": "transparent"}))] + self.explanation = [] def update_with_explicability(self, instance, enum, xtype, solver) : - instance = str(instance).strip().split(',') - instance = list(map(lambda i: tuple([i[0], np.float32(i[1])]), [i.split('=') for i in instance])) - dot_source = visualize_instance(self.dt, instance) - self.network = dash_interactive_graphviz.DashInteractiveGraphviz( - dot_source=dot_source - ) - - self.explanation.value = self.dt.explain(instance, enum=enum, xtype = xtype, solver=solver) + self.network = [dbc.Row(dash_interactive_graphviz.DashInteractiveGraphviz( + dot_source=dot_source, style = {"width": "50%", + "height": "80%", + "background-color": "transparent"} + ))] + + self.explanation = [] + list_explanations_path=[] + explanation = self.dt.explain(instance, enum=enum, xtype = xtype, solver=solver) + + #Creating a clean and nice text component + for k in explanation.keys() : + if k != "List of path explanation(s)": + if k in ["List of abductive explanation(s)","List of contrastive explanation(s)"] : + self.explanation.append(html.H4(k)) + for expl in explanation[k] : + self.explanation.append(html.Hr()) + self.explanation.append(html.P(expl)) + self.explanation.append(html.Hr()) + else : + self.explanation.append(html.P(k + explanation[k])) + else : + list_explanations_path = explanation["List of path explanation(s)"] + + return list_explanations_path + + def draw_explanation(self, instance, expl) : + dot_source = visualize_expl(self.dt, instance, expl) + self.network = [dbc.Row(dash_interactive_graphviz.DashInteractiveGraphviz( + dot_source=dot_source, + style = {"width": "50%", + "height": "80%", + "background-color": "transparent"}))] diff --git a/pages/application/DecisionTree/utils/dtree.py b/pages/application/DecisionTree/utils/dtree.py index 2a8b75f672ac927851084ab14b50cbb46f066f8e..6bd118dfd8cb6ca577338d8a8189cb00cea33a68 100644 --- a/pages/application/DecisionTree/utils/dtree.py +++ b/pages/application/DecisionTree/utils/dtree.py @@ -11,21 +11,26 @@ # #============================================================================== from __future__ import print_function + import collections from functools import reduce + +import sklearn from pysat.card import * from pysat.examples.hitman import Hitman from pysat.formula import CNF, IDPool from pysat.solvers import Solver -import sklearn from torch import threshold try: # for Python2 from cStringIO import StringIO except ImportError: # for Python3 from io import StringIO -from sklearn.tree import _tree + import numpy as np +from dash import dcc, html +from sklearn.tree import _tree + # #============================================================================== @@ -81,7 +86,7 @@ class DecisionTree(): def from_pickle_file(self, tree): #help(_tree.Tree) self.tree_ = tree.tree_ - print(sklearn.tree.export_text(tree)) + #print(sklearn.tree.export_text(tree)) try: feature_names = tree.feature_names_in_ except: @@ -132,7 +137,7 @@ class DecisionTree(): Traverse the tree and extract explicit paths. """ - if root in self.terms: + if root in self.terms.keys(): # store the path term = self.terms[root] self.paths[term].append(prefix) @@ -159,7 +164,7 @@ class DecisionTree(): sets = [] for t, paths in self.paths.items(): # ignoring the right class - if t == term: + if term in self.terms.keys() and self.terms[term] == t: continue # computing the sets to hit @@ -190,11 +195,15 @@ class DecisionTree(): inst_dic = {} for i in range(len(inst)): inst_dic[inst[i][0]] = np.float32(inst[i][1]) - inst_orig = inst[:] path, term = self.execute(inst_values) - explanation = str(inst_dic) + "\n \n" - decision_path_str = "c inst : IF : " + #contaiins all the elements for explanation + explanation_dic = {} + #instance plotting + explanation_dic["Instance : "] = str(inst_dic) + + #decision path + decision_path_str = "IF : " for node_id in path: # continue to the next node if it is a leaf node if term == node_id: @@ -207,43 +216,45 @@ class DecisionTree(): threshold=self.nodes[node_id].threshold) decision_path_str += "THEN " + str(self.terms[term]) - explanation += decision_path_str + "\n \n" - explanation +='c path len:'+ str(len(path))+ "\n \n \n" + explanation_dic["Decision path of instance : "] = decision_path_str + explanation_dic["Decision path length : "] = 'Path length is :'+ str(len(path)) # computing the sets to hit to_hit = self.prepare_sets(inst_dic, term) for type in xtype : if type == "AXp": - explanation += "Abductive explanation : " + "\n \n" - explanation += self.enumerate_abductive(to_hit, enum, solver, htype, term)+ "\n \n" + explanation_dic.update(self.enumerate_abductive(to_hit, enum, solver, htype, term)) else : - explanation += "Contrastive explanation : "+ "\n \n" - explanation += self.enumerate_contrastive(to_hit, term)+ "\n \n" + explanation_dic.update(self.enumerate_contrastive(to_hit, term)) - return explanation + return explanation_dic def enumerate_abductive(self, to_hit, enum, solver, htype, term): """ Enumerate abductive explanations. """ - explanation = "" + list_expls = [] + list_expls_str = [] + explanation = {} with Hitman(bootstrap_with=to_hit, solver='m22', htype=htype) as hitman: expls = [] for i, expl in enumerate(hitman.enumerate(), 1): - explanation += 'c expl: IF {0} THEN class={1}'.format(' AND '.join(["(inst[{feature}] = {value}) {inequality} {threshold})".format(feature=p[0], + list_expls.append([ p[0] + p[2] + p[3] for p in expl]) + list_expls_str.append('Explanation: IF {0} THEN class={1}'.format(' AND '.join(["(inst[{feature}] = {value}) {inequality} {threshold})".format(feature=p[0], value=p[1], inequality=p[2], threshold=p[3]) - for p in sorted(expl, key=lambda p: p[0])]), str(self.terms[term]))+ "\n" + for p in sorted(expl, key=lambda p: p[0])]), str(self.terms[term]))) expls.append(expl) if i == enum: break - - explanation += 'c nof expls:' + str(i)+ "\n" - explanation += 'c min expl:'+ str( min([len(e) for e in expls]))+ "\n" - explanation += 'c max expl:'+ str( max([len(e) for e in expls]))+ "\n" - explanation += 'c avg expl: {0:.2f}'.format(sum([len(e) for e in expls]) / len(expls))+ "\n \n \n" + explanation["List of path explanation(s)"] = list_expls + explanation["List of abductive explanation(s)"] = list_expls_str + explanation["Number of abductive explanation(s) : "] = str(i) + explanation["Minimal abductive explanation : "] = str( min([len(e) for e in expls])) + explanation["Maximal abductive explanation : "] = str( max([len(e) for e in expls])) + explanation["Average abductive explanation : "] = '{0:.2f}'.format(sum([len(e) for e in expls]) / len(expls)) return explanation @@ -263,15 +274,17 @@ class DecisionTree(): to_hit = [set(s) for s in to_hit] to_hit.sort(key=lambda s: len(s)) expls = list(reduce(process_set, to_hit, [])) - explanation = "" + list_expls_str = [] + explanation = {} for expl in expls: - explanation += 'c expl: IF {0} THEN class!={1}'.format(' OR '.join(["inst[{feature}] {inequality} {threshold})".format(feature=p[0], + list_expls_str.append('Contrastive: IF {0} THEN class!={1}'.format(' OR '.join(["inst[{feature}] {inequality} {threshold})".format(feature=p[0], inequality="<=" if p[2]==">" else ">", threshold=p[3]) - for p in sorted(expl, key=lambda p: p[0])]), str(self.terms[term]))+ "\n" - explanation +='c nof expls:'+ str(len(expls))+ "\n" - explanation +='c min expl:'+ str( min([len(e) for e in expls]))+ "\n" - explanation +='c max expl:'+ str( max([len(e) for e in expls]))+ "\n" - explanation +='c avg expl: {0:.2f}'.format(sum([len(e) for e in expls]) / len(expls))+ "\n" + for p in sorted(expl, key=lambda p: p[0])]), str(self.terms[term]))) + explanation["List of contrastive explanation(s)"] = list_expls_str + explanation["Number of contrastive explanation(s) : "]=str(len(expls)) + explanation["Minimal contrastive explanation : "]= str( min([len(e) for e in expls])) + explanation["Maximal contrastive explanation : "]= str( max([len(e) for e in expls])) + explanation["Average contrastive explanation : "]='{0:.2f}'.format(sum([len(e) for e in expls]) / len(expls)) - return explanation \ No newline at end of file + return explanation diff --git a/pages/application/DecisionTree/utils/dtviz.py b/pages/application/DecisionTree/utils/dtviz.py index 11ac6fb2be79bff4c93a1bbcff4b2147b4aa1095..aacf7b646dd1384157116c138a125a069670c772 100755 --- a/pages/application/DecisionTree/utils/dtviz.py +++ b/pages/application/DecisionTree/utils/dtviz.py @@ -8,22 +8,40 @@ ## E-mail: alexey.ignatiev@monash.edu ## -# -#============================================================================== -from pages.application.DecisionTree.utils.dtree import DecisionTree -import pygraphviz import numpy as np -import pandas as pd +import pygraphviz # #============================================================================== +def create_legend(g): + legend = g.subgraphs()[-1] + legend.add_node("a", style = "invis") + legend.add_node("b", style = "invis") + legend.add_node("c", style = "invis") + legend.add_node("d", style = "invis") + + legend.add_edge("a","b") + edge = legend.get_edge("a","b") + edge.attr["label"] = "instance" + edge.attr["style"] = "dashed" + + legend.add_edge("c","d") + edge = legend.get_edge("c","d") + edge.attr["label"] = "instance with explanation" + edge.attr["color"] = "blue" + edge.attr["style"] = "dashed" + + def visualize(dt): """ Visualize a DT with graphviz. """ - g = pygraphviz.AGraph(directed=True, strict=True) + g = pygraphviz.AGraph(name='root', rankdir="TB") + g.is_directed() + g.is_strict() + + #g = pygraphviz.AGraph(name = "main", directed=True, strict=True) g.edge_attr['dir'] = 'forward' - g.graph_attr['rankdir'] = 'TB' # non-terminal nodes for n in dt.nodes: @@ -56,6 +74,9 @@ def visualize(dt): edge.attr['fontsize'] = 10 edge.attr['arrowsize'] = 0.8 + g.add_subgraph(name='legend') + create_legend(g) + # saving file g.layout(prog='dot') return(g.string()) @@ -102,7 +123,7 @@ def visualize_instance(dt, instance): edge.attr['arrowsize'] = 0.8 #instance path in blue if ((n1,children_left) in edges_instance): - edge.attr['color'] = 'blue' + edge.attr['style'] = 'dashed' children_right = dt.nodes[n1].children_right g.add_edge(n1, children_right) @@ -112,8 +133,74 @@ def visualize_instance(dt, instance): edge.attr['arrowsize'] = 0.8 #instance path in blue if ((n1,children_right) in edges_instance): - edge.attr['color'] = 'blue' + edge.attr['style'] = 'dashed' + + g.add_subgraph(name='legend') + create_legend(g) # saving file g.layout(prog='dot') return(g.to_string()) +# +#============================================================================== +def visualize_expl(dt, instance, expl): + """ + Visualize a DT with graphviz and plot the running instance. + """ + g = pygraphviz.AGraph(directed=True, strict=True) + g.edge_attr['dir'] = 'forward' + g.graph_attr['rankdir'] = 'TB' + + # non-terminal nodes + for n in dt.nodes: + g.add_node(n, label=str(dt.nodes[n].feat)) + node = g.get_node(n) + node.attr['shape'] = 'circle' + node.attr['fontsize'] = 13 + + # terminal nodes + for n in dt.terms: + g.add_node(n, label=str(dt.terms[n])) + node = g.get_node(n) + node.attr['shape'] = 'square' + node.attr['fontsize'] = 13 + + #path that follows the instance - colored in blue + instance = [np.float32(i[1]) for i in instance] + path, term_id_node = dt.execute(instance) + edges_instance = [] + for i in range (len(path)-1) : + edges_instance.append((path[i], path[i+1])) + + for n1 in dt.nodes: + threshold = dt.nodes[n1].threshold + + children_left = dt.nodes[n1].children_left + g.add_edge(n1, children_left) + edge = g.get_edge(n1, children_left) + edge.attr['label'] = str(dt.nodes[n1].feat) + "<=" + str(threshold) + edge.attr['fontsize'] = 10 + edge.attr['arrowsize'] = 0.8 + #instance path in blue + if ((n1,children_left) in edges_instance): + edge.attr['style'] = 'dashed' + if edge.attr['label'] in expl : + edge.attr['color'] = 'blue' + + children_right = dt.nodes[n1].children_right + g.add_edge(n1, children_right) + edge = g.get_edge(n1, children_right) + edge.attr['label'] = str(dt.nodes[n1].feat) + ">" + str(threshold) + edge.attr['fontsize'] = 10 + edge.attr['arrowsize'] = 0.8 + #instance path in blue + if ((n1,children_right) in edges_instance): + edge.attr['style'] = 'dashed' + if edge.attr['label'] in expl : + edge.attr['color'] = 'blue' + + g.add_subgraph(name='legend') + create_legend(g) + + g.layout(prog='dot') + return(g.to_string()) diff --git a/pages/application/layout_application.py b/pages/application/application.py similarity index 57% rename from pages/application/layout_application.py rename to pages/application/application.py index 78c5c019436a0e3ac2a5e74ea59434b872b586f6..16996a6aa067aa7bf5c30bd20b93793995d256c1 100644 --- a/pages/application/layout_application.py +++ b/pages/application/application.py @@ -1,14 +1,13 @@ from dash import dcc, html -import dash import dash_bootstrap_components as dbc from pages.application.DecisionTree.DecisionTreeComponent import DecisionTreeComponent -SIDEBAR_STYLE = { -} +class Application(): + def __init__(self, view): + self.view = view + self.model = view.model -CONTENT_STYLE = { -} class Model(): @@ -19,11 +18,14 @@ class Model(): self.ml_models = names_models self.ml_model = '' - self.dataset = '' + self.pretrained_model = '' self.typ_data = '' self.instance = '' + self.list_expls = [] + self.expl_path = [] + self.component_class = '' self.component = '' @@ -32,41 +34,38 @@ class Model(): self.component_class = self.dict_components[self.ml_model] self.component_class = globals()[self.component_class] - def update_dataset(self, dataset_update, typ_data): - self.dataset = dataset_update + def update_pretrained_model(self, pretrained_model_update, typ_data): + self.pretrained_model = pretrained_model_update self.typ_data = typ_data - self.component = self.component_class(self.dataset, self.typ_data) + self.component = self.component_class(self.pretrained_model, self.typ_data) def update_instance(self, instance, enum, xtype, solver="g3"): self.instance = instance - self.component.update_with_explicability(self.instance, enum, xtype, solver) + self.list_expls = self.component.update_with_explicability(self.instance, enum, xtype, solver) + + def update_expl(self, expl): + self.expl = expl + self.component.draw_explanation(self.instance, expl) class View(): def __init__(self, model): self.model = model - self.ml_menu_models = dcc.Dropdown(self.model.ml_models, id='ml_model_choice') + self.ml_menu_models = dcc.Dropdown(self.model.ml_models, + id='ml_model_choice', + className="sidebar-dropdown") - self.dataset_upload = html.Div([ - dcc.Upload( - id='ml_dataset_choice', + self.pretrained_model_upload = html.Div([ + dcc.Upload( + id='ml_pretrained_model_choice', children=html.Div([ 'Drag and Drop or ', html.A('Select File') ]), - style={ - 'width': '100%', - 'height': '60px', - 'lineHeight': '60px', - 'borderWidth': '1px', - 'borderStyle': 'dashed', - 'borderRadius': '5px', - 'textAlign': 'center', - 'margin': '10px' - } + className="upload" ), - html.Div(id='dataset_filename')]) + html.Div(id='pretrained_model_filename')]) self.instance_upload = html.Div([ dcc.Upload( @@ -75,29 +74,20 @@ class View(): 'Drag and Drop or ', html.A('Select instance') ]), - style={ - 'width': '100%', - 'height': '60px', - 'lineHeight': '60px', - 'borderWidth': '1px', - 'borderStyle': 'dashed', - 'borderRadius': '5px', - 'textAlign': 'center', - 'margin': '10px' - } + className="upload" ), html.Div(id='instance_filename')]) - self.sidebar = dbc.Col([ - dcc.Tabs(children=[ - dcc.Tab(label='Basic Parameters', children = [ + self.sidebar = dcc.Tabs(children=[ + dcc.Tab(label='Basic Parameters', children = [ + html.Br(), html.Label("Choose the Machine Learning algorithm :"), html.Br(), self.ml_menu_models, html.Hr(), - html.Label("Choose the dataset : "), + html.Label("Choose the pretrained model : "), html.Br(), - self.dataset_upload, + self.pretrained_model_upload, html.Hr(), html.Label("Choose the instance to explain : "), html.Br(), @@ -107,8 +97,10 @@ class View(): html.Br(), dcc.Input( id="number_explanations", + value=1, type="number", - placeholder="How many explanations ?"), + placeholder="How many explanations ?", + className="sidebar-dropdown"), html.Hr(), html.Label("Choose the kind of explanation : "), html.Br(), @@ -116,17 +108,24 @@ class View(): id="explanation_type", options={'AXp' : "Abductive Explanation", 'CXp': "Contrastive explanation"}, value = ['AXp', 'CXp'], - inline=True)]), + className="sidebar-dropdown", + inline=True)], className="sidebar"), dcc.Tab(label='Advanced Parameters', children = [ html.Hr(), html.Label("Choose the SAT solver : "), html.Br(), dcc.Dropdown(['g3', 'g4', 'lgl', 'mcb', 'mcm', 'mpl', 'm22', 'mc', 'mgh'], 'g3', id='solver_sat') - ]) - ])],width=3) - - self.layout = dbc.Row([self.sidebar, - dbc.Col(html.Div(id = "graph", children=" "), width=4), - dbc.Col(html.Div(id = "explanation", children=" "), width=3)]) - - self.tab = dcc.Tab(label='Application on Explainable AI', children=self.layout) + ], className="sidebar") + ]) + + + + self.expl_choice = dcc.Dropdown(self.model.list_expls, + id='expl_choice', + className="dropdown") + + self.layout = dbc.Row([ dbc.Col([self.sidebar], width=3, class_name="sidebar"), + dbc.Col([dbc.Row(id = "graph", children=[]), + dbc.Row(html.Div([html.H5(id = "navigate_label", hidden=True, children="Navigate through the explanations and plot them on the tree : "), + html.Div(self.expl_choice, id='navigate_dropdown', hidden=True)]))], width=5, class_name="column_graph"), + dbc.Col(html.Main(id = "explanation", children=[], hidden=True), width=4)]) \ No newline at end of file diff --git a/utils.py b/utils.py index 27f0b5ea4d3e8c657245b7e9e44aac709e6ea985..a3afff7ac45293e228d2a82e510de33d05ec77b5 100644 --- a/utils.py +++ b/utils.py @@ -1,12 +1,12 @@ import base64 import io import pickle -import pandas as pd -import sklearn + +import numpy as np from dash import html -def parse_contents_tree(contents, filename): +def parse_contents_graph(contents, filename): content_type, content_string = contents.split(',') decoded = base64.b64decode(content_string) try: @@ -31,7 +31,8 @@ def parse_contents_instance(contents, filename): data = decoded.decode('utf-8') else : data = decoded.decode('utf-8') - + data = str(data).strip().split(',') + data = list(map(lambda i: tuple([i[0], np.float32(i[1])]), [i.split('=') for i in data])) except Exception as e: print(e) return html.Div([