diff --git a/pages/application/DecisionTree/DecisionTreeComponent.py b/pages/application/DecisionTree/DecisionTreeComponent.py index 2347cb40f18a599809f1985f4afcfdd1c0556dd0..9f681bff56465035278eb70d42de08914a5a43f4 100644 --- a/pages/application/DecisionTree/DecisionTreeComponent.py +++ b/pages/application/DecisionTree/DecisionTreeComponent.py @@ -22,7 +22,7 @@ class DecisionTreeComponent(): except: print("You did not dump the model with the features names") feature_names = [str(i) for i in range(tree.n_features_in_)] - self.uploaded_dt = UploadedDecisionTree(tree, 'SKL', filename_tree, maxdepth=3, feature_names=feature_names) + self.uploaded_dt = UploadedDecisionTree(tree, 'SKL', filename_tree, maxdepth=tree.get_depth(), feature_names=feature_names, nb_classes=tree.n_classes_) #need a function that takes as input UploadedDecisionTree and gives DecisionTree #self.dt = DecisionTree(from_dt=) diff --git a/pages/application/DecisionTree/utils/dtree.py b/pages/application/DecisionTree/utils/dtree.py index c002840616a5133570b6b0d9945e128f59d06d11..d3bd0582ed1da896f96aaa8626d482cae0418343 100644 --- a/pages/application/DecisionTree/utils/dtree.py +++ b/pages/application/DecisionTree/utils/dtree.py @@ -261,7 +261,7 @@ class DecisionTree(): #decision path decision_path_str = 'IF {0} THEN class={1}'.format(' AND '.join([self.fvmap[inst[self.feids[self.nodes[n].feat]] ] for n in path]), term) explanation_dic["Decision path of instance : "] = decision_path_str - explanation_dic["Decision path length : "] = 'Path length is :'+ str(len(path)) + explanation_dic["Decision path length : "] = 'Path length is :'+ str(depth) # computing the sets to hit to_hit = self.prepare_sets(inst, term) diff --git a/pages/application/DecisionTree/utils/dtviz.py b/pages/application/DecisionTree/utils/dtviz.py index f5a490856ce572b28849a09a90c15ad67853b7d6..a54d879ea2944a2de0c58b8c4e2baa4b40eee3c6 100755 --- a/pages/application/DecisionTree/utils/dtviz.py +++ b/pages/application/DecisionTree/utils/dtviz.py @@ -16,6 +16,26 @@ import os import pygraphviz import sys +# +#============================================================================== +def create_legend(g): + legend = g.subgraphs()[-1] + legend.add_node("a", style = "invis") + legend.add_node("b", style = "invis") + legend.add_node("c", style = "invis") + legend.add_node("d", style = "invis") + + legend.add_edge("a","b") + edge = legend.get_edge("a","b") + edge.attr["label"] = "instance" + edge.attr["style"] = "dashed" + + legend.add_edge("c","d") + edge = legend.get_edge("c","d") + edge.attr["label"] = "instance with explanation" + edge.attr["color"] = "blue" + edge.attr["style"] = "dashed" + # #============================================================================== @@ -26,19 +46,18 @@ def visualize(dt): g = pygraphviz.AGraph(directed=True, strict=True) g.edge_attr['dir'] = 'forward' - g.graph_attr['rankdir'] = 'TB' # non-terminal nodes for n in dt.nodes: - g.add_node(n, label='{0}\\n({1})'.format(dt.nodes[n].feat, n)) + g.add_node(n, label=dt.nodes[n].feat) node = g.get_node(n) node.attr['shape'] = 'circle' node.attr['fontsize'] = 13 # terminal nodes for n in dt.terms: - g.add_node(n, label='{0}\\n({1})'.format(dt.terms[n], n)) + g.add_node(n, label=dt.terms[n]) node = g.get_node(n) node.attr['shape'] = 'square' node.attr['fontsize'] = 13 @@ -57,7 +76,6 @@ def visualize(dt): edge.attr['arrowsize'] = 0.8 # saving file - g.in_edges g.layout(prog='dot') return(g.to_string()) @@ -76,19 +94,18 @@ def visualize_instance(dt, instance): g = pygraphviz.AGraph(directed=True, strict=True) g.edge_attr['dir'] = 'forward' - g.graph_attr['rankdir'] = 'TB' # non-terminal nodes for n in dt.nodes: - g.add_node(n, label='{0}\\n({1})'.format(dt.nodes[n].feat, n)) + g.add_node(n, label=dt.nodes[n].feat) node = g.get_node(n) node.attr['shape'] = 'circle' node.attr['fontsize'] = 13 # terminal nodes for n in dt.terms: - g.add_node(n, label='{0}\\n({1})'.format(dt.terms[n], n)) + g.add_node(n, label=dt.terms[n]) node = g.get_node(n) node.attr['shape'] = 'square' node.attr['fontsize'] = 13 @@ -98,6 +115,7 @@ def visualize_instance(dt, instance): for v in dt.nodes[n1].vals: n2 = dt.nodes[n1].vals[v] n2_type = g.get_node(n2).attr['shape'] + g.add_edge(n1, n2) edge = g.get_edge(n1, n2) if len(v) == 1: @@ -105,13 +123,16 @@ def visualize_instance(dt, instance): else: edge.attr['label'] = '{0}'.format('\n'.join([dt.fvmap[tuple([dt.nodes[n1].feat, val])] for val in tuple(v)])) - #instance path in blue + #instance path in dashed if ((n1,n2) in edges_instance) or (n2_type=='square' and (n1, "term:"+ dt.terms[n2]) in edges_instance): - edge.attr['color'] = 'blue' + edge.attr['style'] = 'dashed' edge.attr['fontsize'] = 10 edge.attr['arrowsize'] = 0.8 + g.add_subgraph(name='legend') + create_legend(g) + # saving file g.layout(prog='dot') return(g.to_string()) @@ -121,12 +142,6 @@ def visualize_expl(dt, instance, expl): """ Visualize a DT with graphviz and plot the running instance. """ - if '=' in instance[0]: - instance = list(map(lambda i: tuple([i[0], int(i[1])]), [i.split('=') for i in instance])) - - else: - instance = list(map(lambda i : tuple(['f{0}'.format(i[0]), int(i[1])]), [(i, j) for i,j in enumerate(instance)])) - #path that follows the instance - colored in blue path, term, depth = dt.execute(instance) edges_instance = [] @@ -141,14 +156,14 @@ def visualize_expl(dt, instance, expl): # non-terminal nodes for n in dt.nodes: - g.add_node(n, label='{0}\\n({1})'.format(dt.nodes[n].feat, n)) + g.add_node(n, label=dt.nodes[n].feat) node = g.get_node(n) node.attr['shape'] = 'circle' node.attr['fontsize'] = 13 # terminal nodes for n in dt.terms: - g.add_node(n, label='{0}\\n({1})'.format(dt.terms[n], n)) + g.add_node(n, label=dt.terms[n]) node = g.get_node(n) node.attr['shape'] = 'square' node.attr['fontsize'] = 13 @@ -165,13 +180,18 @@ def visualize_expl(dt, instance, expl): else: edge.attr['label'] = '{0}'.format('\n'.join([dt.fvmap[tuple([dt.nodes[n1].feat, val])] for val in tuple(v)])) - #instance path in blue + #instance path in dashed if ((n1,n2) in edges_instance) or (n2_type=='square' and (n1, "term:"+ dt.terms[n2]) in edges_instance): + edge.attr['style'] = 'dashed' + if edge.attr['label'] in expl: edge.attr['color'] = 'blue' edge.attr['fontsize'] = 10 edge.attr['arrowsize'] = 0.8 + g.add_subgraph(name='legend') + create_legend(g) + # saving file g.layout(prog='dot') return(g.to_string())