From 24a0dd1f96ea24e5175c74ff5596fa0121368914 Mon Sep 17 00:00:00 2001
From: Caroline DE POURTALES <cdepourt@montana.irit.fr>
Date: Mon, 25 Apr 2022 14:25:51 +0200
Subject: [PATCH] update for pydot

---
 pages/application/DecisionTree/utils/dtviz.py | 289 ++++++++----------
 1 file changed, 134 insertions(+), 155 deletions(-)

diff --git a/pages/application/DecisionTree/utils/dtviz.py b/pages/application/DecisionTree/utils/dtviz.py
index cf71d2b..02aa847 100755
--- a/pages/application/DecisionTree/utils/dtviz.py
+++ b/pages/application/DecisionTree/utils/dtviz.py
@@ -1,5 +1,5 @@
 #!/usr/bin/env python
-#-*- coding:utf-8 -*-
+# -*- coding:utf-8 -*-
 ##
 ## dtviz.py
 ##
@@ -9,255 +9,234 @@
 ##
 
 #
-#==============================================================================
+# ==============================================================================
 import getopt
-import pygraphviz
+import pydot
 import ast
 import re
+
+
 #
-#==============================================================================
-def create_legend(g):
-    legend = g.subgraphs()[-1]
-    legend.graph_attr.update(size="2,2")    
-    legend.add_node("a", style = "invis")
-    legend.add_node("b", style = "invis")
-    legend.add_node("c", style = "invis")
-    legend.add_node("d", style = "invis")
-    legend.add_node("e", style = "invis")
-    legend.add_node("f", style = "invis")
-
-    legend.add_edge("a","b")
-    edge = legend.get_edge("a","b")
-    edge.attr["label"] = "instance"
-    edge.attr["style"] = "dashed" 
-
-    legend.add_edge("c","d")  
-    edge = legend.get_edge("c","d")
-    edge.attr["label"] = "instance with explanation"
-    edge.attr["color"] = "blue"
-    edge.attr["style"] = "dashed"
-
-    legend.add_edge("e","f")  
-    edge = legend.get_edge("e","f")
-    edge.attr["label"] = "contrastive explanation"
-    edge.attr["color"] = "red"
+# ==============================================================================
+def create_legend(G):
+    legend = pydot.Cluster('legend', rankdir="TB")
+
+    # non-terminal nodes
+    for n in ["a", "b", "c", "d", "e", "f"]:
+        node = pydot.Node(n, label=n)
+        legend.add_node(node)
+
+    edge = pydot.Edge("a", "b")
+    edge.obj_dict['attributes']["label"] = "instance"
+    edge.obj_dict['attributes']["style"] = "dashed"
+    legend.add_edge(edge)
+
+    edge = pydot.Edge("e", "f")
+    edge.obj_dict['attributes']["label"] = "contrastive explanation"
+    edge.obj_dict['attributes']["color"] = "red"
+    edge.obj_dict['attributes']["style"] = "dashed"
+    legend.add_edge(edge)
+
+    edge = pydot.Edge("c", "d")
+    edge.obj_dict['attributes']["label"] = "instance with explanation"
+    edge.obj_dict['attributes']["color"] = "blue"
+    edge.obj_dict['attributes']["style"] = "dashed"
+    legend.add_edge(edge)
+
+    G.add_subgraph(legend)
+
+
 #
-#==============================================================================
+# ==============================================================================
 def visualize(dt):
     """
         Visualize a DT with graphviz.
     """
 
-    g = pygraphviz.AGraph(directed=True, strict=True)
-    g.edge_attr['dir'] = 'forward'
-    g.graph_attr['rankdir'] = 'TB'
+    G = pydot.Dot('tree_total', graph_type='graph')
+
+    g = pydot.Cluster('tree', graph_type='graph')
 
     # non-terminal nodes
     for n in dt.nodes:
-        g.add_node(n, label=dt.feature_names[dt.nodes[n].feat])
-        node = g.get_node(n)
-        node.attr['shape'] = 'circle'
-        node.attr['fontsize'] = 13
+        node = pydot.Node(n, label=dt.feature_names[dt.nodes[n].feat], shape="circle")
+        g.add_node(node)
 
     # terminal nodes
     for n in dt.terms:
-        g.add_node(n, label=dt.terms[n])
-        node = g.get_node(n)
-        node.attr['shape'] = 'square'
-        node.attr['fontsize'] = 13
+        node = pydot.Node(n, label=dt.terms[n], shape="square")
+        g.add_node(node)
 
     # transitions
     for n1 in dt.nodes:
         for v in dt.nodes[n1].vals:
             n2 = dt.nodes[n1].vals[v]
-            g.add_edge(n1, n2)
-            edge = g.get_edge(n1, n2)
+            edge = pydot.Edge(n1, n2)
             if len(v) == 1:
-                edge.attr['label'] = dt.fvmap[tuple([dt.nodes[n1].feat, tuple(v)[0]])]
+                edge.obj_dict['attributes']['label'] = dt.fvmap[tuple([dt.nodes[n1].feat, tuple(v)[0]])]
             else:
-                edge.attr['label'] = '{0}'.format('\n'.join([dt.fvmap[tuple([dt.nodes[n1].feat, val])] for val in tuple(v)]))
-            edge.attr['fontsize'] = 10
-            edge.attr['arrowsize'] = 0.8
+                edge.obj_dict['attributes']['label'] = '{0}'.format(
+                    '\n'.join([dt.fvmap[tuple([dt.nodes[n1].feat, val])] for val in tuple(v)]))
+            edge.obj_dict['attributes']['fontsize'] = 10
+            edge.obj_dict['attributes']['arrowsize'] = 0.8
+
+            g.add_edge(edge)
+
+    G.add_subgraph(g)
+
+    return G.to_string()
 
-    # saving file
-    g.layout(prog='dot')
-    return(g.to_string())
 
 #
-#==============================================================================
+# ==============================================================================
 def visualize_instance(dt, instance):
     """
         Visualize a DT with graphviz and plot the running instance.
     """
-    #path that follows the instance - colored in blue
+    # path that follows the instance - colored in blue
     path, term, depth = dt.execute(instance)
     edges_instance = []
-    for i in range (len(path)-1) :
-        edges_instance.append((path[i], path[i+1]))
-    edges_instance.append((path[-1],"term:"+term))
+    for i in range(len(path) - 1):
+        edges_instance.append((path[i], path[i + 1]))
+    edges_instance.append((path[-1], "term:" + term))
+
+    G = pydot.Dot('tree_total', graph_type='graph')
 
-    g = pygraphviz.AGraph(directed=True, strict=True)
-    g.edge_attr['dir'] = 'forward'
-    g.graph_attr['rankdir'] = 'TB'
+    g = pydot.Cluster('tree', graph_type='graph')
 
     # non-terminal nodes
     for n in dt.nodes:
-        g.add_node(n, label=dt.feature_names[dt.nodes[n].feat])
-        node = g.get_node(n)
-        node.attr['shape'] = 'circle'
-        node.attr['fontsize'] = 13
+        node = pydot.Node(n, label=dt.feature_names[dt.nodes[n].feat], shape="circle")
+        g.add_node(node)
 
     # terminal nodes
     for n in dt.terms:
-        g.add_node(n, label=dt.terms[n])
-        node = g.get_node(n)
-        node.attr['shape'] = 'square'
-        node.attr['fontsize'] = 13
+        node = pydot.Node(n, label=dt.terms[n], shape="square")
+        g.add_node(node)
 
     # transitions
     for n1 in dt.nodes:
         for v in dt.nodes[n1].vals:
             n2 = dt.nodes[n1].vals[v]
-            n2_type = g.get_node(n2).attr['shape']
-
-            g.add_edge(n1, n2)
-            edge = g.get_edge(n1, n2)
+            n2_type = g.get_node(str(n2))[0].obj_dict['attributes']['shape']
+            edge = pydot.Edge(n1, n2)
             if len(v) == 1:
-                edge.attr['label'] = dt.fvmap[tuple([dt.nodes[n1].feat, tuple(v)[0]])]
+                edge.obj_dict['attributes']['label'] = dt.fvmap[tuple([dt.nodes[n1].feat, tuple(v)[0]])]
             else:
-                edge.attr['label'] = '{0}'.format('\n'.join([dt.fvmap[tuple([dt.nodes[n1].feat, val])] for val in tuple(v)]))
-            
-            #instance path in dashed
-            if ((n1,n2) in edges_instance) or (n2_type=='square' and (n1, "term:"+ dt.terms[n2]) in edges_instance): 
-                edge.attr['style'] = 'dashed'
+                edge.obj_dict['attributes']['label'] = '{0}'.format(
+                    '\n'.join([dt.fvmap[tuple([dt.nodes[n1].feat, val])] for val in tuple(v)]))
+            edge.obj_dict['attributes']['fontsize'] = 10
+            edge.obj_dict['attributes']['arrowsize'] = 0.8
 
