Skip to content
Snippets Groups Projects
Commit 24196341 authored by Caroline DE POURTALES's avatar Caroline DE POURTALES
Browse files

needs conversion between pkl and dt

parent b43c3c05
Branches
No related tags found
1 merge request!3Decision tree done
......@@ -22,7 +22,7 @@ class DecisionTreeComponent():
except:
print("You did not dump the model with the features names")
feature_names = [str(i) for i in range(tree.n_features_in_)]
self.uploaded_dt = UploadedDecisionTree(tree, 'SKL', filename_tree, maxdepth=3, feature_names=feature_names)
self.uploaded_dt = UploadedDecisionTree(tree, 'SKL', filename_tree, maxdepth=tree.get_depth(), feature_names=feature_names, nb_classes=tree.n_classes_)
#need a function that takes as input UploadedDecisionTree and gives DecisionTree
#self.dt = DecisionTree(from_dt=)
......
......@@ -261,7 +261,7 @@ class DecisionTree():
#decision path
decision_path_str = 'IF {0} THEN class={1}'.format(' AND '.join([self.fvmap[inst[self.feids[self.nodes[n].feat]] ] for n in path]), term)
explanation_dic["Decision path of instance : "] = decision_path_str
explanation_dic["Decision path length : "] = 'Path length is :'+ str(len(path))
explanation_dic["Decision path length : "] = 'Path length is :'+ str(depth)
# computing the sets to hit
to_hit = self.prepare_sets(inst, term)
......
......@@ -16,6 +16,26 @@ import os
import pygraphviz
import sys
#
#==============================================================================
def create_legend(g):
legend = g.subgraphs()[-1]
legend.add_node("a", style = "invis")
legend.add_node("b", style = "invis")
legend.add_node("c", style = "invis")
legend.add_node("d", style = "invis")
legend.add_edge("a","b")
edge = legend.get_edge("a","b")
edge.attr["label"] = "instance"
edge.attr["style"] = "dashed"
legend.add_edge("c","d")
edge = legend.get_edge("c","d")
edge.attr["label"] = "instance with explanation"
edge.attr["color"] = "blue"
edge.attr["style"] = "dashed"
#
#==============================================================================
......@@ -26,19 +46,18 @@ def visualize(dt):
g = pygraphviz.AGraph(directed=True, strict=True)
g.edge_attr['dir'] = 'forward'
g.graph_attr['rankdir'] = 'TB'
# non-terminal nodes
for n in dt.nodes:
g.add_node(n, label='{0}\\n({1})'.format(dt.nodes[n].feat, n))
g.add_node(n, label=dt.nodes[n].feat)
node = g.get_node(n)
node.attr['shape'] = 'circle'
node.attr['fontsize'] = 13
# terminal nodes
for n in dt.terms:
g.add_node(n, label='{0}\\n({1})'.format(dt.terms[n], n))
g.add_node(n, label=dt.terms[n])
node = g.get_node(n)
node.attr['shape'] = 'square'
node.attr['fontsize'] = 13
......@@ -57,7 +76,6 @@ def visualize(dt):
edge.attr['arrowsize'] = 0.8
# saving file
g.in_edges
g.layout(prog='dot')
return(g.to_string())
......@@ -76,19 +94,18 @@ def visualize_instance(dt, instance):
g = pygraphviz.AGraph(directed=True, strict=True)
g.edge_attr['dir'] = 'forward'
g.graph_attr['rankdir'] = 'TB'
# non-terminal nodes
for n in dt.nodes:
g.add_node(n, label='{0}\\n({1})'.format(dt.nodes[n].feat, n))
g.add_node(n, label=dt.nodes[n].feat)
node = g.get_node(n)
node.attr['shape'] = 'circle'
node.attr['fontsize'] = 13
# terminal nodes
for n in dt.terms:
g.add_node(n, label='{0}\\n({1})'.format(dt.terms[n], n))
g.add_node(n, label=dt.terms[n])
node = g.get_node(n)
node.attr['shape'] = 'square'
node.attr['fontsize'] = 13
......@@ -98,6 +115,7 @@ def visualize_instance(dt, instance):
for v in dt.nodes[n1].vals:
n2 = dt.nodes[n1].vals[v]
n2_type = g.get_node(n2).attr['shape']
g.add_edge(n1, n2)
edge = g.get_edge(n1, n2)
if len(v) == 1:
......@@ -105,13 +123,16 @@ def visualize_instance(dt, instance):
else:
edge.attr['label'] = '{0}'.format('\n'.join([dt.fvmap[tuple([dt.nodes[n1].feat, val])] for val in tuple(v)]))
#instance path in blue
#instance path in dashed
if ((n1,n2) in edges_instance) or (n2_type=='square' and (n1, "term:"+ dt.terms[n2]) in edges_instance):
edge.attr['color'] = 'blue'
edge.attr['style'] = 'dashed'
edge.attr['fontsize'] = 10
edge.attr['arrowsize'] = 0.8
g.add_subgraph(name='legend')
create_legend(g)
# saving file
g.layout(prog='dot')
return(g.to_string())
......@@ -121,12 +142,6 @@ def visualize_expl(dt, instance, expl):
"""
Visualize a DT with graphviz and plot the running instance.
"""
if '=' in instance[0]:
instance = list(map(lambda i: tuple([i[0], int(i[1])]), [i.split('=') for i in instance]))
else:
instance = list(map(lambda i : tuple(['f{0}'.format(i[0]), int(i[1])]), [(i, j) for i,j in enumerate(instance)]))
#path that follows the instance - colored in blue
path, term, depth = dt.execute(instance)
edges_instance = []
......@@ -141,14 +156,14 @@ def visualize_expl(dt, instance, expl):
# non-terminal nodes
for n in dt.nodes:
g.add_node(n, label='{0}\\n({1})'.format(dt.nodes[n].feat, n))
g.add_node(n, label=dt.nodes[n].feat)
node = g.get_node(n)
node.attr['shape'] = 'circle'
node.attr['fontsize'] = 13
# terminal nodes
for n in dt.terms:
g.add_node(n, label='{0}\\n({1})'.format(dt.terms[n], n))
g.add_node(n, label=dt.terms[n])
node = g.get_node(n)
node.attr['shape'] = 'square'
node.attr['fontsize'] = 13
......@@ -165,13 +180,18 @@ def visualize_expl(dt, instance, expl):
else:
edge.attr['label'] = '{0}'.format('\n'.join([dt.fvmap[tuple([dt.nodes[n1].feat, val])] for val in tuple(v)]))
#instance path in blue
#instance path in dashed
if ((n1,n2) in edges_instance) or (n2_type=='square' and (n1, "term:"+ dt.terms[n2]) in edges_instance):
edge.attr['style'] = 'dashed'
if edge.attr['label'] in expl:
edge.attr['color'] = 'blue'
edge.attr['fontsize'] = 10
edge.attr['arrowsize'] = 0.8
g.add_subgraph(name='legend')
create_legend(g)
# saving file
g.layout(prog='dot')
return(g.to_string())
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment