From 8b31c67631aa8ce02d6b2b1b54bab5549f069f33 Mon Sep 17 00:00:00 2001 From: Caroline DE POURTALES <cdepourt@montana.irit.fr> Date: Fri, 4 Mar 2022 11:07:14 +0100 Subject: [PATCH] integration of only pickle file working, need to work on the model when categorical variable --- .gitignore | 6 +- .../DecisionTree/DecisionTreeComponent.py | 2 + pages/application/DecisionTree/utils/dtree.py | 278 +++++------------- pages/application/DecisionTree/utils/dtviz.py | 73 ++--- utils.py | 5 +- 5 files changed, 126 insertions(+), 238 deletions(-) diff --git a/.gitignore b/.gitignore index 4a1ff58..ac7f262 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,8 @@ __pycache__ pages/application/DecisionTree/utils/__pycache__ pages/application/DecisionTree/__pycache__ pages/application/__pycache__ -decision_tree_classifier_20170212.pkl \ No newline at end of file +decision_tree_classifier_20170212.pkl +push_command +adult.pkl +adult_data_00000.inst +iris_00000.txt \ No newline at end of file diff --git a/pages/application/DecisionTree/DecisionTreeComponent.py b/pages/application/DecisionTree/DecisionTreeComponent.py index 6ecb24c..93716c0 100644 --- a/pages/application/DecisionTree/DecisionTreeComponent.py +++ b/pages/application/DecisionTree/DecisionTreeComponent.py @@ -5,6 +5,7 @@ import dash_interactive_graphviz import os.path from os import path +import numpy as np class DecisionTreeComponent(): @@ -30,6 +31,7 @@ class DecisionTreeComponent(): def update_with_explicability(self, instance, enum, xtype, solver) : instance = str(instance).strip().split(',') + instance = list(map(lambda i: tuple([i[0], np.float32(i[1])]), [i.split('=') for i in instance])) dot_source = visualize_instance(self.dt, instance) self.network = dash_interactive_graphviz.DashInteractiveGraphviz( diff --git a/pages/application/DecisionTree/utils/dtree.py b/pages/application/DecisionTree/utils/dtree.py index c6c6bc1..2a8b75f 100644 --- a/pages/application/DecisionTree/utils/dtree.py +++ b/pages/application/DecisionTree/utils/dtree.py @@ -17,6 +17,8 @@ from pysat.card import * from pysat.examples.hitman import Hitman from pysat.formula import CNF, IDPool from pysat.solvers import Solver +import sklearn +from torch import threshold try: # for Python2 from cStringIO import StringIO @@ -32,7 +34,7 @@ class Node(): Node class. """ - def __init__(self, feat='', vals=[], threshold=None): + def __init__(self, feat='', vals=None, threshold=None, children_left= None, children_right=None): """ Constructor. """ @@ -40,8 +42,10 @@ class Node(): self.feat = feat if threshold is not None : self.threshold = threshold + self.children_left = 0 + self.children_right = 0 else : - self.vals = vals + self.vals = {} # @@ -51,12 +55,13 @@ class DecisionTree(): Simple decision tree class. """ - def __init__(self, from_dt=None, from_pickle=None, verbose=0): + def __init__(self, from_pickle=None, verbose=0): """ Constructor. """ self.verbose = verbose + self.typ="" self.nof_nodes = 0 self.nof_terms = 0 @@ -66,26 +71,17 @@ class DecisionTree(): self.paths = {} self.feats = [] self.feids = {} - self.fdoms = {} - self.fvmap = {} - # OHE mapping - OHEMap = collections.namedtuple('OHEMap', ['dir', 'opp']) - self.ohmap = OHEMap(dir={}, opp={}) - - if from_dt: - self.from_dt(from_dt) - elif from_pickle: + if from_pickle: + self.typ="pkl" + self.tree_ = '' self.from_pickle_file(from_pickle) - for f in self.feats: - for v in self.fdoms[f]: - self.fvmap[tuple([f, v])] = '{0}={1}'.format(f, v) - #problem de feature names et problem de vals dans node def from_pickle_file(self, tree): #help(_tree.Tree) - tree_ = tree.tree_ + self.tree_ = tree.tree_ + print(sklearn.tree.export_text(tree)) try: feature_names = tree.feature_names_in_ except: @@ -93,51 +89,37 @@ class DecisionTree(): feature_names = [str(i) for i in range(tree.n_features_in_)] class_names = tree.classes_ - - self.nodes = collections.defaultdict(lambda: Node(feat='', vals={})) + self.nodes = collections.defaultdict(lambda: Node(feat='', threshold=int(0), children_left=int(0), children_right=int(0))) self.terms={} - self.nof_nodes = tree_.node_count - self.nof_terms = 0 + self.nof_nodes = self.tree_.node_count self.root_node = 0 + self.feats = feature_names feature_name = [ feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!" - for i in tree_.feature] + for i in self.tree_.feature] - def recurse(feats, fdoms, node): - if tree_.feature[node] != _tree.TREE_UNDEFINED: + def recurse(node): + if self.tree_.feature[node] != _tree.TREE_UNDEFINED: name = feature_name[node] - val = tree_.threshold[node] + val = self.tree_.threshold[node] #faire une boucle for des vals ? self.nodes[int(node)].feat = name - self.nodes[int(node)].vals[int(np.round(val,4))] = int(tree_.children_left[node]) - - self.nodes[int(node)].feat = name - self.nodes[int(node)].vals[int(4854)] = int(tree_.children_right[node]) + self.nodes[int(node)].threshold = np.round(val, 4) + self.nodes[int(node)].children_left = int(self.tree_.children_left[node]) + self.nodes[int(node)].children_right = int(self.tree_.children_right[node]) - feats.add(name) - fdoms[name].add(int(np.round(val,4))) - feats, fdoms = recurse(feats, fdoms, tree_.children_left[node]) - fdoms[name].add(4854) - feats, fdoms = recurse(feats, fdoms, tree_.children_right[node]) + recurse(self.tree_.children_left[node]) + recurse(self.tree_.children_right[node]) else: - self.terms[node] = class_names[np.argmax(tree_.value[node])] - - return feats, fdoms + self.terms[node] = class_names[np.argmax(self.tree_.value[node])] - self.feats, self.fdoms = recurse(set([]), collections.defaultdict(lambda: set([])), self.root_node) - - for parent in self.nodes: - conns = collections.defaultdict(lambda: set([])) - for val, child in self.nodes[parent].vals.items(): - conns[child].add(val) - self.nodes[parent].vals = {frozenset(val): child for child, val in conns.items()} + recurse(self.root_node) self.feats = sorted(self.feats) self.feids = {f: i for i, f in enumerate(self.feats)} - self.fdoms = {f: sorted(self.fdoms[f]) for f in self.fdoms} self.nof_terms = len(self.terms) self.nof_nodes -= len(self.terms) self.nof_feats = len(self.feats) @@ -145,70 +127,6 @@ class DecisionTree(): self.paths = collections.defaultdict(lambda: []) self.extract_paths(root=self.root_node, prefix=[]) - def from_dt(self, data): - """ - Get the tree from a file pointer. - """ - - contents = StringIO(data) - - lines = contents.readlines() - - # filtering out comment lines (those that start with '#') - lines = list(filter(lambda l: not l.startswith('#'), lines)) - - # number of nodes - self.nof_nodes = int(lines[0].strip()) - - # root node - self.root_node = int(lines[1].strip()) - - # number of terminal nodes (classes) - self.nof_terms = len(lines[3][2:].strip().split()) - - # the ordered list of terminal nodes - self.terms = {} - for i in range(self.nof_terms): - nd, _, t = lines[i + 4].strip().split() - self.terms[int(nd)] = t #int(t) - - # finally, reading the nodes - self.nodes = collections.defaultdict(lambda: Node(feat='', vals={})) - self.feats = set([]) - self.fdoms = collections.defaultdict(lambda: set([])) - for line in lines[(4 + self.nof_terms):]: - # reading the tuple - nid, fid, fval, child = line.strip().split() - - # inserting it in the nodes list - self.nodes[int(nid)].feat = fid - self.nodes[int(nid)].vals[int(fval)] = int(child) - - # updating the list of features - self.feats.add(fid) - - # updaing feature domains - self.fdoms[fid].add(int(fval)) - - # adding complex node connections into consideration - for n1 in self.nodes: - conns = collections.defaultdict(lambda: set([])) - for v, n2 in self.nodes[n1].vals.items(): - conns[n2].add(v) - self.nodes[n1].vals = {frozenset(v): n2 for n2, v in conns.items()} - - # simplifying the features and their domains - self.feats = sorted(self.feats) - self.feids = {f: i for i, f in enumerate(self.feats)} - self.fdoms = {f: sorted(self.fdoms[f]) for f in self.fdoms} - - # here we assume all features are present in the tree - # if not, this value will be rewritten by self.parse_mapping() - self.nof_feats = len(self.feats) - - self.paths = collections.defaultdict(lambda: []) - self.extract_paths(root=self.root_node, prefix=[]) - def extract_paths(self, root, prefix): """ Traverse the tree and extract explicit paths. @@ -220,63 +138,17 @@ class DecisionTree(): self.paths[term].append(prefix) else: # select next node - feat, vals = self.nodes[root].feat, self.nodes[root].vals - for val in vals: - self.extract_paths(vals[val], prefix + [tuple([feat, val])]) - - def execute(self, inst, pathlits=False): - """ - Run the tree and obtain the prediction given an input instance. - """ - - root = self.root_node - depth = 0 - path = [] - - # this array is needed if we focus on the path's literals only - visited = [False for f in inst] - - while not root in self.terms: - path.append(root) - feat, vals = self.nodes[root].feat, self.nodes[root].vals - visited[self.feids[feat]] = True - tval = inst[self.feids[feat]][1] - ############### - # assert(len(vals) == 2) - next_node = root - neq = None - for vs, dest in vals.items(): - if tval in vs: - next_node = dest - break - else: - for v in vs: - if '!=' in self.fvmap[(feat, v)]: - neq = dest - break - else: - next_node = neq - # if tval not in vals: - # # go to the False branch (!=) - # for i in vals: - # if "!=" in self.fvmap[(feat,i)]: - # next_node = vals[i] - # break - # else: - # next_node = vals[tval] - - assert (next_node != root) - ############### - root = next_node - depth += 1 - - if pathlits: - # filtering out non-visited literals - for i, v in enumerate(visited): - if not v: - inst[i] = None - - return path, self.terms[root], depth + feat, threshold, children_left, children_right = self.nodes[root].feat, self.nodes[root].threshold, self.nodes[root].children_left, self.nodes[root].children_right + self.extract_paths(children_left, prefix + [tuple([feat, "<=" + str(threshold)])]) + self.extract_paths(children_right, prefix + [tuple([feat, ">"+ str(threshold)])]) + + def execute(self, inst): + inst = np.array([inst]) + path = self.tree_.decision_path(inst) + term_id_node = self.tree_.apply(inst) + term_id_node = term_id_node[0] + path = path.indices[path.indptr[0] : path.indptr[0 + 1]] + return path, term_id_node def prepare_sets(self, inst, term): """ @@ -295,21 +167,16 @@ class DecisionTree(): to_hit = [] for item in path: # if the instance disagrees with the path on this item - if inst[self.feids[item[0]]] and not inst[self.feids[item[0]]][1] in item[1]: - fv = inst[self.feids[item[0]]] - if fv[0] in self.ohmap.opp: - to_hit.append(tuple([self.ohmap.opp[fv[0]], None])) - else: - to_hit.append(fv) - - to_hit = sorted(set(to_hit)) - sets.append(tuple(to_hit)) - - if self.verbose: - if self.verbose > 1: - print('c trav. path: {0}'.format(path)) + if ("<=" in item[1] and (inst[item[0]] > np.float32(item[1][2:]))) or (">" in item[1] and (inst[item[0]] <= np.float32(item[1][1:]))) : + if "<=" in item[1] : + fv = tuple([item[0], str(inst[item[0]]), ">" , str(np.float32(item[1][2:]))]) + else : + fv = tuple([item[0], str(inst[item[0]]) , "<=" , str(np.float32(item[1][1:]))]) + to_hit.append(fv) - print('c set to hit: {0}'.format(to_hit)) + if len(to_hit)>0 : + to_hit = sorted(set(to_hit)) + sets.append(tuple(to_hit)) # returning the set of sets with no duplicates return list(dict.fromkeys(sets)) @@ -319,25 +186,32 @@ class DecisionTree(): Compute a given number of explanations. """ - inst = list(map(lambda i: tuple([i[0], int(i[1])]), [i.split('=') for i in inst])) + inst_values = [np.float32(i[1]) for i in inst] + inst_dic = {} + for i in range(len(inst)): + inst_dic[inst[i][0]] = np.float32(inst[i][1]) inst_orig = inst[:] - path, term, depth = self.execute(inst, pathlits) - - explanation = str(inst) + "\n \n" - #print('c instance: IF {0} THEN class={1}'.format(' AND '.join([self.fvmap[p] for p in inst_orig]), term)) - #print(term) - explanation += 'c instance: IF {0} THEN class={1}'.format(' AND '.join([self.fvmap[ inst_orig[self.feids[self.nodes[n].feat]] ] for n in path]), term) + "\n" - explanation +='c path len:'+ str(depth)+ "\n \n \n" + path, term = self.execute(inst_values) + + explanation = str(inst_dic) + "\n \n" + decision_path_str = "c inst : IF : " + for node_id in path: + # continue to the next node if it is a leaf node + if term == node_id: + continue - if self.ohmap.dir: - f2v = {fv[0]: fv[1] for fv in inst} + decision_path_str +="(inst[{feature}] = {value}) {inequality} {threshold}) AND ".format( + feature=self.nodes[node_id].feat, + value=inst_dic[self.nodes[node_id].feat], + inequality="<=" if inst_dic[self.nodes[node_id].feat] <= self.nodes[node_id].threshold else ">" , + threshold=self.nodes[node_id].threshold) - # updating fvmap for printing ohe features - for fo, fis in self.ohmap.dir.items(): - self.fvmap[tuple([fo, None])] = '(' + ' AND '.join([self.fvmap[tuple([fi, f2v[fi]])] for fi in fis]) + ')' + decision_path_str += "THEN " + str(self.terms[term]) + explanation += decision_path_str + "\n \n" + explanation +='c path len:'+ str(len(path))+ "\n \n \n" # computing the sets to hit - to_hit = self.prepare_sets(inst, term) + to_hit = self.prepare_sets(inst_dic, term) for type in xtype : if type == "AXp": @@ -354,11 +228,14 @@ class DecisionTree(): Enumerate abductive explanations. """ explanation = "" - with Hitman(bootstrap_with=to_hit, solver=solver, htype=htype) as hitman: + with Hitman(bootstrap_with=to_hit, solver='m22', htype=htype) as hitman: expls = [] for i, expl in enumerate(hitman.enumerate(), 1): - explanation += 'c expl: IF {0} THEN class={1}'.format(' AND '.join([self.fvmap[p] for p in sorted(expl, key=lambda p: p[0])]), term) + "\n" - + explanation += 'c expl: IF {0} THEN class={1}'.format(' AND '.join(["(inst[{feature}] = {value}) {inequality} {threshold})".format(feature=p[0], + value=p[1], + inequality=p[2], + threshold=p[3]) + for p in sorted(expl, key=lambda p: p[0])]), str(self.terms[term]))+ "\n" expls.append(expl) if i == enum: break @@ -388,9 +265,10 @@ class DecisionTree(): expls = list(reduce(process_set, to_hit, [])) explanation = "" for expl in expls: - explanation += 'c expl: IF {0} THEN class!={1}'.format(' OR '.join(['!{0}'.format(self.fvmap[p]) for p in sorted(expl, key=lambda p: p[0])]), term)+ "\n" - - + explanation += 'c expl: IF {0} THEN class!={1}'.format(' OR '.join(["inst[{feature}] {inequality} {threshold})".format(feature=p[0], + inequality="<=" if p[2]==">" else ">", + threshold=p[3]) + for p in sorted(expl, key=lambda p: p[0])]), str(self.terms[term]))+ "\n" explanation +='c nof expls:'+ str(len(expls))+ "\n" explanation +='c min expl:'+ str( min([len(e) for e in expls]))+ "\n" explanation +='c max expl:'+ str( max([len(e) for e in expls]))+ "\n" diff --git a/pages/application/DecisionTree/utils/dtviz.py b/pages/application/DecisionTree/utils/dtviz.py index d0abf06..11ac6fb 100755 --- a/pages/application/DecisionTree/utils/dtviz.py +++ b/pages/application/DecisionTree/utils/dtviz.py @@ -12,7 +12,8 @@ #============================================================================== from pages.application.DecisionTree.utils.dtree import DecisionTree import pygraphviz - +import numpy as np +import pandas as pd # #============================================================================== def visualize(dt): @@ -38,18 +39,22 @@ def visualize(dt): node.attr['shape'] = 'square' node.attr['fontsize'] = 13 - # transitions for n1 in dt.nodes: - for v in dt.nodes[n1].vals: - n2 = dt.nodes[n1].vals[v] - g.add_edge(n1, n2) - edge = g.get_edge(n1, n2) - if len(v) == 1: - edge.attr['label'] = dt.fvmap[tuple([dt.nodes[n1].feat, tuple(v)[0]])] - else: - edge.attr['label'] = '{0}'.format('\n'.join([dt.fvmap[tuple([dt.nodes[n1].feat, val])] for val in tuple(v)])) - edge.attr['fontsize'] = 10 - edge.attr['arrowsize'] = 0.8 + threshold = dt.nodes[n1].threshold + + children_left = dt.nodes[n1].children_left + g.add_edge(n1, children_left) + edge = g.get_edge(n1, children_left) + edge.attr['label'] = str(dt.nodes[n1].feat) + "<=" + str(threshold) + edge.attr['fontsize'] = 10 + edge.attr['arrowsize'] = 0.8 + + children_right = dt.nodes[n1].children_right + g.add_edge(n1, children_right) + edge = g.get_edge(n1, children_right) + edge.attr['label'] = str(dt.nodes[n1].feat) + ">" + str(threshold) + edge.attr['fontsize'] = 10 + edge.attr['arrowsize'] = 0.8 # saving file g.layout(prog='dot') @@ -61,8 +66,6 @@ def visualize_instance(dt, instance): """ Visualize a DT with graphviz and plot the running instance. """ - instance = list(map(lambda i: tuple([i[0], int(i[1])]), [i.split('=') for i in instance])) - g = pygraphviz.AGraph(directed=True, strict=True) g.edge_attr['dir'] = 'forward' g.graph_attr['rankdir'] = 'TB' @@ -82,30 +85,34 @@ def visualize_instance(dt, instance): node.attr['fontsize'] = 13 #path that follows the instance - colored in blue - path, term, depth = dt.execute(instance) + instance = [np.float32(i[1]) for i in instance] + path, term_id_node = dt.execute(instance) edges_instance = [] for i in range (len(path)-1) : edges_instance.append((path[i], path[i+1])) - edges_instance.append((path[-1],"term:"+term)) - - # transitions + for n1 in dt.nodes: - for v in dt.nodes[n1].vals: - n2 = dt.nodes[n1].vals[v] - n2_type = g.get_node(n2).attr['shape'] - g.add_edge(n1, n2) - edge = g.get_edge(n1, n2) - if len(v) == 1: - edge.attr['label'] = dt.fvmap[tuple([dt.nodes[n1].feat, tuple(v)[0]])] - else: - edge.attr['label'] = '{0}'.format('\n'.join([dt.fvmap[tuple([dt.nodes[n1].feat, val])] for val in tuple(v)])) - - #instance path in blue - if ((n1,n2) in edges_instance) or (n2_type=='square' and (n1, "term:"+ dt.terms[n2]) in edges_instance): - edge.attr['color'] = 'blue' + threshold = dt.nodes[n1].threshold + + children_left = dt.nodes[n1].children_left + g.add_edge(n1, children_left) + edge = g.get_edge(n1, children_left) + edge.attr['label'] = str(dt.nodes[n1].feat) + "<=" + str(threshold) + edge.attr['fontsize'] = 10 + edge.attr['arrowsize'] = 0.8 + #instance path in blue + if ((n1,children_left) in edges_instance): + edge.attr['color'] = 'blue' - edge.attr['fontsize'] = 10 - edge.attr['arrowsize'] = 0.8 + children_right = dt.nodes[n1].children_right + g.add_edge(n1, children_right) + edge = g.get_edge(n1, children_right) + edge.attr['label'] = str(dt.nodes[n1].feat) + ">" + str(threshold) + edge.attr['fontsize'] = 10 + edge.attr['arrowsize'] = 0.8 + #instance path in blue + if ((n1,children_right) in edges_instance): + edge.attr['color'] = 'blue' # saving file g.layout(prog='dot') diff --git a/utils.py b/utils.py index 4b194ca..27f0b5e 100644 --- a/utils.py +++ b/utils.py @@ -10,10 +10,7 @@ def parse_contents_tree(contents, filename): content_type, content_string = contents.split(',') decoded = base64.b64decode(content_string) try: - if '.dt' in filename: - data = decoded.decode('utf-8') - typ = 'dt' - elif '.pkl' in filename: + if '.pkl' in filename: data = pickle.load(io.BytesIO(decoded)) typ = 'pkl' except Exception as e: -- GitLab