Skip to content
Snippets Groups Projects
Commit 24a0dd1f authored by Caroline DE POURTALES's avatar Caroline DE POURTALES
Browse files

update for pydot

parent 22f8fdd0
No related tags found
No related merge requests found
#!/usr/bin/env python #!/usr/bin/env python
#-*- coding:utf-8 -*- # -*- coding:utf-8 -*-
## ##
## dtviz.py ## dtviz.py
## ##
...@@ -9,255 +9,234 @@ ...@@ -9,255 +9,234 @@
## ##
# #
#============================================================================== # ==============================================================================
import getopt import getopt
import pygraphviz import pydot
import ast import ast
import re import re
# #
#============================================================================== # ==============================================================================
def create_legend(g): def create_legend(G):
legend = g.subgraphs()[-1] legend = pydot.Cluster('legend', rankdir="TB")
legend.graph_attr.update(size="2,2")
legend.add_node("a", style = "invis") # non-terminal nodes
legend.add_node("b", style = "invis") for n in ["a", "b", "c", "d", "e", "f"]:
legend.add_node("c", style = "invis") node = pydot.Node(n, label=n)
legend.add_node("d", style = "invis") legend.add_node(node)
legend.add_node("e", style = "invis")
legend.add_node("f", style = "invis") edge = pydot.Edge("a", "b")
edge.obj_dict['attributes']["label"] = "instance"
legend.add_edge("a","b") edge.obj_dict['attributes']["style"] = "dashed"
edge = legend.get_edge("a","b") legend.add_edge(edge)
edge.attr["label"] = "instance"
edge.attr["style"] = "dashed" edge = pydot.Edge("e", "f")
edge.obj_dict['attributes']["label"] = "contrastive explanation"
legend.add_edge("c","d") edge.obj_dict['attributes']["color"] = "red"
edge = legend.get_edge("c","d") edge.obj_dict['attributes']["style"] = "dashed"
edge.attr["label"] = "instance with explanation" legend.add_edge(edge)
edge.attr["color"] = "blue"
edge.attr["style"] = "dashed" edge = pydot.Edge("c", "d")
edge.obj_dict['attributes']["label"] = "instance with explanation"
legend.add_edge("e","f") edge.obj_dict['attributes']["color"] = "blue"
edge = legend.get_edge("e","f") edge.obj_dict['attributes']["style"] = "dashed"
edge.attr["label"] = "contrastive explanation" legend.add_edge(edge)
edge.attr["color"] = "red"
G.add_subgraph(legend)
# #
#============================================================================== # ==============================================================================
def visualize(dt): def visualize(dt):
""" """
Visualize a DT with graphviz. Visualize a DT with graphviz.
""" """
g = pygraphviz.AGraph(directed=True, strict=True) G = pydot.Dot('tree_total', graph_type='graph')
g.edge_attr['dir'] = 'forward'
g.graph_attr['rankdir'] = 'TB' g = pydot.Cluster('tree', graph_type='graph')
# non-terminal nodes # non-terminal nodes
for n in dt.nodes: for n in dt.nodes:
g.add_node(n, label=dt.feature_names[dt.nodes[n].feat]) node = pydot.Node(n, label=dt.feature_names[dt.nodes[n].feat], shape="circle")
node = g.get_node(n) g.add_node(node)
node.attr['shape'] = 'circle'
node.attr['fontsize'] = 13
# terminal nodes # terminal nodes
for n in dt.terms: for n in dt.terms:
g.add_node(n, label=dt.terms[n]) node = pydot.Node(n, label=dt.terms[n], shape="square")
node = g.get_node(n) g.add_node(node)
node.attr['shape'] = 'square'
node.attr['fontsize'] = 13
# transitions # transitions
for n1 in dt.nodes: for n1 in dt.nodes:
for v in dt.nodes[n1].vals: for v in dt.nodes[n1].vals:
n2 = dt.nodes[n1].vals[v] n2 = dt.nodes[n1].vals[v]
g.add_edge(n1, n2) edge = pydot.Edge(n1, n2)
edge = g.get_edge(n1, n2)
if len(v) == 1: if len(v) == 1:
edge.attr['label'] = dt.fvmap[tuple([dt.nodes[n1].feat, tuple(v)[0]])] edge.obj_dict['attributes']['label'] = dt.fvmap[tuple([dt.nodes[n1].feat, tuple(v)[0]])]
else: else:
edge.attr['label'] = '{0}'.format('\n'.join([dt.fvmap[tuple([dt.nodes[n1].feat, val])] for val in tuple(v)])) edge.obj_dict['attributes']['label'] = '{0}'.format(
edge.attr['fontsize'] = 10 '\n'.join([dt.fvmap[tuple([dt.nodes[n1].feat, val])] for val in tuple(v)]))
edge.attr['arrowsize'] = 0.8 edge.obj_dict['attributes']['fontsize'] = 10
edge.obj_dict['attributes']['arrowsize'] = 0.8
g.add_edge(edge)
G.add_subgraph(g)
return G.to_string()
# saving file
g.layout(prog='dot')
return(g.to_string())
# #
#============================================================================== # ==============================================================================
def visualize_instance(dt, instance): def visualize_instance(dt, instance):
""" """
Visualize a DT with graphviz and plot the running instance. Visualize a DT with graphviz and plot the running instance.
""" """
#path that follows the instance - colored in blue # path that follows the instance - colored in blue
path, term, depth = dt.execute(instance) path, term, depth = dt.execute(instance)
edges_instance = [] edges_instance = []
for i in range (len(path)-1) : for i in range(len(path) - 1):
edges_instance.append((path[i], path[i+1])) edges_instance.append((path[i], path[i + 1]))
edges_instance.append((path[-1],"term:"+term)) edges_instance.append((path[-1], "term:" + term))
G = pydot.Dot('tree_total', graph_type='graph')
g = pygraphviz.AGraph(directed=True, strict=True) g = pydot.Cluster('tree', graph_type='graph')
g.edge_attr['dir'] = 'forward'
g.graph_attr['rankdir'] = 'TB'
# non-terminal nodes # non-terminal nodes
for n in dt.nodes: for n in dt.nodes:
g.add_node(n, label=dt.feature_names[dt.nodes[n].feat]) node = pydot.Node(n, label=dt.feature_names[dt.nodes[n].feat], shape="circle")
node = g.get_node(n) g.add_node(node)
node.attr['shape'] = 'circle'
node.attr['fontsize'] = 13
# terminal nodes # terminal nodes
for n in dt.terms: for n in dt.terms:
g.add_node(n, label=dt.terms[n]) node = pydot.Node(n, label=dt.terms[n], shape="square")
node = g.get_node(n) g.add_node(node)
node.attr['shape'] = 'square'
node.attr['fontsize'] = 13
# transitions # transitions
for n1 in dt.nodes: for n1 in dt.nodes:
for v in dt.nodes[n1].vals: for v in dt.nodes[n1].vals:
n2 = dt.nodes[n1].vals[v] n2 = dt.nodes[n1].vals[v]
n2_type = g.get_node(n2).attr['shape'] n2_type = g.get_node(str(n2))[0].obj_dict['attributes']['shape']
edge = pydot.Edge(n1, n2)
g.add_edge(n1, n2)
edge = g.get_edge(n1, n2)
if len(v) == 1: if len(v) == 1:
edge.attr['label'] = dt.fvmap[tuple([dt.nodes[n1].feat, tuple(v)[0]])] edge.obj_dict['attributes']['label'] = dt.fvmap[tuple([dt.nodes[n1].feat, tuple(v)[0]])]
else: else:
edge.attr['label'] = '{0}'.format('\n'.join([dt.fvmap[tuple([dt.nodes[n1].feat, val])] for val in tuple(v)])) edge.obj_dict['attributes']['label'] = '{0}'.format(
'\n'.join([dt.fvmap[tuple([dt.nodes[n1].feat, val])] for val in tuple(v)]))
#instance path in dashed edge.obj_dict['attributes']['fontsize'] = 10
if ((n1,n2) in edges_instance) or (n2_type=='square' and (n1, "term:"+ dt.terms[n2]) in edges_instance): edge.obj_dict['attributes']['arrowsize'] = 0.8
edge.attr['style'] = 'dashed'
edge.attr['fontsize'] = 10 # instance path in dashed
edge.attr['arrowsize'] = 0.8 if ((n1, n2) in edges_instance) or (n2_type == 'square' and (n1, "term:" + dt.terms[n2]) in edges_instance):
edge.obj_dict['attributes']['style'] = 'dashed'
g.add_subgraph(name='legend') g.add_edge(edge)
create_legend(g)
# saving file create_legend(G)
g.layout(prog='dot') G.add_subgraph(g)
return(g.to_string())
#============================================================================== return G.to_string()
# ==============================================================================
def visualize_expl(dt, instance, expl): def visualize_expl(dt, instance, expl):
""" """
Visualize a DT with graphviz and plot the running instance. Visualize a DT with graphviz and plot the running instance.
""" """
#path that follows the instance - colored in blue # path that follows the instance - colored in blue
path, term, depth = dt.execute(instance) path, term, depth = dt.execute(instance)
edges_instance = [] edges_instance = []
for i in range (len(path)-1) : for i in range(len(path) - 1):
edges_instance.append((path[i], path[i+1])) edges_instance.append((path[i], path[i + 1]))
edges_instance.append((path[-1],"term:"+term)) edges_instance.append((path[-1], "term:" + term))
g = pygraphviz.AGraph(directed=True, strict=True)
g.edge_attr['dir'] = 'forward'
g.graph_attr['rankdir'] = 'TB' g = pydot.Dot('my_graph', graph_type='graph')
# non-terminal nodes # non-terminal nodes
for n in dt.nodes: for n in dt.nodes:
g.add_node(n, label=dt.feature_names[dt.nodes[n].feat]) node = pydot.Node(n, label=dt.feature_names[dt.nodes[n].feat], shape="circle")
node = g.get_node(n) g.add_node(node)
node.attr['shape'] = 'circle'
node.attr['fontsize'] = 13
# terminal nodes # terminal nodes
for n in dt.terms: for n in dt.terms:
g.add_node(n, label=dt.terms[n]) node = pydot.Node(n, label=dt.terms[n], shape="square")
node = g.get_node(n) g.add_node(node)
node.attr['shape'] = 'square'
node.attr['fontsize'] = 13
# transitions # transitions
for n1 in dt.nodes: for n1 in dt.nodes:
for v in dt.nodes[n1].vals: for v in dt.nodes[n1].vals:
n2 = dt.nodes[n1].vals[v] n2 = dt.nodes[n1].vals[v]
n2_type = g.get_node(n2).attr['shape'] n2_type = g.get_node(str(n2))[0].obj_dict['attributes']['shape']
g.add_edge(n1, n2) edge = pydot.Edge(n1, n2)
edge = g.get_edge(n1, n2)
if len(v) == 1: if len(v) == 1:
edge.attr['label'] = dt.fvmap[tuple([dt.nodes[n1].feat, tuple(v)[0]])] edge.obj_dict['attributes']['label'] = dt.fvmap[tuple([dt.nodes[n1].feat, tuple(v)[0]])]
else: else:
edge.attr['label'] = '{0}'.format('\n'.join([dt.fvmap[tuple([dt.nodes[n1].feat, val])] for val in tuple(v)])) edge.obj_dict['attributes']['label'] = '{0}'.format(
'\n'.join([dt.fvmap[tuple([dt.nodes[n1].feat, val])] for val in tuple(v)]))
#instance path in dashed edge.obj_dict['attributes']['fontsize'] = 10
if ((n1,n2) in edges_instance) or (n2_type=='square' and (n1, "term:"+ dt.terms[n2]) in edges_instance): edge.obj_dict['attributes']['arrowsize'] = 0.8
edge.attr['style'] = 'dashed'
for label in edge.attr['label'].split('\n'): # instance path in dashed
if ((n1, n2) in edges_instance) or (n2_type == 'square' and (n1, "term:" + dt.terms[n2]) in edges_instance):
edge.obj_dict['attributes']['style'] = 'dashed'
for label in edge.obj_dict['attributes']['label'].split('\n'):
if label in expl: if label in expl:
edge.attr['color'] = 'blue' edge.obj_dict['attributes']['color'] = 'blue'
edge.attr['fontsize'] = 10 g.add_edge(edge)
edge.attr['arrowsize'] = 0.8
g.add_subgraph(name='legend') return g.to_string()
create_legend(g)
# saving file
g.layout(prog='dot')
return(g.to_string())
#============================================================================== # ==============================================================================
def visualize_contrastive_expl(dt, instance, cont_expl): def visualize_contrastive_expl(dt, instance, cont_expl):
""" """
Visualize a DT with graphviz and plot the running instance. Visualize a DT with graphviz and plot the running instance.
""" """
#path that follows the instance - colored in blue # path that follows the instance - colored in blue
path, term, depth = dt.execute(instance) path, term, depth = dt.execute(instance)
edges_instance = [] edges_instance = []
for i in range (len(path)-1) : for i in range(len(path) - 1):
edges_instance.append((path[i], path[i+1])) edges_instance.append((path[i], path[i + 1]))
edges_instance.append((path[-1],"term:"+term)) edges_instance.append((path[-1], "term:" + term))
g = pygraphviz.AGraph(directed=True, strict=True)
g.edge_attr['dir'] = 'forward'
g.graph_attr['rankdir'] = 'TB' g = pydot.Dot('my_graph', graph_type='graph')
# non-terminal nodes # non-terminal nodes
for n in dt.nodes: for n in dt.nodes:
g.add_node(n, label=dt.feature_names[dt.nodes[n].feat]) node = pydot.Node(n, label=dt.feature_names[dt.nodes[n].feat], shape="circle")
node = g.get_node(n) g.add_node(node)
node.attr['shape'] = 'circle'
node.attr['fontsize'] = 13
# terminal nodes # terminal nodes
for n in dt.terms: for n in dt.terms:
g.add_node(n, label=dt.terms[n]) node = pydot.Node(n, label=dt.terms[n], shape="square")
node = g.get_node(n) g.add_node(node)
node.attr['shape'] = 'square'
node.attr['fontsize'] = 13
# transitions # transitions
for n1 in dt.nodes: for n1 in dt.nodes:
for v in dt.nodes[n1].vals: for v in dt.nodes[n1].vals:
n2 = dt.nodes[n1].vals[v] n2 = dt.nodes[n1].vals[v]
n2_type = g.get_node(n2).attr['shape'] n2_type = g.get_node(str(n2))[0].obj_dict['attributes']['shape']
g.add_edge(n1, n2) edge = pydot.Edge(n1, n2)
edge = g.get_edge(n1, n2)
if len(v) == 1: if len(v) == 1:
edge.attr['label'] = dt.fvmap[tuple([dt.nodes[n1].feat, tuple(v)[0]])] edge.obj_dict['attributes']['label'] = dt.fvmap[tuple([dt.nodes[n1].feat, tuple(v)[0]])]
else: else:
edge.attr['label'] = '{0}'.format('\n'.join([dt.fvmap[tuple([dt.nodes[n1].feat, val])] for val in tuple(v)])) edge.obj_dict['attributes']['label'] = '{0}'.format(
'\n'.join([dt.fvmap[tuple([dt.nodes[n1].feat, val])] for val in tuple(v)]))
#instance path in dashed edge.obj_dict['attributes']['fontsize'] = 10
if ((n1,n2) in edges_instance) or (n2_type=='square' and (n1, "term:"+ dt.terms[n2]) in edges_instance): edge.obj_dict['attributes']['arrowsize'] = 0.8
edge.attr['style'] = 'dashed'
# instance path in dashed
for label in edge.attr['label'].split('\n'): if ((n1, n2) in edges_instance) or (n2_type == 'square' and (n1, "term:" + dt.terms[n2]) in edges_instance):
if label in cont_expl: edge.obj_dict['attributes']['style'] = 'dashed'
edge.attr['color'] = 'red'
edge.attr['fontsize'] = 10 for label in edge.obj_dict['attributes']['label'].split('\n'):
edge.attr['arrowsize'] = 0.8 if label in cont_expl:
edge.obj_dict['attributes']['color'] = 'red'
g.add_subgraph(name='legend') g.add_edge(edge)
create_legend(g)
# saving file return g.to_string()
g.layout(prog='dot')
return(g.to_string())
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment