Skip to content
Snippets Groups Projects
Commit 24196341 authored by Caroline DE POURTALES's avatar Caroline DE POURTALES
Browse files

needs conversion between pkl and dt

parent b43c3c05
No related branches found
No related tags found
1 merge request!3Decision tree done
...@@ -22,7 +22,7 @@ class DecisionTreeComponent(): ...@@ -22,7 +22,7 @@ class DecisionTreeComponent():
except: except:
print("You did not dump the model with the features names") print("You did not dump the model with the features names")
feature_names = [str(i) for i in range(tree.n_features_in_)] feature_names = [str(i) for i in range(tree.n_features_in_)]
self.uploaded_dt = UploadedDecisionTree(tree, 'SKL', filename_tree, maxdepth=3, feature_names=feature_names) self.uploaded_dt = UploadedDecisionTree(tree, 'SKL', filename_tree, maxdepth=tree.get_depth(), feature_names=feature_names, nb_classes=tree.n_classes_)
#need a function that takes as input UploadedDecisionTree and gives DecisionTree #need a function that takes as input UploadedDecisionTree and gives DecisionTree
#self.dt = DecisionTree(from_dt=) #self.dt = DecisionTree(from_dt=)
......
...@@ -261,7 +261,7 @@ class DecisionTree(): ...@@ -261,7 +261,7 @@ class DecisionTree():
#decision path #decision path
decision_path_str = 'IF {0} THEN class={1}'.format(' AND '.join([self.fvmap[inst[self.feids[self.nodes[n].feat]] ] for n in path]), term) decision_path_str = 'IF {0} THEN class={1}'.format(' AND '.join([self.fvmap[inst[self.feids[self.nodes[n].feat]] ] for n in path]), term)
explanation_dic["Decision path of instance : "] = decision_path_str explanation_dic["Decision path of instance : "] = decision_path_str
explanation_dic["Decision path length : "] = 'Path length is :'+ str(len(path)) explanation_dic["Decision path length : "] = 'Path length is :'+ str(depth)
# computing the sets to hit # computing the sets to hit
to_hit = self.prepare_sets(inst, term) to_hit = self.prepare_sets(inst, term)
......
...@@ -16,6 +16,26 @@ import os ...@@ -16,6 +16,26 @@ import os
import pygraphviz import pygraphviz
import sys import sys
#
#==============================================================================
def create_legend(g):
legend = g.subgraphs()[-1]
legend.add_node("a", style = "invis")
legend.add_node("b", style = "invis")
legend.add_node("c", style = "invis")
legend.add_node("d", style = "invis")
legend.add_edge("a","b")
edge = legend.get_edge("a","b")
edge.attr["label"] = "instance"
edge.attr["style"] = "dashed"
legend.add_edge("c","d")
edge = legend.get_edge("c","d")
edge.attr["label"] = "instance with explanation"
edge.attr["color"] = "blue"
edge.attr["style"] = "dashed"
# #
#============================================================================== #==============================================================================
...@@ -26,19 +46,18 @@ def visualize(dt): ...@@ -26,19 +46,18 @@ def visualize(dt):
g = pygraphviz.AGraph(directed=True, strict=True) g = pygraphviz.AGraph(directed=True, strict=True)
g.edge_attr['dir'] = 'forward' g.edge_attr['dir'] = 'forward'
g.graph_attr['rankdir'] = 'TB' g.graph_attr['rankdir'] = 'TB'
# non-terminal nodes # non-terminal nodes
for n in dt.nodes: for n in dt.nodes:
g.add_node(n, label='{0}\\n({1})'.format(dt.nodes[n].feat, n)) g.add_node(n, label=dt.nodes[n].feat)
node = g.get_node(n) node = g.get_node(n)
node.attr['shape'] = 'circle' node.attr['shape'] = 'circle'
node.attr['fontsize'] = 13 node.attr['fontsize'] = 13
# terminal nodes # terminal nodes
for n in dt.terms: for n in dt.terms:
g.add_node(n, label='{0}\\n({1})'.format(dt.terms[n], n)) g.add_node(n, label=dt.terms[n])
node = g.get_node(n) node = g.get_node(n)
node.attr['shape'] = 'square' node.attr['shape'] = 'square'
node.attr['fontsize'] = 13 node.attr['fontsize'] = 13
...@@ -57,7 +76,6 @@ def visualize(dt): ...@@ -57,7 +76,6 @@ def visualize(dt):
edge.attr['arrowsize'] = 0.8 edge.attr['arrowsize'] = 0.8
# saving file # saving file
g.in_edges
g.layout(prog='dot') g.layout(prog='dot')
return(g.to_string()) return(g.to_string())
...@@ -76,19 +94,18 @@ def visualize_instance(dt, instance): ...@@ -76,19 +94,18 @@ def visualize_instance(dt, instance):
g = pygraphviz.AGraph(directed=True, strict=True) g = pygraphviz.AGraph(directed=True, strict=True)
g.edge_attr['dir'] = 'forward' g.edge_attr['dir'] = 'forward'
g.graph_attr['rankdir'] = 'TB' g.graph_attr['rankdir'] = 'TB'
# non-terminal nodes # non-terminal nodes
for n in dt.nodes: for n in dt.nodes:
g.add_node(n, label='{0}\\n({1})'.format(dt.nodes[n].feat, n)) g.add_node(n, label=dt.nodes[n].feat)
node = g.get_node(n) node = g.get_node(n)
node.attr['shape'] = 'circle' node.attr['shape'] = 'circle'
node.attr['fontsize'] = 13 node.attr['fontsize'] = 13
# terminal nodes # terminal nodes
for n in dt.terms: for n in dt.terms:
g.add_node(n, label='{0}\\n({1})'.format(dt.terms[n], n)) g.add_node(n, label=dt.terms[n])
node = g.get_node(n) node = g.get_node(n)
node.attr['shape'] = 'square' node.attr['shape'] = 'square'
node.attr['fontsize'] = 13 node.attr['fontsize'] = 13
...@@ -98,6 +115,7 @@ def visualize_instance(dt, instance): ...@@ -98,6 +115,7 @@ def visualize_instance(dt, instance):
for v in dt.nodes[n1].vals: for v in dt.nodes[n1].vals:
n2 = dt.nodes[n1].vals[v] n2 = dt.nodes[n1].vals[v]
n2_type = g.get_node(n2).attr['shape'] n2_type = g.get_node(n2).attr['shape']
g.add_edge(n1, n2) g.add_edge(n1, n2)
edge = g.get_edge(n1, n2) edge = g.get_edge(n1, n2)
if len(v) == 1: if len(v) == 1:
...@@ -105,13 +123,16 @@ def visualize_instance(dt, instance): ...@@ -105,13 +123,16 @@ def visualize_instance(dt, instance):
else: else:
edge.attr['label'] = '{0}'.format('\n'.join([dt.fvmap[tuple([dt.nodes[n1].feat, val])] for val in tuple(v)])) edge.attr['label'] = '{0}'.format('\n'.join([dt.fvmap[tuple([dt.nodes[n1].feat, val])] for val in tuple(v)]))
#instance path in blue #instance path in dashed
if ((n1,n2) in edges_instance) or (n2_type=='square' and (n1, "term:"+ dt.terms[n2]) in edges_instance): if ((n1,n2) in edges_instance) or (n2_type=='square' and (n1, "term:"+ dt.terms[n2]) in edges_instance):
edge.attr['color'] = 'blue' edge.attr['style'] = 'dashed'
edge.attr['fontsize'] = 10 edge.attr['fontsize'] = 10
edge.attr['arrowsize'] = 0.8 edge.attr['arrowsize'] = 0.8
g.add_subgraph(name='legend')
create_legend(g)
# saving file # saving file
g.layout(prog='dot') g.layout(prog='dot')
return(g.to_string()) return(g.to_string())
...@@ -121,12 +142,6 @@ def visualize_expl(dt, instance, expl): ...@@ -121,12 +142,6 @@ def visualize_expl(dt, instance, expl):
""" """
Visualize a DT with graphviz and plot the running instance. Visualize a DT with graphviz and plot the running instance.
""" """
if '=' in instance[0]:
instance = list(map(lambda i: tuple([i[0], int(i[1])]), [i.split('=') for i in instance]))
else:
instance = list(map(lambda i : tuple(['f{0}'.format(i[0]), int(i[1])]), [(i, j) for i,j in enumerate(instance)]))
#path that follows the instance - colored in blue #path that follows the instance - colored in blue
path, term, depth = dt.execute(instance) path, term, depth = dt.execute(instance)
edges_instance = [] edges_instance = []
...@@ -141,14 +156,14 @@ def visualize_expl(dt, instance, expl): ...@@ -141,14 +156,14 @@ def visualize_expl(dt, instance, expl):
# non-terminal nodes # non-terminal nodes
for n in dt.nodes: for n in dt.nodes:
g.add_node(n, label='{0}\\n({1})'.format(dt.nodes[n].feat, n)) g.add_node(n, label=dt.nodes[n].feat)
node = g.get_node(n) node = g.get_node(n)
node.attr['shape'] = 'circle' node.attr['shape'] = 'circle'
node.attr['fontsize'] = 13 node.attr['fontsize'] = 13
# terminal nodes # terminal nodes
for n in dt.terms: for n in dt.terms:
g.add_node(n, label='{0}\\n({1})'.format(dt.terms[n], n)) g.add_node(n, label=dt.terms[n])
node = g.get_node(n) node = g.get_node(n)
node.attr['shape'] = 'square' node.attr['shape'] = 'square'
node.attr['fontsize'] = 13 node.attr['fontsize'] = 13
...@@ -165,13 +180,18 @@ def visualize_expl(dt, instance, expl): ...@@ -165,13 +180,18 @@ def visualize_expl(dt, instance, expl):
else: else:
edge.attr['label'] = '{0}'.format('\n'.join([dt.fvmap[tuple([dt.nodes[n1].feat, val])] for val in tuple(v)])) edge.attr['label'] = '{0}'.format('\n'.join([dt.fvmap[tuple([dt.nodes[n1].feat, val])] for val in tuple(v)]))
#instance path in blue #instance path in dashed
if ((n1,n2) in edges_instance) or (n2_type=='square' and (n1, "term:"+ dt.terms[n2]) in edges_instance): if ((n1,n2) in edges_instance) or (n2_type=='square' and (n1, "term:"+ dt.terms[n2]) in edges_instance):
edge.attr['style'] = 'dashed'
if edge.attr['label'] in expl:
edge.attr['color'] = 'blue' edge.attr['color'] = 'blue'
edge.attr['fontsize'] = 10 edge.attr['fontsize'] = 10
edge.attr['arrowsize'] = 0.8 edge.attr['arrowsize'] = 0.8
g.add_subgraph(name='legend')
create_legend(g)
# saving file # saving file
g.layout(prog='dot') g.layout(prog='dot')
return(g.to_string()) return(g.to_string())
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment