-
Caroline DE POURTALES authoredCaroline DE POURTALES authored
dtviz.py 6.05 KiB
#!/usr/bin/env python
#-*- coding:utf-8 -*-
##
## dtviz.py
##
## Created on: Jul 7, 2020
## Author: Alexey Ignatiev
## E-mail: alexey.ignatiev@monash.edu
##
#
#==============================================================================
from pages.application.DecisionTree.utils.dtree import DecisionTree
import getopt
import os
import pygraphviz
import sys
#
#==============================================================================
def create_legend(g):
legend = g.subgraphs()[-1]
legend.add_node("a", style = "invis")
legend.add_node("b", style = "invis")
legend.add_node("c", style = "invis")
legend.add_node("d", style = "invis")
legend.add_edge("a","b")
edge = legend.get_edge("a","b")
edge.attr["label"] = "instance"
edge.attr["style"] = "dashed"
legend.add_edge("c","d")
edge = legend.get_edge("c","d")
edge.attr["label"] = "instance with explanation"
edge.attr["color"] = "blue"
edge.attr["style"] = "dashed"
#
#==============================================================================
def visualize(dt):
"""
Visualize a DT with graphviz.
"""
g = pygraphviz.AGraph(directed=True, strict=True)
g.edge_attr['dir'] = 'forward'
g.graph_attr['rankdir'] = 'TB'
# non-terminal nodes
for n in dt.nodes:
g.add_node(n, label=dt.features_names[dt.nodes[n].feat])
node = g.get_node(n)
node.attr['shape'] = 'circle'
node.attr['fontsize'] = 13
# terminal nodes
for n in dt.terms:
g.add_node(n, label=dt.terms[n])
node = g.get_node(n)
node.attr['shape'] = 'square'
node.attr['fontsize'] = 13
# transitions
for n1 in dt.nodes:
for v in dt.nodes[n1].vals:
n2 = dt.nodes[n1].vals[v]
g.add_edge(n1, n2)
edge = g.get_edge(n1, n2)
if len(v) == 1:
edge.attr['label'] = dt.fvmap[tuple([dt.nodes[n1].feat, tuple(v)[0]])]
else:
edge.attr['label'] = '{0}'.format('\n'.join([dt.fvmap[tuple([dt.nodes[n1].feat, val])] for val in tuple(v)]))
edge.attr['fontsize'] = 10
edge.attr['arrowsize'] = 0.8
# saving file
g.layout(prog='dot')
return(g.to_string())
#
#==============================================================================
def visualize_instance(dt, instance):
"""
Visualize a DT with graphviz and plot the running instance.
"""
#path that follows the instance - colored in blue
path, term, depth = dt.execute(instance)
edges_instance = []
for i in range (len(path)-1) :
edges_instance.append((path[i], path[i+1]))
edges_instance.append((path[-1],"term:"+term))
g = pygraphviz.AGraph(directed=True, strict=True)
g.edge_attr['dir'] = 'forward'
g.graph_attr['rankdir'] = 'TB'
# non-terminal nodes
for n in dt.nodes:
g.add_node(n, label=dt.features_names[dt.nodes[n].feat])
node = g.get_node(n)
node.attr['shape'] = 'circle'
node.attr['fontsize'] = 13
# terminal nodes
for n in dt.terms:
g.add_node(n, label=dt.terms[n])
node = g.get_node(n)
node.attr['shape'] = 'square'
node.attr['fontsize'] = 13
# transitions
for n1 in dt.nodes:
for v in dt.nodes[n1].vals:
n2 = dt.nodes[n1].vals[v]
n2_type = g.get_node(n2).attr['shape']
g.add_edge(n1, n2)
edge = g.get_edge(n1, n2)
if len(v) == 1:
edge.attr['label'] = dt.fvmap[tuple([dt.nodes[n1].feat, tuple(v)[0]])]
else:
edge.attr['label'] = '{0}'.format('\n'.join([dt.fvmap[tuple([dt.nodes[n1].feat, val])] for val in tuple(v)]))
#instance path in dashed
if ((n1,n2) in edges_instance) or (n2_type=='square' and (n1, "term:"+ dt.terms[n2]) in edges_instance):
edge.attr['style'] = 'dashed'
edge.attr['fontsize'] = 10
edge.attr['arrowsize'] = 0.8
g.add_subgraph(name='legend')
create_legend(g)
# saving file
g.layout(prog='dot')
return(g.to_string())
#==============================================================================
def visualize_expl(dt, instance, expl):
"""
Visualize a DT with graphviz and plot the running instance.
"""
#path that follows the instance - colored in blue
path, term, depth = dt.execute(instance)
edges_instance = []
for i in range (len(path)-1) :
edges_instance.append((path[i], path[i+1]))
edges_instance.append((path[-1],"term:"+term))
g = pygraphviz.AGraph(directed=True, strict=True)
g.edge_attr['dir'] = 'forward'
g.graph_attr['rankdir'] = 'TB'
# non-terminal nodes
for n in dt.nodes:
g.add_node(n, label=dt.features_names[dt.nodes[n].feat])
node = g.get_node(n)
node.attr['shape'] = 'circle'
node.attr['fontsize'] = 13
# terminal nodes
for n in dt.terms:
g.add_node(n, label=dt.terms[n])
node = g.get_node(n)
node.attr['shape'] = 'square'
node.attr['fontsize'] = 13
# transitions
for n1 in dt.nodes:
for v in dt.nodes[n1].vals:
n2 = dt.nodes[n1].vals[v]
n2_type = g.get_node(n2).attr['shape']
g.add_edge(n1, n2)
edge = g.get_edge(n1, n2)
if len(v) == 1:
edge.attr['label'] = dt.fvmap[tuple([dt.nodes[n1].feat, tuple(v)[0]])]
else:
edge.attr['label'] = '{0}'.format('\n'.join([dt.fvmap[tuple([dt.nodes[n1].feat, val])] for val in tuple(v)]))
#instance path in dashed
if ((n1,n2) in edges_instance) or (n2_type=='square' and (n1, "term:"+ dt.terms[n2]) in edges_instance):
edge.attr['style'] = 'dashed'
if edge.attr['label'] in expl:
edge.attr['color'] = 'blue'
edge.attr['fontsize'] = 10
edge.attr['arrowsize'] = 0.8
g.add_subgraph(name='legend')
create_legend(g)
# saving file
g.layout(prog='dot')
return(g.to_string())