Skip to content
Snippets Groups Projects
dtviz.py 6.05 KiB
#!/usr/bin/env python
#-*- coding:utf-8 -*-
##
## dtviz.py
##
##  Created on: Jul 7, 2020
##      Author: Alexey Ignatiev
##      E-mail: alexey.ignatiev@monash.edu
##

#
#==============================================================================
from pages.application.DecisionTree.utils.dtree import DecisionTree
import getopt
import os
import pygraphviz
import sys

#
#==============================================================================
def create_legend(g):
    legend = g.subgraphs()[-1]
    legend.add_node("a", style = "invis")
    legend.add_node("b", style = "invis")
    legend.add_node("c", style = "invis")
    legend.add_node("d", style = "invis")

    legend.add_edge("a","b")
    edge = legend.get_edge("a","b")
    edge.attr["label"] = "instance"
    edge.attr["style"] = "dashed" 

    legend.add_edge("c","d")  
    edge = legend.get_edge("c","d")
    edge.attr["label"] = "instance with explanation"
    edge.attr["color"] = "blue"
    edge.attr["style"] = "dashed"


#
#==============================================================================
def visualize(dt):
    """
        Visualize a DT with graphviz.
    """

    g = pygraphviz.AGraph(directed=True, strict=True)
    g.edge_attr['dir'] = 'forward'
    g.graph_attr['rankdir'] = 'TB'

    # non-terminal nodes
    for n in dt.nodes:
        g.add_node(n, label=dt.features_names[dt.nodes[n].feat])
        node = g.get_node(n)
        node.attr['shape'] = 'circle'
        node.attr['fontsize'] = 13

    # terminal nodes
    for n in dt.terms:
        g.add_node(n, label=dt.terms[n])
        node = g.get_node(n)
        node.attr['shape'] = 'square'
        node.attr['fontsize'] = 13

    # transitions
    for n1 in dt.nodes:
        for v in dt.nodes[n1].vals:
            n2 = dt.nodes[n1].vals[v]
            g.add_edge(n1, n2)
            edge = g.get_edge(n1, n2)
            if len(v) == 1:
                edge.attr['label'] = dt.fvmap[tuple([dt.nodes[n1].feat, tuple(v)[0]])]
            else:
                edge.attr['label'] = '{0}'.format('\n'.join([dt.fvmap[tuple([dt.nodes[n1].feat, val])] for val in tuple(v)]))
            edge.attr['fontsize'] = 10
            edge.attr['arrowsize'] = 0.8

    # saving file
    g.layout(prog='dot')
    return(g.to_string())

#
#==============================================================================
def visualize_instance(dt, instance):
    """
        Visualize a DT with graphviz and plot the running instance.
    """
    #path that follows the instance - colored in blue
    path, term, depth = dt.execute(instance)
    edges_instance = []
    for i in range (len(path)-1) :
        edges_instance.append((path[i], path[i+1]))
    edges_instance.append((path[-1],"term:"+term))

    g = pygraphviz.AGraph(directed=True, strict=True)
    g.edge_attr['dir'] = 'forward'
    g.graph_attr['rankdir'] = 'TB'

    # non-terminal nodes
    for n in dt.nodes:
        g.add_node(n, label=dt.features_names[dt.nodes[n].feat])
        node = g.get_node(n)
        node.attr['shape'] = 'circle'
        node.attr['fontsize'] = 13

    # terminal nodes
    for n in dt.terms:
        g.add_node(n, label=dt.terms[n])
        node = g.get_node(n)
        node.attr['shape'] = 'square'
        node.attr['fontsize'] = 13

    # transitions
    for n1 in dt.nodes:
        for v in dt.nodes[n1].vals:
            n2 = dt.nodes[n1].vals[v]
            n2_type = g.get_node(n2).attr['shape']

            g.add_edge(n1, n2)
            edge = g.get_edge(n1, n2)
            if len(v) == 1:
                edge.attr['label'] = dt.fvmap[tuple([dt.nodes[n1].feat, tuple(v)[0]])]
            else:
                edge.attr['label'] = '{0}'.format('\n'.join([dt.fvmap[tuple([dt.nodes[n1].feat, val])] for val in tuple(v)]))
            
            #instance path in dashed
            if ((n1,n2) in edges_instance) or (n2_type=='square' and (n1, "term:"+ dt.terms[n2]) in edges_instance): 
                edge.attr['style'] = 'dashed'

            edge.attr['fontsize'] = 10
            edge.attr['arrowsize'] = 0.8

    g.add_subgraph(name='legend')
    create_legend(g)

    # saving file
    g.layout(prog='dot')
    return(g.to_string())

#==============================================================================
def visualize_expl(dt, instance, expl):
    """
        Visualize a DT with graphviz and plot the running instance.
    """
    #path that follows the instance - colored in blue
    path, term, depth = dt.execute(instance)
    edges_instance = []
    for i in range (len(path)-1) :
        edges_instance.append((path[i], path[i+1]))
    edges_instance.append((path[-1],"term:"+term))

    g = pygraphviz.AGraph(directed=True, strict=True)
    g.edge_attr['dir'] = 'forward'

    g.graph_attr['rankdir'] = 'TB'

    # non-terminal nodes
    for n in dt.nodes:
        g.add_node(n, label=dt.features_names[dt.nodes[n].feat])
        node = g.get_node(n)
        node.attr['shape'] = 'circle'
        node.attr['fontsize'] = 13

    # terminal nodes
    for n in dt.terms:
        g.add_node(n, label=dt.terms[n])
        node = g.get_node(n)
        node.attr['shape'] = 'square'
        node.attr['fontsize'] = 13

    # transitions
    for n1 in dt.nodes:
        for v in dt.nodes[n1].vals:
            n2 = dt.nodes[n1].vals[v]
            n2_type = g.get_node(n2).attr['shape']
            g.add_edge(n1, n2)
            edge = g.get_edge(n1, n2)
            if len(v) == 1:
                edge.attr['label'] = dt.fvmap[tuple([dt.nodes[n1].feat, tuple(v)[0]])]
            else:
                edge.attr['label'] = '{0}'.format('\n'.join([dt.fvmap[tuple([dt.nodes[n1].feat, val])] for val in tuple(v)]))
            
            #instance path in dashed
            if ((n1,n2) in edges_instance) or (n2_type=='square' and (n1, "term:"+ dt.terms[n2]) in edges_instance): 
                edge.attr['style'] = 'dashed'
            if edge.attr['label'] in expl:
                edge.attr['color'] = 'blue'

            edge.attr['fontsize'] = 10
            edge.attr['arrowsize'] = 0.8

    g.add_subgraph(name='legend')
    create_legend(g)

    # saving file
    g.layout(prog='dot')
    return(g.to_string())