From e4166e3ef9bd90f501fcf229f4d3099c1fa73067 Mon Sep 17 00:00:00 2001
From: Caroline DE POURTALES <cdepourt@montana.irit.fr>
Date: Tue, 5 Apr 2022 16:56:41 +0200
Subject: [PATCH] init draw contrastive - complicated due to name of label

---
 callbacks.py                                  | 26 +++++---
 .../DecisionTree/DecisionTreeComponent.py     | 17 ++++-
 pages/application/DecisionTree/utils/dtree.py |  4 ++
 pages/application/DecisionTree/utils/dtviz.py | 64 ++++++++++++++++++-
 pages/application/application.py              | 29 ++++++---
 utils.py                                      | 10 ++-
 6 files changed, 126 insertions(+), 24 deletions(-)

diff --git a/callbacks.py b/callbacks.py
index 5ef62d6..25a794b 100644
--- a/callbacks.py
+++ b/callbacks.py
@@ -46,10 +46,11 @@ def register_callbacks(page_home, page_course, page_application, app):
         Input('explanation_type', 'value'),
         Input('solver_sat', 'value'),
         Input('expl_choice', 'value'),
+        Input('cont_expl_choice', 'value'),
         prevent_initial_call=True
     )
     def update_ml_type(value_ml_model, pretrained_model_contents, pretrained_model_filename, model_info, model_info_filename, \
-                        instance_contents, instance_filename, enum, xtype, solver, expl_choice):
+                        instance_contents, instance_filename, enum, xtype, solver, expl_choice, cont_expl_choice):
         ctx = dash.callback_context
         if ctx.triggered:
             ihm_id = ctx.triggered[0]['prop_id'].split('.')[0]
@@ -116,27 +117,36 @@ def register_callbacks(page_home, page_course, page_application, app):
                 model_application.update_expl(expl_choice)
                 return pretrained_model_filename, model_info_filename, instance_filename, model_application.component.network, model_application.component.explanation             
 
+            # Choice of CxP to draw
+            elif ihm_id == 'cont_expl_choice' :
+                if model_application.ml_model is None or model_application.pretrained_model is None or len(model_application.instance)==0 or model_application.enum<=0 or len(model_application.xtype)==0:
+                    raise PreventUpdate
+                model_application.update_cont_expl(cont_expl_choice)
+                return pretrained_model_filename, model_info_filename, instance_filename, model_application.component.network, model_application.component.explanation             
 
     @app.callback(
         Output('explanation', 'hidden'),
-        Output('navigate_label', 'hidden'),    
-        Output('navigate_dropdown', 'hidden'),
+        Output('interaction_graph', 'hidden'),    
         Output('expl_choice', 'options'),
+        Output('cont_expl_choice', 'options'),
         Input('explanation', 'children'),
         Input('explanation_type', 'value'),
         prevent_initial_call=True
     )
     def layout_buttons_navigate_expls(explanation, explanation_type):
         if explanation is None or len(explanation_type)==0:
-            return True, True, True, {}
+            return True, True, {}, {}
         elif "AXp" not in explanation_type and "CXp" in explanation_type:
-            return False, True, True, {}
+            return False, True, {}, {}
         else : 
-            options = {}
+            options_expls = {}
+            options_cont_expls = {}
             model_application = page_application.model
             for i in range (len(model_application.list_expls)):
-                options[str(model_application.list_expls[i])] = model_application.list_expls[i]
-            return False, False, False, options
+                options_expls[str(model_application.list_expls[i])] = model_application.list_expls[i]
+            for i in range (len(model_application.list_cont_expls)):
+                options_cont_expls[str(model_application.list_cont_expls[i])] = model_application.list_cont_expls[i]
+            return False, False, options_expls, options_cont_expls
 
     @app.callback(
         Output('choice_info_div', 'hidden'),
diff --git a/pages/application/DecisionTree/DecisionTreeComponent.py b/pages/application/DecisionTree/DecisionTreeComponent.py
index 8a56c32..f9d34e3 100644
--- a/pages/application/DecisionTree/DecisionTreeComponent.py
+++ b/pages/application/DecisionTree/DecisionTreeComponent.py
@@ -11,7 +11,8 @@ from pages.application.DecisionTree.utils.dtree import DecisionTree
 
 from pages.application.DecisionTree.utils.dtviz import (visualize,
                                                         visualize_expl,
-                                                        visualize_instance)
+                                                        visualize_instance,
+                                                        visualize_contrastive_expl)
 
 class DecisionTreeComponent():
 
@@ -135,7 +136,7 @@ class DecisionTreeComponent():
         self.explanation.append(html.H4("Instance : \n"))
         self.explanation.append(html.P(str([str(instance[i]) for i in range (len(instance))])))
         for k in explanation.keys() :
-            if k != "List of path explanation(s)":
+            if k != "List of path explanation(s)" and k!= "List of path contrastive explanation(s)" :
                 if k in ["List of abductive explanation(s)","List of contrastive explanation(s)"] :
                     self.explanation.append(html.H4(k))
                     for expl in explanation[k] :