-            edge.attr['fontsize'] = 10
-            edge.attr['arrowsize'] = 0.8
+            # instance path in dashed
+            if ((n1, n2) in edges_instance) or (n2_type == 'square' and (n1, "term:" + dt.terms[n2]) in edges_instance):
+                edge.obj_dict['attributes']['style'] = 'dashed'
 
-    g.add_subgraph(name='legend')
-    create_legend(g)
+            g.add_edge(edge)
 
-    # saving file
-    g.layout(prog='dot')
-    return(g.to_string())
+    create_legend(G)
+    G.add_subgraph(g)
 
-#==============================================================================
+    return G.to_string()
+
+
+# ==============================================================================
 def visualize_expl(dt, instance, expl):
     """
         Visualize a DT with graphviz and plot the running instance.
     """
-    #path that follows the instance - colored in blue
+    # path that follows the instance - colored in blue
     path, term, depth = dt.execute(instance)
     edges_instance = []
-    for i in range (len(path)-1) :
-        edges_instance.append((path[i], path[i+1]))
-    edges_instance.append((path[-1],"term:"+term))
-
-    g = pygraphviz.AGraph(directed=True, strict=True)
-    g.edge_attr['dir'] = 'forward'
+    for i in range(len(path) - 1):
+        edges_instance.append((path[i], path[i + 1]))
+    edges_instance.append((path[-1], "term:" + term))
 
-    g.graph_attr['rankdir'] = 'TB'
+    g = pydot.Dot('my_graph', graph_type='graph')
 
     # non-terminal nodes
     for n in dt.nodes:
-        g.add_node(n, label=dt.feature_names[dt.nodes[n].feat])
-        node = g.get_node(n)
-        node.attr['shape'] = 'circle'
-        node.attr['fontsize'] = 13
+        node = pydot.Node(n, label=dt.feature_names[dt.nodes[n].feat], shape="circle")
+        g.add_node(node)
 
     # terminal nodes
     for n in dt.terms:
-        g.add_node(n, label=dt.terms[n])
-        node = g.get_node(n)
-        node.attr['shape'] = 'square'
-        node.attr['fontsize'] = 13
+        node = pydot.Node(n, label=dt.terms[n], shape="square")
+        g.add_node(node)
 
     # transitions
     for n1 in dt.nodes:
         for v in dt.nodes[n1].vals:
             n2 = dt.nodes[n1].vals[v]
-            n2_type = g.get_node(n2).attr['shape']
-            g.add_edge(n1, n2)
-            edge = g.get_edge(n1, n2)
+            n2_type = g.get_node(str(n2))[0].obj_dict['attributes']['shape']
+            edge = pydot.Edge(n1, n2)
             if len(v) == 1:
-                edge.attr['label'] = dt.fvmap[tuple([dt.nodes[n1].feat, tuple(v)[0]])]
+                edge.obj_dict['attributes']['label'] = dt.fvmap[tuple([dt.nodes[n1].feat, tuple(v)[0]])]
             else:
-                edge.attr['label'] = '{0}'.format('\n'.join([dt.fvmap[tuple([dt.nodes[n1].feat, val])] for val in tuple(v)]))
-            
-            #instance path in dashed
-            if ((n1,n2) in edges_instance) or (n2_type=='square' and (n1, "term:"+ dt.terms[n2]) in edges_instance): 
-                edge.attr['style'] = 'dashed'
-            for label in edge.attr['label'].split('\n'):
+                edge.obj_dict['attributes']['label'] = '{0}'.format(
+                    '\n'.join([dt.fvmap[tuple([dt.nodes[n1].feat, val])] for val in tuple(v)]))
+            edge.obj_dict['attributes']['fontsize'] = 10
+            edge.obj_dict['attributes']['arrowsize'] = 0.8
+
+            # instance path in dashed
+            if ((n1, n2) in edges_instance) or (n2_type == 'square' and (n1, "term:" + dt.terms[n2]) in edges_instance):
+                edge.obj_dict['attributes']['style'] = 'dashed'
+
+            for label in edge.obj_dict['attributes']['label'].split('\n'):
                 if label in expl:
