Skip to content
Snippets Groups Projects

end Decision tree

Merged Caroline de Pourtalès requested to merge decision-tree-type-file into main
11 files
+ 1687
173
Compare changes
  • Side-by-side
  • Inline
Files
11
@@ -12,7 +12,8 @@
#==============================================================================
import getopt
import pygraphviz
import ast
import re
#
#==============================================================================
def create_legend(g):
@@ -22,6 +23,8 @@ def create_legend(g):
legend.add_node("b", style = "invis")
legend.add_node("c", style = "invis")
legend.add_node("d", style = "invis")
legend.add_node("e", style = "invis")
legend.add_node("f", style = "invis")
legend.add_edge("a","b")
edge = legend.get_edge("a","b")
@@ -34,7 +37,10 @@ def create_legend(g):
edge.attr["color"] = "blue"
edge.attr["style"] = "dashed"
legend.add_edge("e","f")
edge = legend.get_edge("e","f")
edge.attr["label"] = "contrastive explanation"
edge.attr["color"] = "red"
#
#==============================================================================
def visualize(dt):
@@ -194,3 +200,64 @@ def visualize_expl(dt, instance, expl):
# saving file
g.layout(prog='dot')
return(g.to_string())
#==============================================================================
def visualize_contrastive_expl(dt, instance, cont_expl):
"""
Visualize a DT with graphviz and plot the running instance.
"""
#path that follows the instance - colored in blue
path, term, depth = dt.execute(instance)
edges_instance = []
for i in range (len(path)-1) :
edges_instance.append((path[i], path[i+1]))
edges_instance.append((path[-1],"term:"+term))
g = pygraphviz.AGraph(directed=True, strict=True)
g.edge_attr['dir'] = 'forward'
g.graph_attr['rankdir'] = 'TB'
# non-terminal nodes
for n in dt.nodes:
g.add_node(n, label=dt.feature_names[dt.nodes[n].feat])
node = g.get_node(n)
node.attr['shape'] = 'circle'
node.attr['fontsize'] = 13
# terminal nodes
for n in dt.terms:
g.add_node(n, label=dt.terms[n])
node = g.get_node(n)
node.attr['shape'] = 'square'
node.attr['fontsize'] = 13
# transitions
for n1 in dt.nodes:
for v in dt.nodes[n1].vals:
n2 = dt.nodes[n1].vals[v]
n2_type = g.get_node(n2).attr['shape']
g.add_edge(n1, n2)
edge = g.get_edge(n1, n2)
if len(v) == 1:
edge.attr['label'] = dt.fvmap[tuple([dt.nodes[n1].feat, tuple(v)[0]])]
else:
edge.attr['label'] = '{0}'.format('\n'.join([dt.fvmap[tuple([dt.nodes[n1].feat, val])] for val in tuple(v)]))
#instance path in dashed
if ((n1,n2) in edges_instance) or (n2_type=='square' and (n1, "term:"+ dt.terms[n2]) in edges_instance):
edge.attr['style'] = 'dashed'
for label in edge.attr['label'].split('\n'):
if label in cont_expl:
edge.attr['color'] = 'red'
edge.attr['fontsize'] = 10
edge.attr['arrowsize'] = 0.8
g.add_subgraph(name='legend')
create_legend(g)
# saving file
g.layout(prog='dot')
return(g.to_string())
Loading