diff --git a/callbacks.py b/callbacks.py index 53779188ee9bd9b4de03049a0058529ba89d4787..327dbd32a4a69c5c859e84c8306a052c7adab4ca 100644 --- a/callbacks.py +++ b/callbacks.py @@ -57,8 +57,8 @@ def register_callbacks(page_home, page_course, page_application, app): elif ihm_id == 'ml_pretrained_model_choice': if value_ml_model is None : raise PreventUpdate - tree, typ = parse_contents_graph(pretrained_model_contents, pretrained_model_filename) - model_application.update_pretrained_model(tree, typ) + tree = parse_contents_graph(pretrained_model_contents, pretrained_model_filename) + model_application.update_pretrained_model(tree, pretrained_model_filename) return pretrained_model_filename, None, model_application.component.network, None elif ihm_id == 'ml_instance_choice' : diff --git a/pages/application/DecisionTree/DecisionTreeComponent.py b/pages/application/DecisionTree/DecisionTreeComponent.py index 662884f4a755556a110e97b99456adaf860b100e..2347cb40f18a599809f1985f4afcfdd1c0556dd0 100644 --- a/pages/application/DecisionTree/DecisionTreeComponent.py +++ b/pages/application/DecisionTree/DecisionTreeComponent.py @@ -1,10 +1,13 @@ from os import path +import base64 import dash_bootstrap_components as dbc import dash_interactive_graphviz import numpy as np from dash import dcc, html +from pages.application.DecisionTree.utils.upload_tree import UploadedDecisionTree from pages.application.DecisionTree.utils.dtree import DecisionTree + from pages.application.DecisionTree.utils.dtviz import (visualize, visualize_expl, visualize_instance) @@ -12,9 +15,19 @@ from pages.application.DecisionTree.utils.dtviz import (visualize, class DecisionTreeComponent(): - def __init__(self, tree, typ_data): + def __init__(self, tree, filename_tree): + + try: + feature_names = tree.feature_names_in_ + except: + print("You did not dump the model with the features names") + feature_names = [str(i) for i in range(tree.n_features_in_)] + self.uploaded_dt = UploadedDecisionTree(tree, 'SKL', filename_tree, maxdepth=3, feature_names=feature_names) - self.dt = DecisionTree(from_pickle = tree) + #need a function that takes as input UploadedDecisionTree and gives DecisionTree + #self.dt = DecisionTree(from_dt=) + dt = open("pages/application/DecisionTree/meteo.dt", "r").read() + self.dt = DecisionTree(from_dt=dt) dot_source = visualize(self.dt) self.network = [dbc.Row(dash_interactive_graphviz.DashInteractiveGraphviz(dot_source=dot_source, style = {"width": "60%", diff --git a/pages/application/DecisionTree/cancer.dt b/pages/application/DecisionTree/cancer.dt new file mode 100644 index 0000000000000000000000000000000000000000..6b7796a65779c1c226e17f8791fc6a1aadd81501 --- /dev/null +++ b/pages/application/DecisionTree/cancer.dt @@ -0,0 +1,35 @@ +21 +1 +I 1 2 4 6 8 11 12 14 17 19 +T 3 5 7 9 10 13 15 16 18 20 21 +3 T 0 +5 T 0 +7 T 0 +9 T 0 +10 T 1 +13 T 0 +15 T 0 +16 T 1 +18 T 1 +20 T 0 +21 T 1 +1 f5 17 2 +1 f5 16 11 +2 f4 17 3 +2 f4 16 4 +4 f2 17 5 +4 f2 16 6 +6 f2 19 7 +6 f2 18 8 +8 f0 17 9 +8 f0 16 10 +11 f4 17 12 +11 f4 16 17 +12 f1 17 13 +12 f1 16 14 +14 f7 17 15 +14 f7 16 16 +17 f5 15 18 +17 f5 14 19 +19 f2 17 20 +19 f2 16 21 diff --git a/pages/application/DecisionTree/iris.dt b/pages/application/DecisionTree/iris.dt new file mode 100644 index 0000000000000000000000000000000000000000..77d79f3ffaaa4b715c6a8f0d3f1514a4a346d3db --- /dev/null +++ b/pages/application/DecisionTree/iris.dt @@ -0,0 +1,35 @@ +21 +1 +I 1 3 5 7 9 11 13 15 17 19 +T 2 4 6 8 10 12 14 16 18 20 21 +2 T Setosa +4 T Setosa +6 T Setosa +8 T Setosa +10 T Setosa +12 T Versicolor +14 T Versicolor +16 T Versicolor +18 T Versicolor +20 T Versicolor +21 T Virginica +1 f3 23 2 +1 f3 22 3 +3 f2 31 4 +3 f2 30 5 +5 f3 41 6 +5 f3 40 7 +7 f2 21 8 +7 f2 20 9 +9 f3 15 10 +9 f3 14 11 +11 f3 25 12 +11 f3 24 13 +13 f3 7 14 +13 f3 6 15 +15 f3 1 16 +15 f3 0 17 +17 f3 5 18 +17 f3 4 19 +19 f2 73 20 +19 f2 72 21 \ No newline at end of file diff --git a/pages/application/DecisionTree/meteo.dt b/pages/application/DecisionTree/meteo.dt new file mode 100644 index 0000000000000000000000000000000000000000..3892d46e85daf5116bbe2741413f204bed6908a7 --- /dev/null +++ b/pages/application/DecisionTree/meteo.dt @@ -0,0 +1,8 @@ +3 +1 +I 1 +T 2 3 +2 T - +3 T + +1 f0 5 2 +1 f0 4 3 diff --git a/pages/application/DecisionTree/utils/dtree.py b/pages/application/DecisionTree/utils/dtree.py index 6bd118dfd8cb6ca577338d8a8189cb00cea33a68..c002840616a5133570b6b0d9945e128f59d06d11 100644 --- a/pages/application/DecisionTree/utils/dtree.py +++ b/pages/application/DecisionTree/utils/dtree.py @@ -11,47 +11,32 @@ # #============================================================================== from __future__ import print_function - import collections from functools import reduce - -import sklearn from pysat.card import * from pysat.examples.hitman import Hitman from pysat.formula import CNF, IDPool from pysat.solvers import Solver -from torch import threshold try: # for Python2 from cStringIO import StringIO except ImportError: # for Python3 from io import StringIO - -import numpy as np -from dash import dcc, html from sklearn.tree import _tree +import numpy as np - -# -#============================================================================== class Node(): """ Node class. """ - def __init__(self, feat='', vals=None, threshold=None, children_left= None, children_right=None): + def __init__(self, feat='', vals=[]): """ Constructor. """ self.feat = feat - if threshold is not None : - self.threshold = threshold - self.children_left = 0 - self.children_right = 0 - else : - self.vals = {} - + self.vals = vals # #============================================================================== @@ -60,13 +45,12 @@ class DecisionTree(): Simple decision tree class. """ - def __init__(self, from_pickle=None, verbose=0): + def __init__(self, from_dt=None, verbose=0): """ Constructor. """ self.verbose = verbose - self.typ="" self.nof_nodes = 0 self.nof_terms = 0 @@ -76,57 +60,79 @@ class DecisionTree(): self.paths = {} self.feats = [] self.feids = {} + self.fdoms = {} + self.fvmap = {} - if from_pickle: - self.typ="pkl" - self.tree_ = '' - self.from_pickle_file(from_pickle) - - #problem de feature names et problem de vals dans node - def from_pickle_file(self, tree): - #help(_tree.Tree) - self.tree_ = tree.tree_ - #print(sklearn.tree.export_text(tree)) - try: - feature_names = tree.feature_names_in_ - except: - print("You did not dump the model with the features names") - feature_names = [str(i) for i in range(tree.n_features_in_)] - - class_names = tree.classes_ - self.nodes = collections.defaultdict(lambda: Node(feat='', threshold=int(0), children_left=int(0), children_right=int(0))) - self.terms={} - self.nof_nodes = self.tree_.node_count - self.root_node = 0 - self.feats = feature_names - - feature_name = [ - feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!" - for i in self.tree_.feature] - - def recurse(node): - if self.tree_.feature[node] != _tree.TREE_UNDEFINED: - name = feature_name[node] - val = self.tree_.threshold[node] - - #faire une boucle for des vals ? - self.nodes[int(node)].feat = name - self.nodes[int(node)].threshold = np.round(val, 4) - self.nodes[int(node)].children_left = int(self.tree_.children_left[node]) - self.nodes[int(node)].children_right = int(self.tree_.children_right[node]) - - recurse(self.tree_.children_left[node]) - recurse(self.tree_.children_right[node]) + # OHE mapping + OHEMap = collections.namedtuple('OHEMap', ['dir', 'opp']) + self.ohmap = OHEMap(dir={}, opp={}) - else: - self.terms[node] = class_names[np.argmax(self.tree_.value[node])] - - recurse(self.root_node) + if from_dt: + self.from_dt(from_dt) + + for f in self.feats: + for v in self.fdoms[f]: + self.fvmap[tuple([f, v])] = '{0}={1}'.format(f, v) + + def from_dt(self, data): + """ + Get the tree from a file pointer. + """ + + contents = StringIO(data) + + lines = contents.readlines() + + # filtering out comment lines (those that start with '#') + lines = list(filter(lambda l: not l.startswith('#'), lines)) + + # number of nodes + self.nof_nodes = int(lines[0].strip()) + + # root node + self.root_node = int(lines[1].strip()) + + # number of terminal nodes (classes) + self.nof_terms = len(lines[3][2:].strip().split()) + # the ordered list of terminal nodes + self.terms = {} + for i in range(self.nof_terms): + nd, _, t = lines[i + 4].strip().split() + self.terms[int(nd)] = t #int(t) + + # finally, reading the nodes + self.nodes = collections.defaultdict(lambda: Node(feat='', vals={})) + self.feats = set([]) + self.fdoms = collections.defaultdict(lambda: set([])) + for line in lines[(4 + self.nof_terms):]: + # reading the tuple + nid, fid, fval, child = line.strip().split() + + # inserting it in the nodes list + self.nodes[int(nid)].feat = fid + self.nodes[int(nid)].vals[int(fval)] = int(child) + + # updating the list of features + self.feats.add(fid) + + # updaing feature domains + self.fdoms[fid].add(int(fval)) + + # adding complex node connections into consideration + for n1 in self.nodes: + conns = collections.defaultdict(lambda: set([])) + for v, n2 in self.nodes[n1].vals.items(): + conns[n2].add(v) + self.nodes[n1].vals = {frozenset(v): n2 for n2, v in conns.items()} + + # simplifying the features and their domains self.feats = sorted(self.feats) self.feids = {f: i for i, f in enumerate(self.feats)} - self.nof_terms = len(self.terms) - self.nof_nodes -= len(self.terms) + self.fdoms = {f: sorted(self.fdoms[f]) for f in self.fdoms} + + # here we assume all features are present in the tree + # if not, this value will be rewritten by self.parse_mapping() self.nof_feats = len(self.feats) self.paths = collections.defaultdict(lambda: []) @@ -137,23 +143,69 @@ class DecisionTree(): Traverse the tree and extract explicit paths. """ - if root in self.terms.keys(): + if root in self.terms: # store the path term = self.terms[root] self.paths[term].append(prefix) else: # select next node - feat, threshold, children_left, children_right = self.nodes[root].feat, self.nodes[root].threshold, self.nodes[root].children_left, self.nodes[root].children_right - self.extract_paths(children_left, prefix + [tuple([feat, "<=" + str(threshold)])]) - self.extract_paths(children_right, prefix + [tuple([feat, ">"+ str(threshold)])]) - - def execute(self, inst): - inst = np.array([inst]) - path = self.tree_.decision_path(inst) - term_id_node = self.tree_.apply(inst) - term_id_node = term_id_node[0] - path = path.indices[path.indptr[0] : path.indptr[0 + 1]] - return path, term_id_node + feat, vals = self.nodes[root].feat, self.nodes[root].vals + for val in vals: + self.extract_paths(vals[val], prefix + [tuple([feat, val])]) + + def execute(self, inst, pathlits=False): + """ + Run the tree and obtain the prediction given an input instance. + """ + + root = self.root_node + depth = 0 + path = [] + + # this array is needed if we focus on the path's literals only + visited = [False for f in inst] + + while not root in self.terms: + path.append(root) + feat, vals = self.nodes[root].feat, self.nodes[root].vals + visited[self.feids[feat]] = True + tval = inst[self.feids[feat]][1] + ############### + # assert(len(vals) == 2) + next_node = root + neq = None + for vs, dest in vals.items(): + if tval in vs: + next_node = dest + break + else: + for v in vs: + if '!=' in self.fvmap[(feat, v)]: + neq = dest + break + else: + next_node = neq + # if tval not in vals: + # # go to the False branch (!=) + # for i in vals: + # if "!=" in self.fvmap[(feat,i)]: + # next_node = vals[i] + # break + # else: + # next_node = vals[tval] + + assert (next_node != root) + ############### + root = next_node + depth += 1 + + if pathlits: + # filtering out non-visited literals + for i, v in enumerate(visited): + if not v: + inst[i] = None + + return path, self.terms[root], depth def prepare_sets(self, inst, term): """ @@ -164,7 +216,7 @@ class DecisionTree(): sets = [] for t, paths in self.paths.items(): # ignoring the right class - if term in self.terms.keys() and self.terms[term] == t: + if t == term: continue # computing the sets to hit @@ -172,16 +224,21 @@ class DecisionTree(): to_hit = [] for item in path: # if the instance disagrees with the path on this item - if ("<=" in item[1] and (inst[item[0]] > np.float32(item[1][2:]))) or (">" in item[1] and (inst[item[0]] <= np.float32(item[1][1:]))) : - if "<=" in item[1] : - fv = tuple([item[0], str(inst[item[0]]), ">" , str(np.float32(item[1][2:]))]) - else : - fv = tuple([item[0], str(inst[item[0]]) , "<=" , str(np.float32(item[1][1:]))]) - to_hit.append(fv) + if inst[self.feids[item[0]]] and not inst[self.feids[item[0]]][1] in item[1]: + fv = inst[self.feids[item[0]]] + if fv[0] in self.ohmap.opp: + to_hit.append(tuple([self.ohmap.opp[fv[0]], None])) + else: + to_hit.append(fv) + + to_hit = sorted(set(to_hit)) + sets.append(tuple(to_hit)) + + if self.verbose: + if self.verbose > 1: + print('c trav. path: {0}'.format(path)) - if len(to_hit)>0 : - to_hit = sorted(set(to_hit)) - sets.append(tuple(to_hit)) + print('c set to hit: {0}'.format(to_hit)) # returning the set of sets with no duplicates return list(dict.fromkeys(sets)) @@ -191,11 +248,10 @@ class DecisionTree(): Compute a given number of explanations. """ - inst_values = [np.float32(i[1]) for i in inst] inst_dic = {} for i in range(len(inst)): inst_dic[inst[i][0]] = np.float32(inst[i][1]) - path, term = self.execute(inst_values) + path, term, depth = self.execute(inst) #contaiins all the elements for explanation explanation_dic = {} @@ -203,24 +259,12 @@ class DecisionTree(): explanation_dic["Instance : "] = str(inst_dic) #decision path - decision_path_str = "IF : " - for node_id in path: - # continue to the next node if it is a leaf node - if term == node_id: - continue - - decision_path_str +="(inst[{feature}] = {value}) {inequality} {threshold}) AND ".format( - feature=self.nodes[node_id].feat, - value=inst_dic[self.nodes[node_id].feat], - inequality="<=" if inst_dic[self.nodes[node_id].feat] <= self.nodes[node_id].threshold else ">" , - threshold=self.nodes[node_id].threshold) - - decision_path_str += "THEN " + str(self.terms[term]) + decision_path_str = 'IF {0} THEN class={1}'.format(' AND '.join([self.fvmap[inst[self.feids[self.nodes[n].feat]] ] for n in path]), term) explanation_dic["Decision path of instance : "] = decision_path_str explanation_dic["Decision path length : "] = 'Path length is :'+ str(len(path)) # computing the sets to hit - to_hit = self.prepare_sets(inst_dic, term) + to_hit = self.prepare_sets(inst, term) for type in xtype : if type == "AXp": @@ -240,12 +284,9 @@ class DecisionTree(): with Hitman(bootstrap_with=to_hit, solver='m22', htype=htype) as hitman: expls = [] for i, expl in enumerate(hitman.enumerate(), 1): - list_expls.append([ p[0] + p[2] + p[3] for p in expl]) - list_expls_str.append('Explanation: IF {0} THEN class={1}'.format(' AND '.join(["(inst[{feature}] = {value}) {inequality} {threshold})".format(feature=p[0], - value=p[1], - inequality=p[2], - threshold=p[3]) - for p in sorted(expl, key=lambda p: p[0])]), str(self.terms[term]))) + list_expls.append([ str(p[0]) + "=" + str(p[1]) for p in expl]) + list_expls_str.append('Explanation: IF {0} THEN class={1}'.format(' AND '.join([self.fvmap[p] for p in sorted(expl, key=lambda p: p[0])]), term)) + expls.append(expl) if i == enum: break @@ -277,10 +318,8 @@ class DecisionTree(): list_expls_str = [] explanation = {} for expl in expls: - list_expls_str.append('Contrastive: IF {0} THEN class!={1}'.format(' OR '.join(["inst[{feature}] {inequality} {threshold})".format(feature=p[0], - inequality="<=" if p[2]==">" else ">", - threshold=p[3]) - for p in sorted(expl, key=lambda p: p[0])]), str(self.terms[term]))) + list_expls_str.append('Contrastive: IF {0} THEN class!={1}'.format(' OR '.join(['!{0}'.format(self.fvmap[p]) for p in sorted(expl, key=lambda p: p[0])]), term)) + explanation["List of contrastive explanation(s)"] = list_expls_str explanation["Number of contrastive explanation(s) : "]=str(len(expls)) explanation["Minimal contrastive explanation : "]= str( min([len(e) for e in expls])) diff --git a/pages/application/DecisionTree/utils/dtviz.py b/pages/application/DecisionTree/utils/dtviz.py index aacf7b646dd1384157116c138a125a069670c772..f5a490856ce572b28849a09a90c15ad67853b7d6 100755 --- a/pages/application/DecisionTree/utils/dtviz.py +++ b/pages/application/DecisionTree/utils/dtviz.py @@ -8,78 +8,58 @@ ## E-mail: alexey.ignatiev@monash.edu ## -import numpy as np -import pygraphviz # #============================================================================== -def create_legend(g): - legend = g.subgraphs()[-1] - legend.add_node("a", style = "invis") - legend.add_node("b", style = "invis") - legend.add_node("c", style = "invis") - legend.add_node("d", style = "invis") - - legend.add_edge("a","b") - edge = legend.get_edge("a","b") - edge.attr["label"] = "instance" - edge.attr["style"] = "dashed" - - legend.add_edge("c","d") - edge = legend.get_edge("c","d") - edge.attr["label"] = "instance with explanation" - edge.attr["color"] = "blue" - edge.attr["style"] = "dashed" +from pages.application.DecisionTree.utils.dtree import DecisionTree +import getopt +import os +import pygraphviz +import sys +# +#============================================================================== def visualize(dt): """ Visualize a DT with graphviz. """ - g = pygraphviz.AGraph(name='root', rankdir="TB") - g.is_directed() - g.is_strict() - - #g = pygraphviz.AGraph(name = "main", directed=True, strict=True) + g = pygraphviz.AGraph(directed=True, strict=True) g.edge_attr['dir'] = 'forward' + g.graph_attr['rankdir'] = 'TB' + # non-terminal nodes for n in dt.nodes: - g.add_node(n, label=str(dt.nodes[n].feat)) + g.add_node(n, label='{0}\\n({1})'.format(dt.nodes[n].feat, n)) node = g.get_node(n) node.attr['shape'] = 'circle' node.attr['fontsize'] = 13 # terminal nodes for n in dt.terms: - g.add_node(n, label=str(dt.terms[n])) + g.add_node(n, label='{0}\\n({1})'.format(dt.terms[n], n)) node = g.get_node(n) node.attr['shape'] = 'square' node.attr['fontsize'] = 13 + # transitions for n1 in dt.nodes: - threshold = dt.nodes[n1].threshold - - children_left = dt.nodes[n1].children_left - g.add_edge(n1, children_left) - edge = g.get_edge(n1, children_left) - edge.attr['label'] = str(dt.nodes[n1].feat) + "<=" + str(threshold) - edge.attr['fontsize'] = 10 - edge.attr['arrowsize'] = 0.8 - - children_right = dt.nodes[n1].children_right - g.add_edge(n1, children_right) - edge = g.get_edge(n1, children_right) - edge.attr['label'] = str(dt.nodes[n1].feat) + ">" + str(threshold) - edge.attr['fontsize'] = 10 - edge.attr['arrowsize'] = 0.8 - - g.add_subgraph(name='legend') - create_legend(g) + for v in dt.nodes[n1].vals: + n2 = dt.nodes[n1].vals[v] + g.add_edge(n1, n2) + edge = g.get_edge(n1, n2) + if len(v) == 1: + edge.attr['label'] = dt.fvmap[tuple([dt.nodes[n1].feat, tuple(v)[0]])] + else: + edge.attr['label'] = '{0}'.format('\n'.join([dt.fvmap[tuple([dt.nodes[n1].feat, val])] for val in tuple(v)])) + edge.attr['fontsize'] = 10 + edge.attr['arrowsize'] = 0.8 # saving file + g.in_edges g.layout(prog='dot') - return(g.string()) + return(g.to_string()) # #============================================================================== @@ -87,120 +67,111 @@ def visualize_instance(dt, instance): """ Visualize a DT with graphviz and plot the running instance. """ + #path that follows the instance - colored in blue + path, term, depth = dt.execute(instance) + edges_instance = [] + for i in range (len(path)-1) : + edges_instance.append((path[i], path[i+1])) + edges_instance.append((path[-1],"term:"+term)) + g = pygraphviz.AGraph(directed=True, strict=True) g.edge_attr['dir'] = 'forward' + g.graph_attr['rankdir'] = 'TB' # non-terminal nodes for n in dt.nodes: - g.add_node(n, label=str(dt.nodes[n].feat)) + g.add_node(n, label='{0}\\n({1})'.format(dt.nodes[n].feat, n)) node = g.get_node(n) node.attr['shape'] = 'circle' node.attr['fontsize'] = 13 # terminal nodes for n in dt.terms: - g.add_node(n, label=str(dt.terms[n])) + g.add_node(n, label='{0}\\n({1})'.format(dt.terms[n], n)) node = g.get_node(n) node.attr['shape'] = 'square' node.attr['fontsize'] = 13 - #path that follows the instance - colored in blue - instance = [np.float32(i[1]) for i in instance] - path, term_id_node = dt.execute(instance) - edges_instance = [] - for i in range (len(path)-1) : - edges_instance.append((path[i], path[i+1])) - + # transitions for n1 in dt.nodes: - threshold = dt.nodes[n1].threshold - - children_left = dt.nodes[n1].children_left - g.add_edge(n1, children_left) - edge = g.get_edge(n1, children_left) - edge.attr['label'] = str(dt.nodes[n1].feat) + "<=" + str(threshold) - edge.attr['fontsize'] = 10 - edge.attr['arrowsize'] = 0.8 - #instance path in blue - if ((n1,children_left) in edges_instance): - edge.attr['style'] = 'dashed' - - children_right = dt.nodes[n1].children_right - g.add_edge(n1, children_right) - edge = g.get_edge(n1, children_right) - edge.attr['label'] = str(dt.nodes[n1].feat) + ">" + str(threshold) - edge.attr['fontsize'] = 10 - edge.attr['arrowsize'] = 0.8 - #instance path in blue - if ((n1,children_right) in edges_instance): - edge.attr['style'] = 'dashed' - - g.add_subgraph(name='legend') - create_legend(g) + for v in dt.nodes[n1].vals: + n2 = dt.nodes[n1].vals[v] + n2_type = g.get_node(n2).attr['shape'] + g.add_edge(n1, n2) + edge = g.get_edge(n1, n2) + if len(v) == 1: + edge.attr['label'] = dt.fvmap[tuple([dt.nodes[n1].feat, tuple(v)[0]])] + else: + edge.attr['label'] = '{0}'.format('\n'.join([dt.fvmap[tuple([dt.nodes[n1].feat, val])] for val in tuple(v)])) + + #instance path in blue + if ((n1,n2) in edges_instance) or (n2_type=='square' and (n1, "term:"+ dt.terms[n2]) in edges_instance): + edge.attr['color'] = 'blue' + + edge.attr['fontsize'] = 10 + edge.attr['arrowsize'] = 0.8 # saving file g.layout(prog='dot') return(g.to_string()) -# + #============================================================================== def visualize_expl(dt, instance, expl): """ Visualize a DT with graphviz and plot the running instance. """ + if '=' in instance[0]: + instance = list(map(lambda i: tuple([i[0], int(i[1])]), [i.split('=') for i in instance])) + + else: + instance = list(map(lambda i : tuple(['f{0}'.format(i[0]), int(i[1])]), [(i, j) for i,j in enumerate(instance)])) + + #path that follows the instance - colored in blue + path, term, depth = dt.execute(instance) + edges_instance = [] + for i in range (len(path)-1) : + edges_instance.append((path[i], path[i+1])) + edges_instance.append((path[-1],"term:"+term)) + g = pygraphviz.AGraph(directed=True, strict=True) g.edge_attr['dir'] = 'forward' + g.graph_attr['rankdir'] = 'TB' # non-terminal nodes for n in dt.nodes: - g.add_node(n, label=str(dt.nodes[n].feat)) + g.add_node(n, label='{0}\\n({1})'.format(dt.nodes[n].feat, n)) node = g.get_node(n) node.attr['shape'] = 'circle' node.attr['fontsize'] = 13 # terminal nodes for n in dt.terms: - g.add_node(n, label=str(dt.terms[n])) + g.add_node(n, label='{0}\\n({1})'.format(dt.terms[n], n)) node = g.get_node(n) node.attr['shape'] = 'square' node.attr['fontsize'] = 13 - #path that follows the instance - colored in blue - instance = [np.float32(i[1]) for i in instance] - path, term_id_node = dt.execute(instance) - edges_instance = [] - for i in range (len(path)-1) : - edges_instance.append((path[i], path[i+1])) - + # transitions for n1 in dt.nodes: - threshold = dt.nodes[n1].threshold - - children_left = dt.nodes[n1].children_left - g.add_edge(n1, children_left) - edge = g.get_edge(n1, children_left) - edge.attr['label'] = str(dt.nodes[n1].feat) + "<=" + str(threshold) - edge.attr['fontsize'] = 10 - edge.attr['arrowsize'] = 0.8 - #instance path in blue - if ((n1,children_left) in edges_instance): - edge.attr['style'] = 'dashed' - if edge.attr['label'] in expl : - edge.attr['color'] = 'blue' - - children_right = dt.nodes[n1].children_right - g.add_edge(n1, children_right) - edge = g.get_edge(n1, children_right) - edge.attr['label'] = str(dt.nodes[n1].feat) + ">" + str(threshold) - edge.attr['fontsize'] = 10 - edge.attr['arrowsize'] = 0.8 - #instance path in blue - if ((n1,children_right) in edges_instance): - edge.attr['style'] = 'dashed' - if edge.attr['label'] in expl : - edge.attr['color'] = 'blue' - - g.add_subgraph(name='legend') - create_legend(g) + for v in dt.nodes[n1].vals: + n2 = dt.nodes[n1].vals[v] + n2_type = g.get_node(n2).attr['shape'] + g.add_edge(n1, n2) + edge = g.get_edge(n1, n2) + if len(v) == 1: + edge.attr['label'] = dt.fvmap[tuple([dt.nodes[n1].feat, tuple(v)[0]])] + else: + edge.attr['label'] = '{0}'.format('\n'.join([dt.fvmap[tuple([dt.nodes[n1].feat, val])] for val in tuple(v)])) + + #instance path in blue + if ((n1,n2) in edges_instance) or (n2_type=='square' and (n1, "term:"+ dt.terms[n2]) in edges_instance): + edge.attr['color'] = 'blue' + + edge.attr['fontsize'] = 10 + edge.attr['arrowsize'] = 0.8 + # saving file g.layout(prog='dot') return(g.to_string()) diff --git a/pages/application/DecisionTree/utils/save/dtree save.py b/pages/application/DecisionTree/utils/save/dtree save.py new file mode 100644 index 0000000000000000000000000000000000000000..6bd118dfd8cb6ca577338d8a8189cb00cea33a68 --- /dev/null +++ b/pages/application/DecisionTree/utils/save/dtree save.py @@ -0,0 +1,290 @@ +#!/usr/bin/env python +#-*- coding:utf-8 -*- +## +## dtree.py +## +## Created on: Jul 6, 2020 +## Author: Alexey Ignatiev +## E-mail: alexey.ignatiev@monash.edu +## + +# +#============================================================================== +from __future__ import print_function + +import collections +from functools import reduce + +import sklearn +from pysat.card import * +from pysat.examples.hitman import Hitman +from pysat.formula import CNF, IDPool +from pysat.solvers import Solver +from torch import threshold + +try: # for Python2 + from cStringIO import StringIO +except ImportError: # for Python3 + from io import StringIO + +import numpy as np +from dash import dcc, html +from sklearn.tree import _tree + + +# +#============================================================================== +class Node(): + """ + Node class. + """ + + def __init__(self, feat='', vals=None, threshold=None, children_left= None, children_right=None): + """ + Constructor. + """ + + self.feat = feat + if threshold is not None : + self.threshold = threshold + self.children_left = 0 + self.children_right = 0 + else : + self.vals = {} + + +# +#============================================================================== +class DecisionTree(): + """ + Simple decision tree class. + """ + + def __init__(self, from_pickle=None, verbose=0): + """ + Constructor. + """ + + self.verbose = verbose + self.typ="" + + self.nof_nodes = 0 + self.nof_terms = 0 + self.root_node = None + self.terms = [] + self.nodes = {} + self.paths = {} + self.feats = [] + self.feids = {} + + if from_pickle: + self.typ="pkl" + self.tree_ = '' + self.from_pickle_file(from_pickle) + + #problem de feature names et problem de vals dans node + def from_pickle_file(self, tree): + #help(_tree.Tree) + self.tree_ = tree.tree_ + #print(sklearn.tree.export_text(tree)) + try: + feature_names = tree.feature_names_in_ + except: + print("You did not dump the model with the features names") + feature_names = [str(i) for i in range(tree.n_features_in_)] + + class_names = tree.classes_ + self.nodes = collections.defaultdict(lambda: Node(feat='', threshold=int(0), children_left=int(0), children_right=int(0))) + self.terms={} + self.nof_nodes = self.tree_.node_count + self.root_node = 0 + self.feats = feature_names + + feature_name = [ + feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!" + for i in self.tree_.feature] + + def recurse(node): + if self.tree_.feature[node] != _tree.TREE_UNDEFINED: + name = feature_name[node] + val = self.tree_.threshold[node] + + #faire une boucle for des vals ? + self.nodes[int(node)].feat = name + self.nodes[int(node)].threshold = np.round(val, 4) + self.nodes[int(node)].children_left = int(self.tree_.children_left[node]) + self.nodes[int(node)].children_right = int(self.tree_.children_right[node]) + + recurse(self.tree_.children_left[node]) + recurse(self.tree_.children_right[node]) + + else: + self.terms[node] = class_names[np.argmax(self.tree_.value[node])] + + recurse(self.root_node) + + self.feats = sorted(self.feats) + self.feids = {f: i for i, f in enumerate(self.feats)} + self.nof_terms = len(self.terms) + self.nof_nodes -= len(self.terms) + self.nof_feats = len(self.feats) + + self.paths = collections.defaultdict(lambda: []) + self.extract_paths(root=self.root_node, prefix=[]) + + def extract_paths(self, root, prefix): + """ + Traverse the tree and extract explicit paths. + """ + + if root in self.terms.keys(): + # store the path + term = self.terms[root] + self.paths[term].append(prefix) + else: + # select next node + feat, threshold, children_left, children_right = self.nodes[root].feat, self.nodes[root].threshold, self.nodes[root].children_left, self.nodes[root].children_right + self.extract_paths(children_left, prefix + [tuple([feat, "<=" + str(threshold)])]) + self.extract_paths(children_right, prefix + [tuple([feat, ">"+ str(threshold)])]) + + def execute(self, inst): + inst = np.array([inst]) + path = self.tree_.decision_path(inst) + term_id_node = self.tree_.apply(inst) + term_id_node = term_id_node[0] + path = path.indices[path.indptr[0] : path.indptr[0 + 1]] + return path, term_id_node + + def prepare_sets(self, inst, term): + """ + Hitting set based encoding of the problem. + (currently not incremental -- should be fixed later) + """ + + sets = [] + for t, paths in self.paths.items(): + # ignoring the right class + if term in self.terms.keys() and self.terms[term] == t: + continue + + # computing the sets to hit + for path in paths: + to_hit = [] + for item in path: + # if the instance disagrees with the path on this item + if ("<=" in item[1] and (inst[item[0]] > np.float32(item[1][2:]))) or (">" in item[1] and (inst[item[0]] <= np.float32(item[1][1:]))) : + if "<=" in item[1] : + fv = tuple([item[0], str(inst[item[0]]), ">" , str(np.float32(item[1][2:]))]) + else : + fv = tuple([item[0], str(inst[item[0]]) , "<=" , str(np.float32(item[1][1:]))]) + to_hit.append(fv) + + if len(to_hit)>0 : + to_hit = sorted(set(to_hit)) + sets.append(tuple(to_hit)) + + # returning the set of sets with no duplicates + return list(dict.fromkeys(sets)) + + def explain(self, inst, enum=1, pathlits=False, xtype = ["AXp"], solver='g3', htype='sorted'): + """ + Compute a given number of explanations. + """ + + inst_values = [np.float32(i[1]) for i in inst] + inst_dic = {} + for i in range(len(inst)): + inst_dic[inst[i][0]] = np.float32(inst[i][1]) + path, term = self.execute(inst_values) + + #contaiins all the elements for explanation + explanation_dic = {} + #instance plotting + explanation_dic["Instance : "] = str(inst_dic) + + #decision path + decision_path_str = "IF : " + for node_id in path: + # continue to the next node if it is a leaf node + if term == node_id: + continue + + decision_path_str +="(inst[{feature}] = {value}) {inequality} {threshold}) AND ".format( + feature=self.nodes[node_id].feat, + value=inst_dic[self.nodes[node_id].feat], + inequality="<=" if inst_dic[self.nodes[node_id].feat] <= self.nodes[node_id].threshold else ">" , + threshold=self.nodes[node_id].threshold) + + decision_path_str += "THEN " + str(self.terms[term]) + explanation_dic["Decision path of instance : "] = decision_path_str + explanation_dic["Decision path length : "] = 'Path length is :'+ str(len(path)) + + # computing the sets to hit + to_hit = self.prepare_sets(inst_dic, term) + + for type in xtype : + if type == "AXp": + explanation_dic.update(self.enumerate_abductive(to_hit, enum, solver, htype, term)) + else : + explanation_dic.update(self.enumerate_contrastive(to_hit, term)) + + return explanation_dic + + def enumerate_abductive(self, to_hit, enum, solver, htype, term): + """ + Enumerate abductive explanations. + """ + list_expls = [] + list_expls_str = [] + explanation = {} + with Hitman(bootstrap_with=to_hit, solver='m22', htype=htype) as hitman: + expls = [] + for i, expl in enumerate(hitman.enumerate(), 1): + list_expls.append([ p[0] + p[2] + p[3] for p in expl]) + list_expls_str.append('Explanation: IF {0} THEN class={1}'.format(' AND '.join(["(inst[{feature}] = {value}) {inequality} {threshold})".format(feature=p[0], + value=p[1], + inequality=p[2], + threshold=p[3]) + for p in sorted(expl, key=lambda p: p[0])]), str(self.terms[term]))) + expls.append(expl) + if i == enum: + break + explanation["List of path explanation(s)"] = list_expls + explanation["List of abductive explanation(s)"] = list_expls_str + explanation["Number of abductive explanation(s) : "] = str(i) + explanation["Minimal abductive explanation : "] = str( min([len(e) for e in expls])) + explanation["Maximal abductive explanation : "] = str( max([len(e) for e in expls])) + explanation["Average abductive explanation : "] = '{0:.2f}'.format(sum([len(e) for e in expls]) / len(expls)) + + return explanation + + def enumerate_contrastive(self, to_hit, term): + """ + Enumerate contrastive explanations. + """ + + def process_set(done, target): + for s in done: + if s <= target: + break + else: + done.append(target) + return done + + to_hit = [set(s) for s in to_hit] + to_hit.sort(key=lambda s: len(s)) + expls = list(reduce(process_set, to_hit, [])) + list_expls_str = [] + explanation = {} + for expl in expls: + list_expls_str.append('Contrastive: IF {0} THEN class!={1}'.format(' OR '.join(["inst[{feature}] {inequality} {threshold})".format(feature=p[0], + inequality="<=" if p[2]==">" else ">", + threshold=p[3]) + for p in sorted(expl, key=lambda p: p[0])]), str(self.terms[term]))) + explanation["List of contrastive explanation(s)"] = list_expls_str + explanation["Number of contrastive explanation(s) : "]=str(len(expls)) + explanation["Minimal contrastive explanation : "]= str( min([len(e) for e in expls])) + explanation["Maximal contrastive explanation : "]= str( max([len(e) for e in expls])) + explanation["Average contrastive explanation : "]='{0:.2f}'.format(sum([len(e) for e in expls]) / len(expls)) + + return explanation diff --git a/pages/application/DecisionTree/utils/save/dtviz save.py b/pages/application/DecisionTree/utils/save/dtviz save.py new file mode 100755 index 0000000000000000000000000000000000000000..aacf7b646dd1384157116c138a125a069670c772 --- /dev/null +++ b/pages/application/DecisionTree/utils/save/dtviz save.py @@ -0,0 +1,206 @@ +#!/usr/bin/env python +#-*- coding:utf-8 -*- +## +## dtviz.py +## +## Created on: Jul 7, 2020 +## Author: Alexey Ignatiev +## E-mail: alexey.ignatiev@monash.edu +## + +import numpy as np +import pygraphviz +# +#============================================================================== +def create_legend(g): + legend = g.subgraphs()[-1] + legend.add_node("a", style = "invis") + legend.add_node("b", style = "invis") + legend.add_node("c", style = "invis") + legend.add_node("d", style = "invis") + + legend.add_edge("a","b") + edge = legend.get_edge("a","b") + edge.attr["label"] = "instance" + edge.attr["style"] = "dashed" + + legend.add_edge("c","d") + edge = legend.get_edge("c","d") + edge.attr["label"] = "instance with explanation" + edge.attr["color"] = "blue" + edge.attr["style"] = "dashed" + + +def visualize(dt): + """ + Visualize a DT with graphviz. + """ + + g = pygraphviz.AGraph(name='root', rankdir="TB") + g.is_directed() + g.is_strict() + + #g = pygraphviz.AGraph(name = "main", directed=True, strict=True) + g.edge_attr['dir'] = 'forward' + + # non-terminal nodes + for n in dt.nodes: + g.add_node(n, label=str(dt.nodes[n].feat)) + node = g.get_node(n) + node.attr['shape'] = 'circle' + node.attr['fontsize'] = 13 + + # terminal nodes + for n in dt.terms: + g.add_node(n, label=str(dt.terms[n])) + node = g.get_node(n) + node.attr['shape'] = 'square' + node.attr['fontsize'] = 13 + + for n1 in dt.nodes: + threshold = dt.nodes[n1].threshold + + children_left = dt.nodes[n1].children_left + g.add_edge(n1, children_left) + edge = g.get_edge(n1, children_left) + edge.attr['label'] = str(dt.nodes[n1].feat) + "<=" + str(threshold) + edge.attr['fontsize'] = 10 + edge.attr['arrowsize'] = 0.8 + + children_right = dt.nodes[n1].children_right + g.add_edge(n1, children_right) + edge = g.get_edge(n1, children_right) + edge.attr['label'] = str(dt.nodes[n1].feat) + ">" + str(threshold) + edge.attr['fontsize'] = 10 + edge.attr['arrowsize'] = 0.8 + + g.add_subgraph(name='legend') + create_legend(g) + + # saving file + g.layout(prog='dot') + return(g.string()) + +# +#============================================================================== +def visualize_instance(dt, instance): + """ + Visualize a DT with graphviz and plot the running instance. + """ + g = pygraphviz.AGraph(directed=True, strict=True) + g.edge_attr['dir'] = 'forward' + g.graph_attr['rankdir'] = 'TB' + + # non-terminal nodes + for n in dt.nodes: + g.add_node(n, label=str(dt.nodes[n].feat)) + node = g.get_node(n) + node.attr['shape'] = 'circle' + node.attr['fontsize'] = 13 + + # terminal nodes + for n in dt.terms: + g.add_node(n, label=str(dt.terms[n])) + node = g.get_node(n) + node.attr['shape'] = 'square' + node.attr['fontsize'] = 13 + + #path that follows the instance - colored in blue + instance = [np.float32(i[1]) for i in instance] + path, term_id_node = dt.execute(instance) + edges_instance = [] + for i in range (len(path)-1) : + edges_instance.append((path[i], path[i+1])) + + for n1 in dt.nodes: + threshold = dt.nodes[n1].threshold + + children_left = dt.nodes[n1].children_left + g.add_edge(n1, children_left) + edge = g.get_edge(n1, children_left) + edge.attr['label'] = str(dt.nodes[n1].feat) + "<=" + str(threshold) + edge.attr['fontsize'] = 10 + edge.attr['arrowsize'] = 0.8 + #instance path in blue + if ((n1,children_left) in edges_instance): + edge.attr['style'] = 'dashed' + + children_right = dt.nodes[n1].children_right + g.add_edge(n1, children_right) + edge = g.get_edge(n1, children_right) + edge.attr['label'] = str(dt.nodes[n1].feat) + ">" + str(threshold) + edge.attr['fontsize'] = 10 + edge.attr['arrowsize'] = 0.8 + #instance path in blue + if ((n1,children_right) in edges_instance): + edge.attr['style'] = 'dashed' + + g.add_subgraph(name='legend') + create_legend(g) + + # saving file + g.layout(prog='dot') + return(g.to_string()) +# +#============================================================================== +def visualize_expl(dt, instance, expl): + """ + Visualize a DT with graphviz and plot the running instance. + """ + g = pygraphviz.AGraph(directed=True, strict=True) + g.edge_attr['dir'] = 'forward' + g.graph_attr['rankdir'] = 'TB' + + # non-terminal nodes + for n in dt.nodes: + g.add_node(n, label=str(dt.nodes[n].feat)) + node = g.get_node(n) + node.attr['shape'] = 'circle' + node.attr['fontsize'] = 13 + + # terminal nodes + for n in dt.terms: + g.add_node(n, label=str(dt.terms[n])) + node = g.get_node(n) + node.attr['shape'] = 'square' + node.attr['fontsize'] = 13 + + #path that follows the instance - colored in blue + instance = [np.float32(i[1]) for i in instance] + path, term_id_node = dt.execute(instance) + edges_instance = [] + for i in range (len(path)-1) : + edges_instance.append((path[i], path[i+1])) + + for n1 in dt.nodes: + threshold = dt.nodes[n1].threshold + + children_left = dt.nodes[n1].children_left + g.add_edge(n1, children_left) + edge = g.get_edge(n1, children_left) + edge.attr['label'] = str(dt.nodes[n1].feat) + "<=" + str(threshold) + edge.attr['fontsize'] = 10 + edge.attr['arrowsize'] = 0.8 + #instance path in blue + if ((n1,children_left) in edges_instance): + edge.attr['style'] = 'dashed' + if edge.attr['label'] in expl : + edge.attr['color'] = 'blue' + + children_right = dt.nodes[n1].children_right + g.add_edge(n1, children_right) + edge = g.get_edge(n1, children_right) + edge.attr['label'] = str(dt.nodes[n1].feat) + ">" + str(threshold) + edge.attr['fontsize'] = 10 + edge.attr['arrowsize'] = 0.8 + #instance path in blue + if ((n1,children_right) in edges_instance): + edge.attr['style'] = 'dashed' + if edge.attr['label'] in expl : + edge.attr['color'] = 'blue' + + g.add_subgraph(name='legend') + create_legend(g) + + g.layout(prog='dot') + return(g.to_string()) diff --git a/pages/application/DecisionTree/utils/upload_tree.py b/pages/application/DecisionTree/utils/upload_tree.py new file mode 100644 index 0000000000000000000000000000000000000000..86c9430436ddd5a1fa96b2b9d45b32dd31dea34c --- /dev/null +++ b/pages/application/DecisionTree/utils/upload_tree.py @@ -0,0 +1,464 @@ +#!/usr/bin/env python +#-*- coding:utf-8 -*- +## +## tree.py (reuses parts of the code of SHAP) +## +## Created on: Dec 7, 2018 +## Author: Nina Narodytska +## E-mail: narodytska@vmware.com +## + +# +#============================================================================== +from anytree import Node, RenderTree,AsciiStyle +import json +import numpy as np +import math +import os + + +# +#============================================================================== +class xgnode(Node): + def __init__(self, id, parent = None): + Node.__init__(self, id, parent) + self.id = id # node value + self.name = None + self.left_node_id = -1 # left child + self.right_node_id = -1 # right child + + self.feature = -1 + self.threshold = None + self.values = -1 + #iai + self.split = None + + def __str__(self): + pref = ' ' * self.depth + if (len(self.children) == 0): + return (pref+ f"leaf:{self.id} {self.values}") + else: + if(self.name is None): + if (self.threshold is None): + return (pref+ f"({self.id}) f{self.feature}") + else: + return (pref+ f"({self.id}) f{self.feature} = {self.threshold}") + else: + if (self.threshold is None): + return (pref+ f"({self.id}) \"{self.name}\"") + else: + return (pref+ f"({self.id}) \"{self.name}\" = {self.threshold}") + +# +#============================================================================== + +def walk_tree(node): + if (len(node.children) == 0): + # leaf + print(node) + else: + print(node) + walk_tree(node.children[0]) + walk_tree(node.children[1]) + + +# +#============================================================================== +def scores_tree(node, sample): + if (len(node.children) == 0): + # leaf + return node.values + else: + feature_branch = node.feature + sample_value = sample[feature_branch] + assert(sample_value is not None) + if(sample_value < node.threshold): + return scores_tree(node.children[0], sample) + else: + return scores_tree(node.children[1], sample) + + + +# +#============================================================================== +def get_json_tree(model, tool, maxdepth=None, fname=None): + """ + returns the dtree in JSON format + """ + jt = None + if tool == "DL85": + jt = model.tree_ + elif tool == "IAI": + fname = os.path.splitext(os.path.basename(fname))[0] + dir_name = os.path.join("temp", f"{tool}{maxdepth}") + try: + os.stat(dir_name) + except: + os.makedirs(dir_name) + iai_json = os.path.join(dir_name, fname+'.json') + model.write_json(iai_json) + print(f'load JSON tree from {iai_json} ...') + with open(iai_json) as fp: + jt = json.load(fp) + elif tool == "ITI": + print(f'load JSON tree from {model.json_name} ...') + with open(model.json_name) as fp: + jt = json.load(fp) + #else: + # assert False, 'Unhandled model type: {0}'.format(self.tool) + + return jt + +# +#============================================================================== +class UploadedDecisionTree: + """ A decision tree. + This object provides a common interface to many different types of models. + """ + def __init__(self, model, tool, fname, maxdepth, feature_names=None, nb_classes = 0): + self.tool = tool + self.model = model + self.tree = None + self.depth = None + self.n_nodes = None + json_tree = get_json_tree(self.model, self.tool, maxdepth, fname) + self.tree, self.n_nodes, self.depth = self.build_tree(json_tree, feature_names) + + print("c #nodes:", self.n_nodes) + print("c depth:", self.depth) + + + def print_tree(self): + print("DT model:") + walk_tree(self.tree) + + + def dump(self, fvmap, filename=None, maxdepth=None, output='temp', feat_names=None): + """ + save the dtree and data map in .dt/.map file + """ + + def walk_tree(node, domains, internal, terminal): + """ + extract internal (non-term) & terminal nodes + """ + if (len(node.children) == 0): # leaf node + terminal.append((node.id, node.values)) + else: + assert (node.children[0].id == node.left_node_id) + assert (node.children[1].id == node.right_node_id) + + f = f"f{node.feature}" + + if self.tool == "DL85": + l,r = (1,0) + internal.append((node.id, f, l, node.children[0].id)) + internal.append((node.id, f, r, node.children[1].id)) + + elif self.tool == "ITI": + #l,r = (0,1) + if len(fvmap[f]) > 2: + n = 0 + for v in fvmap[f]: + if (fvmap[f][v][2] == node.threshold) and \ + (fvmap[f][v][1] == True): + l = v + n = n + 1 + if (fvmap[f][v][2] == node.threshold) and \ + (fvmap[f][v][1] == False): + r = v + n = n + 1 + + assert (n == 2) + + elif (fvmap[f][0][2] == node.threshold): + l,r = (0,1) + else: + assert (fvmap[f][1][2] == node.threshold) + l,r = (1,0) + + internal.append((node.id, f, l, node.children[0].id)) + internal.append((node.id, f, r, node.children[1].id)) + + elif self.tool == "IAI": + left, right = [], [] + for p in fvmap[f]: + if fvmap[f][p][1] == True: + assert (fvmap[f][p][2] in node.split) + if node.split[fvmap[f][p][2]]: + left.append(p) + else: + right.append(p) + + internal.extend([(node.id, f, l, node.children[0].id) for l in left]) + internal.extend([(node.id, f, r, node.children[1].id) for r in right]) + + elif self.tool == 'SKL': + left, right = [], [] + for j in domains[f]: #[(j, fvmap[f][j][2]) for j in fvmap[f] if(fvmap[f][j][1])]: + if np.float32(fvmap[f][j][2]) <= node.threshold: + left.append(j) + else: + right.append(j) + + internal.extend([(node.id, f, l, node.children[0].id) for l in left]) + internal.extend([(node.id, f, r, node.children[1].id) for r in right]) + + dom0, dom1 = dict(), dict() + dom0.update(domains) + dom1.update(domains) + dom0[f] = left + dom1[f] = right + + else: + assert False, 'Unhandled model type: {0}'.format(self.tool) + + + internal, terminal = walk_tree(node.children[0], dom0, internal, terminal) + internal, terminal = walk_tree(node.children[1], dom1, internal, terminal) + + return internal, terminal + + domains = {f:[j for j in fvmap[f] if((fvmap[f][j][1]))] for f in fvmap} + internal, terminal = walk_tree(self.tree, domains, [], []) + + if filename and maxdepth: + fname = os.path.splitext(os.path.basename(filename))[0] + dir_name = os.path.join(output, 'tree', fname) + dir_name = os.path.join(dir_name, f"{self.tool}{maxdepth}") + + if self.tool == 'ITI': + dir_name = os.path.join(dir_name, self.tool) + elif filename: + fname = os.path.splitext(os.path.basename(filename))[0] + dir_name = os.path.join(output, f'tree/{fname}/{self.tool}') + else: + fname = "tree" + dir_name = os.path.join(output) + + try: + os.stat(dir_name) + except: + os.makedirs(dir_name) + + fname = os.path.join(dir_name, fname+'.dt') + print("saving dtree to ", fname) + + with open(fname, 'w') as fp: + fp.write(f"{self.n_nodes}\n{self.tree.id}\n") + fp.write(f"I {' '.join(dict.fromkeys([str(i) for i,_,_,_ in internal]))}\n") + fp.write(f"T {' '.join([str(i) for i,_ in terminal ])}\n") + for i,c in terminal: + fp.write(f"{i} T {c}\n") + for i,f, j, n in internal: + fp.write(f"{i} {f} {j} {n}\n") + + if filename and maxdepth: + fname = os.path.splitext(os.path.basename(filename))[0] + dir_name = os.path.join(output, 'map', fname) + if self.tool == "ITI": + dir_name = os.path.join(dir_name, self.tool) + else: + dir_name = os.path.join(dir_name, f'{self.tool}{maxdepth}') + elif filename: + fname = os.path.splitext(os.path.basename(filename))[0] + dir_name = os.path.join(output, f'map/{fname}/{self.tool}') + else: + fname = "tree" + dir_name = os.path.join(output) + + try: + os.stat(dir_name) + except: + os.makedirs(dir_name) + + fname = os.path.join(dir_name, fname+'.map') + print("saving dtree map to ", fname) + + with open(fname, 'w') as fp: + fp.write("Categorical\n") + fp.write(f"{len(fvmap)}\n") + for f in fvmap: + for v in fvmap[f]: + if (fvmap[f][v][1] == True): + fp.write(f"{f} {v} ={fvmap[f][v][2]}\n") + if (fvmap[f][v][1] == False) and self.tool == "ITI": + fp.write(f"{f} {v} !={fvmap[f][v][2]}\n") + + if feat_names is not None: + if filename: + fname = os.path.splitext(os.path.basename(filename))[0] + fname = os.path.join(dir_name, fname+'.txt') + else: + fname = os.path.join(dir_name, 'map.txt') + + print("saving feature map to ", fname) + + with open(fname, 'w') as fp: + for i,fid in enumerate(feat_names): + f=f'f{i}' + fp.write(f'{fid}:{f},'+",".join([f'{fvmap[f][v][2]}:{v}' for v in fvmap[f] if(fvmap[f][v][1])])+'\n') + # + print('Done') + # end dump fct + + def build_tree(self, json_tree=None, feature_names=None): + + def extract_data(json_node, idx, depth=0, root=None, feature_names=None): + """ + Incremental Tree Inducer / DL8.5 + """ + if (root is None): + node = xgnode(idx) + else: + node = xgnode(idx, parent = root) + + if "feat" in json_node: + + if self.tool == "ITI": #f0, f1, ...,fn + node.feature = json_node["feat"][1:] + else: + node.feature = json_node["feat"] #json DL8.5 + if (feature_names is not None): + node.name = feature_names[node.feature] + + if self.tool == "ITI": + node.threshold = json_node[json_node["feat"]] + + node.left_node_id = idx + 1 + _, idx, d1 = extract_data(json_node['left'], idx+1, depth+1, node, feature_names) + node.right_node_id = idx + 1 + _, idx, d2 = extract_data(json_node['right'], idx+1, depth+1, node, feature_names) + depth = max(d1, d2) + + elif "value" in json_node: + node.values = json_node["value"] + + return node, idx, depth + + + def extract_iai(lnr, json_tree, feature_names = None): + """ + Interpretable AI tree + """ + + json_tree = json_tree['tree_'] + nodes = [] + depth = 0 + for i, json_node in enumerate(json_tree["nodes"]): + if json_node["parent"] == -2: + node = xgnode(json_node["id"]) + else: + root = nodes[json_node["parent"] - 1] + node = xgnode(json_node["id"], parent = root) + + assert (json_node["parent"] > 0) + assert (root.id == json_node["parent"]) + + if json_node["split_type"] == "LEAF": + #node.values = target[json_node["fit"]["class"] - 1] + ##assert json_node["fit"]["probs"][node.values] == 1.0 + node.values = lnr.get_classification_label(node.id) + depth = max(depth, lnr.get_depth(node.id)) + + assert (json_node["lower_child"] == -2 and json_node["upper_child"] == -2) + + elif json_node["split_type"] == "MIXED": + #node.feature = json_node["split_mixed"]["categoric_split"]["feature"] - 1 + #node.left_node_id = json_node["lower_child"] + #node.right_node_id = json_node["upper_child"] + + node.feature = lnr.get_split_feature(node.id) + node.left_node_id = lnr.get_lower_child(node.id) + node.right_node_id = lnr.get_upper_child(node.id) + node.split = lnr.get_split_categories(node.id) + + + assert (json_node["split_mixed"]["categoric_split"]["feature"] > 0) + assert (json_node["lower_child"] > 0) + assert (json_node["upper_child"] > 0) + + else: + assert False, 'Split feature is not \"categoric_split\"' + + nodes.append(node) + + return nodes[0], json_tree["node_count"], depth + + + def extract_skl(tree_, classes_, feature_names=None): + """ + scikit-learn tree + """ + + def get_CART_tree(tree_): + n_nodes = tree_.node_count + children_left = tree_.children_left + children_right = tree_.children_right + #feature = tree_.feature + #threshold = tree_.threshold + #values = tree_.value + node_depth = np.zeros(shape=n_nodes, dtype=np.int64) + + is_leaf = np.zeros(shape=n_nodes, dtype=bool) + stack = [(0, -1)] # seed is the root node id and its parent depth + while len(stack) > 0: + node_id, parent_depth = stack.pop() + node_depth[node_id] = parent_depth + 1 + + # If we have a test node + if (children_left[node_id] != children_right[node_id]): + stack.append((children_left[node_id], parent_depth + 1)) + stack.append((children_right[node_id], parent_depth + 1)) + else: + is_leaf[node_id] = True + + return children_left, children_right, is_leaf, node_depth + + children_left, children_right, is_leaf, node_depth = get_CART_tree(tree_) + feature = tree_.feature + threshold = tree_.threshold + values = tree_.value + m = tree_.node_count + assert (m > 0), "Empty tree" + + def extract_data(idx, root = None, feature_names = None): + i = idx + assert (i < m), "Error index node" + if (root is None): + node = xgnode(i) + else: + node = xgnode(i, parent = root) + if is_leaf[i]: + node.values = classes_[np.argmax(values[i])] + else: + node.feature = feature[i] + if (feature_names): + node.name = feature_names[feature[i]] + node.threshold = threshold[i] + node.left_node_id = children_left[i] + node.right_node_id = children_right[i] + extract_data(node.left_node_id, node, feature_names) + extract_data(node.right_node_id, node, feature_names) + + return node + + root = extract_data(0, None, feature_names) + return root, tree_.node_count, tree_.max_depth + + + + root, node_count, maxdepth = None, None, None + + if(self.tool == 'SKL'): + root, node_count, maxdepth = extract_skl(self.model.tree_, self.model.classes_, feature_names) + + if json_tree: + if self.tool == "IAI": + root, node_count, maxdepth = extract_iai(self.model, json_tree, feature_names) + else: + root,_,maxdepth = extract_data(json_tree, 1, 0, None, feature_names) + node_count = json.dumps(json_tree).count('feat') + json.dumps(json_tree).count('value') + + return root, node_count, maxdepth + \ No newline at end of file diff --git a/pages/application/application.py b/pages/application/application.py index 16996a6aa067aa7bf5c30bd20b93793995d256c1..a27efaccd6cccf04b6237855391ef84cf1d5adec 100644 --- a/pages/application/application.py +++ b/pages/application/application.py @@ -19,7 +19,6 @@ class Model(): self.ml_model = '' self.pretrained_model = '' - self.typ_data = '' self.instance = '' @@ -34,10 +33,9 @@ class Model(): self.component_class = self.dict_components[self.ml_model] self.component_class = globals()[self.component_class] - def update_pretrained_model(self, pretrained_model_update, typ_data): + def update_pretrained_model(self, pretrained_model_update, filename_model): self.pretrained_model = pretrained_model_update - self.typ_data = typ_data - self.component = self.component_class(self.pretrained_model, self.typ_data) + self.component = self.component_class(self.pretrained_model, filename_model) def update_instance(self, instance, enum, xtype, solver="g3"): self.instance = instance diff --git a/requirements.txt b/requirements.txt index db4cd1fddde22135bee76ca884fda1e72ab6f2c2..aab2fcaf8910657e4d1cd96b58230fefca5b5dbc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,4 +11,5 @@ scipy>=1.2.1 dash_bootstrap_components dash_interactive_graphviz python-sat[pblib,aiger] -pygraphviz \ No newline at end of file +pygraphviz +anytree \ No newline at end of file diff --git a/utils.py b/utils.py index a3afff7ac45293e228d2a82e510de33d05ec77b5..ed7d1b55c114e8f105d18f71ad01fed4e21caafa 100644 --- a/utils.py +++ b/utils.py @@ -12,14 +12,14 @@ def parse_contents_graph(contents, filename): try: if '.pkl' in filename: data = pickle.load(io.BytesIO(decoded)) - typ = 'pkl' except Exception as e: print(e) return html.Div([ 'There was an error processing this file.' ]) - return data, typ + return data + def parse_contents_instance(contents, filename): content_type, content_string = contents.split(',')