-                    edge.attr['color'] = 'blue'
+                    edge.obj_dict['attributes']['color'] = 'blue'
 
-            edge.attr['fontsize'] = 10
-            edge.attr['arrowsize'] = 0.8
+            g.add_edge(edge)
 
-    g.add_subgraph(name='legend')
-    create_legend(g)
+    return g.to_string()
 
-    # saving file
-    g.layout(prog='dot')
-    return(g.to_string())
 
-#==============================================================================
+# ==============================================================================
 def visualize_contrastive_expl(dt, instance, cont_expl):
     """
         Visualize a DT with graphviz and plot the running instance.
     """
-    #path that follows the instance - colored in blue
+    # path that follows the instance - colored in blue
     path, term, depth = dt.execute(instance)
     edges_instance = []
-    for i in range (len(path)-1) :
-        edges_instance.append((path[i], path[i+1]))
-    edges_instance.append((path[-1],"term:"+term))
-
-    g = pygraphviz.AGraph(directed=True, strict=True)
-    g.edge_attr['dir'] = 'forward'
+    for i in range(len(path) - 1):
+        edges_instance.append((path[i], path[i + 1]))
+    edges_instance.append((path[-1], "term:" + term))
 
-    g.graph_attr['rankdir'] = 'TB'
+    g = pydot.Dot('my_graph', graph_type='graph')
 
     # non-terminal nodes
     for n in dt.nodes:
-        g.add_node(n, label=dt.feature_names[dt.nodes[n].feat])
-        node = g.get_node(n)
-        node.attr['shape'] = 'circle'
-        node.attr['fontsize'] = 13
+        node = pydot.Node(n, label=dt.feature_names[dt.nodes[n].feat], shape="circle")
+        g.add_node(node)
 
     # terminal nodes
     for n in dt.terms:
-        g.add_node(n, label=dt.terms[n])
-        node = g.get_node(n)
-        node.attr['shape'] = 'square'
-        node.attr['fontsize'] = 13
+        node = pydot.Node(n, label=dt.terms[n], shape="square")
+        g.add_node(node)
 
     # transitions
     for n1 in dt.nodes:
         for v in dt.nodes[n1].vals:
             n2 = dt.nodes[n1].vals[v]
-            n2_type = g.get_node(n2).attr['shape']
-            g.add_edge(n1, n2)
-            edge = g.get_edge(n1, n2)
+            n2_type = g.get_node(str(n2))[0].obj_dict['attributes']['shape']
+            edge = pydot.Edge(n1, n2)
             if len(v) == 1:
-                edge.attr['label'] = dt.fvmap[tuple([dt.nodes[n1].feat, tuple(v)[0]])]
+                edge.obj_dict['attributes']['label'] = dt.fvmap[tuple([dt.nodes[n1].feat, tuple(v)[0]])]
             else:
-                edge.attr['label'] = '{0}'.format('\n'.join([dt.fvmap[tuple([dt.nodes[n1].feat, val])] for val in tuple(v)]))
-            
-            #instance path in dashed
-            if ((n1,n2) in edges_instance) or (n2_type=='square' and (n1, "term:"+ dt.terms[n2]) in edges_instance): 
-                edge.attr['style'] = 'dashed'
-                    
-            for label in edge.attr['label'].split('\n'):
-                if label in cont_expl:
-                    edge.attr['color'] = 'red'
+                edge.obj_dict['attributes']['label'] = '{0}'.format(
+                    '\n'.join([dt.fvmap[tuple([dt.nodes[n1].feat, val])] for val in tuple(v)]))
+            edge.obj_dict['attributes']['fontsize'] = 10
+            edge.obj_dict['attributes']['arrowsize'] = 0.8
+
+            # instance path in dashed
+            if ((n1, n2) in edges_instance) or (n2_type == 'square' and (n1, "term:" + dt.terms[n2]) in edges_instance):
+                edge.obj_dict['attributes']['style'] = 'dashed'
 
-            edge.attr['fontsize'] = 10
-            edge.attr['arrowsize'] = 0.8
+            for label in edge.obj_dict['attributes']['label'].split('\n'):
+                if label in cont_expl:
+                    edge.obj_dict['attributes']['color'] = 'red'
 
-    g.add_subgraph(name='legend')
-    create_legend(g)
+            g.add_edge(edge)
 
-    # saving file
-    g.layout(prog='dot')
-    return(g.to_string())
+    return g.to_string()
-- 
GitLab