@@ -146,8 +147,9 @@ class DecisionTreeComponent():
                     self.explanation.append(html.P(k + explanation[k]))
             else :
                 list_explanations_path = explanation["List of path explanation(s)"]
+                list_contrastive_explanations_path = explanation["List of path contrastive explanation(s)"]
 
-        return list_explanations_path
+        return list_explanations_path, list_contrastive_explanations_path
 
     def draw_explanation(self, instance, expl) :
         instance = self.translate_instance(instance)
@@ -157,3 +159,12 @@ class DecisionTreeComponent():
                                 style = {"width": "50%",
                                         "height": "80%",
                                         "background-color": "transparent"})])
+
+    def draw_contrastive_explanation(self, instance, cont_expl) :
+        instance = self.translate_instance(instance)
+        dot_source = visualize_contrastive_expl(self.dt, instance, cont_expl)
+        self.network = html.Div([dash_interactive_graphviz.DashInteractiveGraphviz(
+                                dot_source=dot_source, 
+                                style = {"width": "50%",
+                                        "height": "80%",
+                                        "background-color": "transparent"})])
diff --git a/pages/application/DecisionTree/utils/dtree.py b/pages/application/DecisionTree/utils/dtree.py
index 79088c2..c2775df 100644
--- a/pages/application/DecisionTree/utils/dtree.py
+++ b/pages/application/DecisionTree/utils/dtree.py
@@ -393,14 +393,18 @@ class DecisionTree():
                 done.append(target)
             return done
 
+        list_contrastive_expls = []
+
         to_hit = [set(s) for s in to_hit]
         to_hit.sort(key=lambda s: len(s))
         expls = list(reduce(process_set, to_hit, []))
         list_expls_str = []
         explanation = {}
         for expl in expls:
+            list_contrastive_expls.append([self.fvmap[(p[0],1-p[1])] for p in sorted(expl, key=lambda p: p[0])])
             list_expls_str.append('Contrastive: IF {0} THEN class!={1}'.format(' OR '.join(['!{0}'.format(self.fvmap[p]) for p in sorted(expl, key=lambda p: p[0])]), term))
 
+        explanation["List of path contrastive explanation(s)"] = list_contrastive_expls
         explanation["List of contrastive explanation(s)"] = list_expls_str
         explanation["Number of contrastive explanation(s) : "]=str(len(expls))
         explanation["Minimal contrastive explanation : "]= str( min([len(e) for e in expls]))
diff --git a/pages/application/DecisionTree/utils/dtviz.py b/pages/application/DecisionTree/utils/dtviz.py
index 21ea639..e859cc0 100755
--- a/pages/application/DecisionTree/utils/dtviz.py
+++ b/pages/application/DecisionTree/utils/dtviz.py
@@ -12,7 +12,8 @@
 #==============================================================================
 import getopt
 import pygraphviz
-
+import ast
+import re
 #
 #==============================================================================
 def create_legend(g):
@@ -194,3 +195,64 @@ def visualize_expl(dt, instance, expl):
     # saving file
     g.layout(prog='dot')
     return(g.to_string())
+
+#==============================================================================
+def visualize_contrastive_expl(dt, instance, cont_expl):
+    """
+        Visualize a DT with graphviz and plot the running instance.
+    """
+    #path that follows the instance - colored in blue
+    path, term, depth = dt.execute(instance)
+    edges_instance = []
+    for i in range (len(path)-1) :
+        edges_instance.append((path[i], path[i+1]))
+    edges_instance.append((path[-1],"term:"+term))
+
+    g = pygraphviz.AGraph(directed=True, strict=True)
+    g.edge_attr['dir'] = 'forward'
+
+    g.graph_attr['rankdir'] = 'TB'
+
+    # non-terminal nodes
+    for n in dt.nodes:
+        g.add_node(n, label=dt.feature_names[dt.nodes[n].feat])
+        node = g.get_node(n)
+        node.attr['shape'] = 'circle'
+        node.attr['fontsize'] = 13
+
+    # terminal nodes
+    for n in dt.terms:
+        g.add_node(n, label=dt.terms[n])
+        node = g.get_node(n)
+        node.attr['shape'] = 'square'
+        node.attr['fontsize'] = 13
+
+    # transitions
+    for n1 in dt.nodes:
+        for v in dt.nodes[n1].vals:
+            n2 = dt.nodes[n1].vals[v]
+            n2_type = g.get_node(n2).attr['shape']
+            g.add_edge(n1, n2)
+            edge = g.get_edge(n1, n2)
+            if len(v) == 1:
+                edge.attr['label'] = dt.fvmap[tuple([dt.nodes[n1].feat, tuple(v)[0]])]
+            else:
+                edge.attr['label'] = '{0}'.format('\n'.join([dt.fvmap[tuple([dt.nodes[n1].feat, val])] for val in tuple(v)]))
+            
+            #instance path in dashed
+            if ((n1,n2) in edges_instance) or (n2_type=='square' and (n1, "term:"+ dt.terms[n2]) in edges_instance): 
+                edge.attr['style'] = 'dashed'
+                    
+            for label in edge.attr['label'].split('\n'):
+                if label in cont_expl:
+                    edge.attr['color'] = 'red'
+
+            edge.attr['fontsize'] = 10
+            edge.attr['arrowsize'] = 0.8
+
+    g.add_subgraph(name='legend')
+    create_legend(g)
+
+    # saving file
+    g.layout(prog='dot')
+    return(g.to_string())
diff --git a/pages/application/application.py b/pages/application/application.py
index db54577..062cc14 100644
--- a/pages/application/application.py
+++ b/pages/application/application.py
@@ -31,7 +31,10 @@ class Model():
         self.instance = ''
 
         self.list_expls = []
-        self.expl_path = []
+        self.list_cont_expls = []
+
+        self.expl=''
+        self.cont_expl=''
 
         self.component_class = ''
         self.component = ''
@@ -56,24 +59,28 @@ class Model():
 
     def update_instance(self, instance):
         self.instance = instance
-        self.list_expls = self.component.update_with_explicability(self.instance, self.enum, self.xtype, self.solver)    
+        self.list_expls, self.list_cont_expls = self.component.update_with_explicability(self.instance, self.enum, self.xtype, self.solver)    
         
     def update_enum(self, enum):
         self.enum = enum
-        self.list_expls = self.component.update_with_explicability(self.instance, self.enum, self.xtype, self.solver)    
+        self.list_expls, self.list_cont_expls = self.component.update_with_explicability(self.instance, self.enum, self.xtype, self.solver)    
         
     def update_xtype(self, xtype):
         self.xtype = xtype
-        self.list_expls = self.component.update_with_explicability(self.instance, self.enum, self.xtype, self.solver)    
+        self.list_expls, self.list_cont_expls = self.component.update_with_explicability(self.instance, self.enum, self.xtype, self.solver)    
 
     def update_solver(self, solver):
         self.solver = solver
-        self.list_expls = self.component.update_with_explicability(self.instance, self.enum, self.xtype, self.solver)    
+        self.list_expls, self.list_cont_expls = self.component.update_with_explicability(self.instance, self.enum, self.xtype, self.solver)    
               
     def update_expl(self, expl):
         self.expl = expl
         self.component.draw_explanation(self.instance, expl)
 
+    def update_cont_expl(self, cont_expl):
+        self.expl = cont_expl
+        self.component.draw_contrastive_explanation(self.instance, cont_expl)
+
 class View():
 
     def __init__(self, model):
@@ -178,12 +185,16 @@ class View():
                                 ], className="sidebar")])
                             
  
-        self.expl_choice = html.Div([html.H5(id = "navigate_label", hidden=True, children="Navigate through the explanations and plot them on the tree : "),
-                                    html.Div(id='navigate_dropdown', hidden=True,
-                                    children = [dcc.Dropdown(self.model.list_expls,
+        self.expl_choice = html.Div(id = "interaction_graph", hidden=True, 
+                           children=[html.H5("Navigate through the explanations and plot them on the tree : "),
+                                    html.Div(children = [dcc.Dropdown(self.model.list_expls,
                                         id='expl_choice',  
+                                        className="dropdown")]),
+                                    html.H5("Navigate through the contrastive explanations and plot them on the tree : "),
+                                    html.Div(children = [dcc.Dropdown(self.model.list_cont_expls,
+                                        id='cont_expl_choice',  
                                         className="dropdown")])])
-                                            
+
         self.layout = dbc.Row([ dbc.Col([self.sidebar], 
                                         width=3, class_name="sidebar"), 
                                 dbc.Col([dbc.Row(id = "graph", children=[]),
diff --git a/utils.py b/utils.py
index 42b30c9..d9996f5 100644
--- a/utils.py
+++ b/utils.py
@@ -27,9 +27,9 @@ def parse_contents_data(contents, filename):
     decoded = base64.b64decode(content_string)
     try:        
         if '.csv' in filename:
-            data = decoded.decode('utf-8')
+            data = decoded.decode('utf-8').strip()
         if '.txt' in filename:
-            data = decoded.decode('utf-8')
+            data = decoded.decode('utf-8').strip()
     except Exception as e:
         print(e)
         return html.Div([
@@ -51,7 +51,11 @@ def parse_contents_instance(contents, filename):
             data = str(data).strip().split(',')
             data = list(map(lambda i: tuple([i[0], np.float32(i[1])]), [i.split('=') for i in data]))       
         elif '.json' in filename:
-            data = decoded.decode('utf-8')
+            data = decoded.decode('utf-8').strip()
+            data = json.loads(data)
+            data = list(tuple(data.items()))
+        elif '.inst' in filename:
+            data = decoded.decode('utf-8').strip()
             data = json.loads(data)
             data = list(tuple(data.items()))
     except Exception as e:
-- 
GitLab