From e4166e3ef9bd90f501fcf229f4d3099c1fa73067 Mon Sep 17 00:00:00 2001 From: Caroline DE POURTALES <cdepourt@montana.irit.fr> Date: Tue, 5 Apr 2022 16:56:41 +0200 Subject: [PATCH] init draw contrastive - complicated due to name of label --- callbacks.py | 26 +++++--- .../DecisionTree/DecisionTreeComponent.py | 17 ++++- pages/application/DecisionTree/utils/dtree.py | 4 ++ pages/application/DecisionTree/utils/dtviz.py | 64 ++++++++++++++++++- pages/application/application.py | 29 ++++++--- utils.py | 10 ++- 6 files changed, 126 insertions(+), 24 deletions(-) diff --git a/callbacks.py b/callbacks.py index 5ef62d6..25a794b 100644 --- a/callbacks.py +++ b/callbacks.py @@ -46,10 +46,11 @@ def register_callbacks(page_home, page_course, page_application, app): Input('explanation_type', 'value'), Input('solver_sat', 'value'), Input('expl_choice', 'value'), + Input('cont_expl_choice', 'value'), prevent_initial_call=True ) def update_ml_type(value_ml_model, pretrained_model_contents, pretrained_model_filename, model_info, model_info_filename, \ - instance_contents, instance_filename, enum, xtype, solver, expl_choice): + instance_contents, instance_filename, enum, xtype, solver, expl_choice, cont_expl_choice): ctx = dash.callback_context if ctx.triggered: ihm_id = ctx.triggered[0]['prop_id'].split('.')[0] @@ -116,27 +117,36 @@ def register_callbacks(page_home, page_course, page_application, app): model_application.update_expl(expl_choice) return pretrained_model_filename, model_info_filename, instance_filename, model_application.component.network, model_application.component.explanation + # Choice of CxP to draw + elif ihm_id == 'cont_expl_choice' : + if model_application.ml_model is None or model_application.pretrained_model is None or len(model_application.instance)==0 or model_application.enum<=0 or len(model_application.xtype)==0: + raise PreventUpdate + model_application.update_cont_expl(cont_expl_choice) + return pretrained_model_filename, model_info_filename, instance_filename, model_application.component.network, model_application.component.explanation @app.callback( Output('explanation', 'hidden'), - Output('navigate_label', 'hidden'), - Output('navigate_dropdown', 'hidden'), + Output('interaction_graph', 'hidden'), Output('expl_choice', 'options'), + Output('cont_expl_choice', 'options'), Input('explanation', 'children'), Input('explanation_type', 'value'), prevent_initial_call=True ) def layout_buttons_navigate_expls(explanation, explanation_type): if explanation is None or len(explanation_type)==0: - return True, True, True, {} + return True, True, {}, {} elif "AXp" not in explanation_type and "CXp" in explanation_type: - return False, True, True, {} + return False, True, {}, {} else : - options = {} + options_expls = {} + options_cont_expls = {} model_application = page_application.model for i in range (len(model_application.list_expls)): - options[str(model_application.list_expls[i])] = model_application.list_expls[i] - return False, False, False, options + options_expls[str(model_application.list_expls[i])] = model_application.list_expls[i] + for i in range (len(model_application.list_cont_expls)): + options_cont_expls[str(model_application.list_cont_expls[i])] = model_application.list_cont_expls[i] + return False, False, options_expls, options_cont_expls @app.callback( Output('choice_info_div', 'hidden'), diff --git a/pages/application/DecisionTree/DecisionTreeComponent.py b/pages/application/DecisionTree/DecisionTreeComponent.py index 8a56c32..f9d34e3 100644 --- a/pages/application/DecisionTree/DecisionTreeComponent.py +++ b/pages/application/DecisionTree/DecisionTreeComponent.py @@ -11,7 +11,8 @@ from pages.application.DecisionTree.utils.dtree import DecisionTree from pages.application.DecisionTree.utils.dtviz import (visualize, visualize_expl, - visualize_instance) + visualize_instance, + visualize_contrastive_expl) class DecisionTreeComponent(): @@ -135,7 +136,7 @@ class DecisionTreeComponent(): self.explanation.append(html.H4("Instance : \n")) self.explanation.append(html.P(str([str(instance[i]) for i in range (len(instance))]))) for k in explanation.keys() : - if k != "List of path explanation(s)": + if k != "List of path explanation(s)" and k!= "List of path contrastive explanation(s)" : if k in ["List of abductive explanation(s)","List of contrastive explanation(s)"] : self.explanation.append(html.H4(k)) for expl in explanation[k] : @@ -146,8 +147,9 @@ class DecisionTreeComponent(): self.explanation.append(html.P(k + explanation[k])) else : list_explanations_path = explanation["List of path explanation(s)"] + list_contrastive_explanations_path = explanation["List of path contrastive explanation(s)"] - return list_explanations_path + return list_explanations_path, list_contrastive_explanations_path def draw_explanation(self, instance, expl) : instance = self.translate_instance(instance) @@ -157,3 +159,12 @@ class DecisionTreeComponent(): style = {"width": "50%", "height": "80%", "background-color": "transparent"})]) + + def draw_contrastive_explanation(self, instance, cont_expl) : + instance = self.translate_instance(instance) + dot_source = visualize_contrastive_expl(self.dt, instance, cont_expl) + self.network = html.Div([dash_interactive_graphviz.DashInteractiveGraphviz( + dot_source=dot_source, + style = {"width": "50%", + "height": "80%", + "background-color": "transparent"})]) diff --git a/pages/application/DecisionTree/utils/dtree.py b/pages/application/DecisionTree/utils/dtree.py index 79088c2..c2775df 100644 --- a/pages/application/DecisionTree/utils/dtree.py +++ b/pages/application/DecisionTree/utils/dtree.py @@ -393,14 +393,18 @@ class DecisionTree(): done.append(target) return done + list_contrastive_expls = [] + to_hit = [set(s) for s in to_hit] to_hit.sort(key=lambda s: len(s)) expls = list(reduce(process_set, to_hit, [])) list_expls_str = [] explanation = {} for expl in expls: + list_contrastive_expls.append([self.fvmap[(p[0],1-p[1])] for p in sorted(expl, key=lambda p: p[0])]) list_expls_str.append('Contrastive: IF {0} THEN class!={1}'.format(' OR '.join(['!{0}'.format(self.fvmap[p]) for p in sorted(expl, key=lambda p: p[0])]), term)) + explanation["List of path contrastive explanation(s)"] = list_contrastive_expls explanation["List of contrastive explanation(s)"] = list_expls_str explanation["Number of contrastive explanation(s) : "]=str(len(expls)) explanation["Minimal contrastive explanation : "]= str( min([len(e) for e in expls])) diff --git a/pages/application/DecisionTree/utils/dtviz.py b/pages/application/DecisionTree/utils/dtviz.py index 21ea639..e859cc0 100755 --- a/pages/application/DecisionTree/utils/dtviz.py +++ b/pages/application/DecisionTree/utils/dtviz.py @@ -12,7 +12,8 @@ #============================================================================== import getopt import pygraphviz - +import ast +import re # #============================================================================== def create_legend(g): @@ -194,3 +195,64 @@ def visualize_expl(dt, instance, expl): # saving file g.layout(prog='dot') return(g.to_string()) + +#============================================================================== +def visualize_contrastive_expl(dt, instance, cont_expl): + """ + Visualize a DT with graphviz and plot the running instance. + """ + #path that follows the instance - colored in blue + path, term, depth = dt.execute(instance) + edges_instance = [] + for i in range (len(path)-1) : + edges_instance.append((path[i], path[i+1])) + edges_instance.append((path[-1],"term:"+term)) + + g = pygraphviz.AGraph(directed=True, strict=True) + g.edge_attr['dir'] = 'forward' + + g.graph_attr['rankdir'] = 'TB' + + # non-terminal nodes + for n in dt.nodes: + g.add_node(n, label=dt.feature_names[dt.nodes[n].feat]) + node = g.get_node(n) + node.attr['shape'] = 'circle' + node.attr['fontsize'] = 13 + + # terminal nodes + for n in dt.terms: + g.add_node(n, label=dt.terms[n]) + node = g.get_node(n) + node.attr['shape'] = 'square' + node.attr['fontsize'] = 13 + + # transitions + for n1 in dt.nodes: + for v in dt.nodes[n1].vals: + n2 = dt.nodes[n1].vals[v] + n2_type = g.get_node(n2).attr['shape'] + g.add_edge(n1, n2) + edge = g.get_edge(n1, n2) + if len(v) == 1: + edge.attr['label'] = dt.fvmap[tuple([dt.nodes[n1].feat, tuple(v)[0]])] + else: + edge.attr['label'] = '{0}'.format('\n'.join([dt.fvmap[tuple([dt.nodes[n1].feat, val])] for val in tuple(v)])) + + #instance path in dashed + if ((n1,n2) in edges_instance) or (n2_type=='square' and (n1, "term:"+ dt.terms[n2]) in edges_instance): + edge.attr['style'] = 'dashed' + + for label in edge.attr['label'].split('\n'): + if label in cont_expl: + edge.attr['color'] = 'red' + + edge.attr['fontsize'] = 10 + edge.attr['arrowsize'] = 0.8 + + g.add_subgraph(name='legend') + create_legend(g) + + # saving file + g.layout(prog='dot') + return(g.to_string()) diff --git a/pages/application/application.py b/pages/application/application.py index db54577..062cc14 100644 --- a/pages/application/application.py +++ b/pages/application/application.py @@ -31,7 +31,10 @@ class Model(): self.instance = '' self.list_expls = [] - self.expl_path = [] + self.list_cont_expls = [] + + self.expl='' + self.cont_expl='' self.component_class = '' self.component = '' @@ -56,24 +59,28 @@ class Model(): def update_instance(self, instance): self.instance = instance - self.list_expls = self.component.update_with_explicability(self.instance, self.enum, self.xtype, self.solver) + self.list_expls, self.list_cont_expls = self.component.update_with_explicability(self.instance, self.enum, self.xtype, self.solver) def update_enum(self, enum): self.enum = enum - self.list_expls = self.component.update_with_explicability(self.instance, self.enum, self.xtype, self.solver) + self.list_expls, self.list_cont_expls = self.component.update_with_explicability(self.instance, self.enum, self.xtype, self.solver) def update_xtype(self, xtype): self.xtype = xtype - self.list_expls = self.component.update_with_explicability(self.instance, self.enum, self.xtype, self.solver) + self.list_expls, self.list_cont_expls = self.component.update_with_explicability(self.instance, self.enum, self.xtype, self.solver) def update_solver(self, solver): self.solver = solver - self.list_expls = self.component.update_with_explicability(self.instance, self.enum, self.xtype, self.solver) + self.list_expls, self.list_cont_expls = self.component.update_with_explicability(self.instance, self.enum, self.xtype, self.solver) def update_expl(self, expl): self.expl = expl self.component.draw_explanation(self.instance, expl) + def update_cont_expl(self, cont_expl): + self.expl = cont_expl + self.component.draw_contrastive_explanation(self.instance, cont_expl) + class View(): def __init__(self, model): @@ -178,12 +185,16 @@ class View(): ], className="sidebar")]) - self.expl_choice = html.Div([html.H5(id = "navigate_label", hidden=True, children="Navigate through the explanations and plot them on the tree : "), - html.Div(id='navigate_dropdown', hidden=True, - children = [dcc.Dropdown(self.model.list_expls, + self.expl_choice = html.Div(id = "interaction_graph", hidden=True, + children=[html.H5("Navigate through the explanations and plot them on the tree : "), + html.Div(children = [dcc.Dropdown(self.model.list_expls, id='expl_choice', + className="dropdown")]), + html.H5("Navigate through the contrastive explanations and plot them on the tree : "), + html.Div(children = [dcc.Dropdown(self.model.list_cont_expls, + id='cont_expl_choice', className="dropdown")])]) - + self.layout = dbc.Row([ dbc.Col([self.sidebar], width=3, class_name="sidebar"), dbc.Col([dbc.Row(id = "graph", children=[]), diff --git a/utils.py b/utils.py index 42b30c9..d9996f5 100644 --- a/utils.py +++ b/utils.py @@ -27,9 +27,9 @@ def parse_contents_data(contents, filename): decoded = base64.b64decode(content_string) try: if '.csv' in filename: - data = decoded.decode('utf-8') + data = decoded.decode('utf-8').strip() if '.txt' in filename: - data = decoded.decode('utf-8') + data = decoded.decode('utf-8').strip() except Exception as e: print(e) return html.Div([ @@ -51,7 +51,11 @@ def parse_contents_instance(contents, filename): data = str(data).strip().split(',') data = list(map(lambda i: tuple([i[0], np.float32(i[1])]), [i.split('=') for i in data])) elif '.json' in filename: - data = decoded.decode('utf-8') + data = decoded.decode('utf-8').strip() + data = json.loads(data) + data = list(tuple(data.items())) + elif '.inst' in filename: + data = decoded.decode('utf-8').strip() data = json.loads(data) data = list(tuple(data.items())) except Exception as e: -- GitLab