Skip to content
Snippets Groups Projects
Commit 24a0dd1f authored by Caroline DE POURTALES's avatar Caroline DE POURTALES
Browse files

update for pydot

parent 22f8fdd0
No related tags found
No related merge requests found
#!/usr/bin/env python
#-*- coding:utf-8 -*-
# -*- coding:utf-8 -*-
##
## dtviz.py
##
......@@ -9,255 +9,234 @@
##
#
#==============================================================================
# ==============================================================================
import getopt
import pygraphviz
import pydot
import ast
import re
#
#==============================================================================
def create_legend(g):
legend = g.subgraphs()[-1]
legend.graph_attr.update(size="2,2")
legend.add_node("a", style = "invis")
legend.add_node("b", style = "invis")
legend.add_node("c", style = "invis")
legend.add_node("d", style = "invis")
legend.add_node("e", style = "invis")
legend.add_node("f", style = "invis")
legend.add_edge("a","b")
edge = legend.get_edge("a","b")
edge.attr["label"] = "instance"
edge.attr["style"] = "dashed"
legend.add_edge("c","d")
edge = legend.get_edge("c","d")
edge.attr["label"] = "instance with explanation"
edge.attr["color"] = "blue"
edge.attr["style"] = "dashed"
legend.add_edge("e","f")
edge = legend.get_edge("e","f")
edge.attr["label"] = "contrastive explanation"
edge.attr["color"] = "red"
# ==============================================================================
def create_legend(G):
legend = pydot.Cluster('legend', rankdir="TB")
# non-terminal nodes
for n in ["a", "b", "c", "d", "e", "f"]:
node = pydot.Node(n, label=n)
legend.add_node(node)
edge = pydot.Edge("a", "b")
edge.obj_dict['attributes']["label"] = "instance"
edge.obj_dict['attributes']["style"] = "dashed"
legend.add_edge(edge)
edge = pydot.Edge("e", "f")
edge.obj_dict['attributes']["label"] = "contrastive explanation"
edge.obj_dict['attributes']["color"] = "red"
edge.obj_dict['attributes']["style"] = "dashed"
legend.add_edge(edge)
edge = pydot.Edge("c", "d")
edge.obj_dict['attributes']["label"] = "instance with explanation"
edge.obj_dict['attributes']["color"] = "blue"
edge.obj_dict['attributes']["style"] = "dashed"
legend.add_edge(edge)
G.add_subgraph(legend)
#
#==============================================================================
# ==============================================================================
def visualize(dt):
"""
Visualize a DT with graphviz.
"""
g = pygraphviz.AGraph(directed=True, strict=True)
g.edge_attr['dir'] = 'forward'
g.graph_attr['rankdir'] = 'TB'
G = pydot.Dot('tree_total', graph_type='graph')
g = pydot.Cluster('tree', graph_type='graph')
# non-terminal nodes
for n in dt.nodes:
g.add_node(n, label=dt.feature_names[dt.nodes[n].feat])
node = g.get_node(n)
node.attr['shape'] = 'circle'
node.attr['fontsize'] = 13
node = pydot.Node(n, label=dt.feature_names[dt.nodes[n].feat], shape="circle")
g.add_node(node)
# terminal nodes
for n in dt.terms:
g.add_node(n, label=dt.terms[n])
node = g.get_node(n)
node.attr['shape'] = 'square'
node.attr['fontsize'] = 13
node = pydot.Node(n, label=dt.terms[n], shape="square")
g.add_node(node)
# transitions
for n1 in dt.nodes:
for v in dt.nodes[n1].vals:
n2 = dt.nodes[n1].vals[v]
g.add_edge(n1, n2)
edge = g.get_edge(n1, n2)
edge = pydot.Edge(n1, n2)
if len(v) == 1:
edge.attr['label'] = dt.fvmap[tuple([dt.nodes[n1].feat, tuple(v)[0]])]
edge.obj_dict['attributes']['label'] = dt.fvmap[tuple([dt.nodes[n1].feat, tuple(v)[0]])]
else:
edge.attr['label'] = '{0}'.format('\n'.join([dt.fvmap[tuple([dt.nodes[n1].feat, val])] for val in tuple(v)]))
edge.attr['fontsize'] = 10
edge.attr['arrowsize'] = 0.8
edge.obj_dict['attributes']['label'] = '{0}'.format(
'\n'.join([dt.fvmap[tuple([dt.nodes[n1].feat, val])] for val in tuple(v)]))
edge.obj_dict['attributes']['fontsize'] = 10
edge.obj_dict['attributes']['arrowsize'] = 0.8
g.add_edge(edge)
G.add_subgraph(g)
return G.to_string()
# saving file
g.layout(prog='dot')
return(g.to_string())
#
#==============================================================================
# ==============================================================================
def visualize_instance(dt, instance):
"""
Visualize a DT with graphviz and plot the running instance.
"""
#path that follows the instance - colored in blue
# path that follows the instance - colored in blue
path, term, depth = dt.execute(instance)
edges_instance = []
for i in range (len(path)-1) :
edges_instance.append((path[i], path[i+1]))
edges_instance.append((path[-1],"term:"+term))
for i in range(len(path) - 1):
edges_instance.append((path[i], path[i + 1]))
edges_instance.append((path[-1], "term:" + term))
G = pydot.Dot('tree_total', graph_type='graph')
g = pygraphviz.AGraph(directed=True, strict=True)
g.edge_attr['dir'] = 'forward'
g.graph_attr['rankdir'] = 'TB'
g = pydot.Cluster('tree', graph_type='graph')
# non-terminal nodes
for n in dt.nodes:
g.add_node(n, label=dt.feature_names[dt.nodes[n].feat])
node = g.get_node(n)
node.attr['shape'] = 'circle'
node.attr['fontsize'] = 13
node = pydot.Node(n, label=dt.feature_names[dt.nodes[n].feat], shape="circle")
g.add_node(node)
# terminal nodes
for n in dt.terms:
g.add_node(n, label=dt.terms[n])
node = g.get_node(n)
node.attr['shape'] = 'square'
node.attr['fontsize'] = 13
node = pydot.Node(n, label=dt.terms[n], shape="square")
g.add_node(node)
# transitions
for n1 in dt.nodes:
for v in dt.nodes[n1].vals:
n2 = dt.nodes[n1].vals[v]
n2_type = g.get_node(n2).attr['shape']
g.add_edge(n1, n2)
edge = g.get_edge(n1, n2)
n2_type = g.get_node(str(n2))[0].obj_dict['attributes']['shape']
edge = pydot.Edge(n1, n2)
if len(v) == 1:
edge.attr['label'] = dt.fvmap[tuple([dt.nodes[n1].feat, tuple(v)[0]])]
edge.obj_dict['attributes']['label'] = dt.fvmap[tuple([dt.nodes[n1].feat, tuple(v)[0]])]
else:
edge.attr['label'] = '{0}'.format('\n'.join([dt.fvmap[tuple([dt.nodes[n1].feat, val])] for val in tuple(v)]))
#instance path in dashed
if ((n1,n2) in edges_instance) or (n2_type=='square' and (n1, "term:"+ dt.terms[n2]) in edges_instance):
edge.attr['style'] = 'dashed'
edge.obj_dict['attributes']['label'] = '{0}'.format(
'\n'.join([dt.fvmap[tuple([dt.nodes[n1].feat, val])] for val in tuple(v)]))
edge.obj_dict['attributes']['fontsize'] = 10
edge.obj_dict['attributes']['arrowsize'] = 0.8
edge.attr['fontsize'] = 10
edge.attr['arrowsize'] = 0.8
# instance path in dashed
if ((n1, n2) in edges_instance) or (n2_type == 'square' and (n1, "term:" + dt.terms[n2]) in edges_instance):
edge.obj_dict['attributes']['style'] = 'dashed'
g.add_subgraph(name='legend')
create_legend(g)
g.add_edge(edge)
# saving file
g.layout(prog='dot')
return(g.to_string())
create_legend(G)
G.add_subgraph(g)
#==============================================================================
return G.to_string()
# ==============================================================================
def visualize_expl(dt, instance, expl):
"""
Visualize a DT with graphviz and plot the running instance.
"""
#path that follows the instance - colored in blue
# path that follows the instance - colored in blue
path, term, depth = dt.execute(instance)
edges_instance = []
for i in range (len(path)-1) :
edges_instance.append((path[i], path[i+1]))
edges_instance.append((path[-1],"term:"+term))
g = pygraphviz.AGraph(directed=True, strict=True)
g.edge_attr['dir'] = 'forward'
for i in range(len(path) - 1):
edges_instance.append((path[i], path[i + 1]))
edges_instance.append((path[-1], "term:" + term))
g.graph_attr['rankdir'] = 'TB'
g = pydot.Dot('my_graph', graph_type='graph')
# non-terminal nodes
for n in dt.nodes:
g.add_node(n, label=dt.feature_names[dt.nodes[n].feat])
node = g.get_node(n)
node.attr['shape'] = 'circle'
node.attr['fontsize'] = 13
node = pydot.Node(n, label=dt.feature_names[dt.nodes[n].feat], shape="circle")
g.add_node(node)
# terminal nodes
for n in dt.terms:
g.add_node(n, label=dt.terms[n])
node = g.get_node(n)
node.attr['shape'] = 'square'
node.attr['fontsize'] = 13
node = pydot.Node(n, label=dt.terms[n], shape="square")
g.add_node(node)
# transitions
for n1 in dt.nodes:
for v in dt.nodes[n1].vals:
n2 = dt.nodes[n1].vals[v]
n2_type = g.get_node(n2).attr['shape']
g.add_edge(n1, n2)
edge = g.get_edge(n1, n2)
n2_type = g.get_node(str(n2))[0].obj_dict['attributes']['shape']
edge = pydot.Edge(n1, n2)
if len(v) == 1:
edge.attr['label'] = dt.fvmap[tuple([dt.nodes[n1].feat, tuple(v)[0]])]
edge.obj_dict['attributes']['label'] = dt.fvmap[tuple([dt.nodes[n1].feat, tuple(v)[0]])]
else:
edge.attr['label'] = '{0}'.format('\n'.join([dt.fvmap[tuple([dt.nodes[n1].feat, val])] for val in tuple(v)]))
#instance path in dashed
if ((n1,n2) in edges_instance) or (n2_type=='square' and (n1, "term:"+ dt.terms[n2]) in edges_instance):
edge.attr['style'] = 'dashed'
for label in edge.attr['label'].split('\n'):
edge.obj_dict['attributes']['label'] = '{0}'.format(
'\n'.join([dt.fvmap[tuple([dt.nodes[n1].feat, val])] for val in tuple(v)]))
edge.obj_dict['attributes']['fontsize'] = 10
edge.obj_dict['attributes']['arrowsize'] = 0.8
# instance path in dashed
if ((n1, n2) in edges_instance) or (n2_type == 'square' and (n1, "term:" + dt.terms[n2]) in edges_instance):
edge.obj_dict['attributes']['style'] = 'dashed'
for label in edge.obj_dict['attributes']['label'].split('\n'):
if label in expl:
edge.attr['color'] = 'blue'
edge.obj_dict['attributes']['color'] = 'blue'
edge.attr['fontsize'] = 10
edge.attr['arrowsize'] = 0.8
g.add_edge(edge)
g.add_subgraph(name='legend')
create_legend(g)
return g.to_string()
# saving file
g.layout(prog='dot')
return(g.to_string())
#==============================================================================
# ==============================================================================
def visualize_contrastive_expl(dt, instance, cont_expl):
"""
Visualize a DT with graphviz and plot the running instance.
"""
#path that follows the instance - colored in blue
# path that follows the instance - colored in blue
path, term, depth = dt.execute(instance)
edges_instance = []
for i in range (len(path)-1) :
edges_instance.append((path[i], path[i+1]))
edges_instance.append((path[-1],"term:"+term))
g = pygraphviz.AGraph(directed=True, strict=True)
g.edge_attr['dir'] = 'forward'
for i in range(len(path) - 1):
edges_instance.append((path[i], path[i + 1]))
edges_instance.append((path[-1], "term:" + term))
g.graph_attr['rankdir'] = 'TB'
g = pydot.Dot('my_graph', graph_type='graph')
# non-terminal nodes
for n in dt.nodes:
g.add_node(n, label=dt.feature_names[dt.nodes[n].feat])
node = g.get_node(n)
node.attr['shape'] = 'circle'
node.attr['fontsize'] = 13
node = pydot.Node(n, label=dt.feature_names[dt.nodes[n].feat], shape="circle")
g.add_node(node)
# terminal nodes
for n in dt.terms:
g.add_node(n, label=dt.terms[n])
node = g.get_node(n)
node.attr['shape'] = 'square'
node.attr['fontsize'] = 13
node = pydot.Node(n, label=dt.terms[n], shape="square")
g.add_node(node)
# transitions
for n1 in dt.nodes:
for v in dt.nodes[n1].vals:
n2 = dt.nodes[n1].vals[v]
n2_type = g.get_node(n2).attr['shape']
g.add_edge(n1, n2)
edge = g.get_edge(n1, n2)
n2_type = g.get_node(str(n2))[0].obj_dict['attributes']['shape']
edge = pydot.Edge(n1, n2)
if len(v) == 1:
edge.attr['label'] = dt.fvmap[tuple([dt.nodes[n1].feat, tuple(v)[0]])]
edge.obj_dict['attributes']['label'] = dt.fvmap[tuple([dt.nodes[n1].feat, tuple(v)[0]])]
else:
edge.attr['label'] = '{0}'.format('\n'.join([dt.fvmap[tuple([dt.nodes[n1].feat, val])] for val in tuple(v)]))
#instance path in dashed
if ((n1,n2) in edges_instance) or (n2_type=='square' and (n1, "term:"+ dt.terms[n2]) in edges_instance):
edge.attr['style'] = 'dashed'
for label in edge.attr['label'].split('\n'):
if label in cont_expl:
edge.attr['color'] = 'red'
edge.obj_dict['attributes']['label'] = '{0}'.format(
'\n'.join([dt.fvmap[tuple([dt.nodes[n1].feat, val])] for val in tuple(v)]))
edge.obj_dict['attributes']['fontsize'] = 10
edge.obj_dict['attributes']['arrowsize'] = 0.8
# instance path in dashed
if ((n1, n2) in edges_instance) or (n2_type == 'square' and (n1, "term:" + dt.terms[n2]) in edges_instance):
edge.obj_dict['attributes']['style'] = 'dashed'
edge.attr['fontsize'] = 10
edge.attr['arrowsize'] = 0.8
for label in edge.obj_dict['attributes']['label'].split('\n'):
if label in cont_expl:
edge.obj_dict['attributes']['color'] = 'red'
g.add_subgraph(name='legend')
create_legend(g)
g.add_edge(edge)
# saving file
g.layout(prog='dot')
return(g.to_string())
return g.to_string()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment