diff --git a/.gitignore b/.gitignore index ed1537f039bd0b1e28e3d3fd73f3cae685e3e067..7c8041bcc2451773e574b119d7a2584292980f25 100644 --- a/.gitignore +++ b/.gitignore @@ -8,4 +8,5 @@ push_command adult.pkl adult_data_00000.inst iris_00000.txt -tests \ No newline at end of file +tests +create_pkl.py \ No newline at end of file diff --git a/callbacks.py b/callbacks.py index 327dbd32a4a69c5c859e84c8306a052c7adab4ca..95101dfe34dd17385bfdcf85504df81b948239eb 100644 --- a/callbacks.py +++ b/callbacks.py @@ -55,37 +55,37 @@ def register_callbacks(page_home, page_course, page_application, app): return None, None, None, None elif ihm_id == 'ml_pretrained_model_choice': - if value_ml_model is None : + if model_application.ml_model is None : raise PreventUpdate tree = parse_contents_graph(pretrained_model_contents, pretrained_model_filename) model_application.update_pretrained_model(tree, pretrained_model_filename) return pretrained_model_filename, None, model_application.component.network, None elif ihm_id == 'ml_instance_choice' : - if value_ml_model is None or pretrained_model_contents is None or enum is None or xtype is None: + if model_application.ml_model is None or model_application.pretrained_model is None : raise PreventUpdate instance = parse_contents_instance(instance_contents, instance_filename) model_application.update_instance(instance, enum, xtype) return pretrained_model_filename, instance_filename, model_application.component.network, model_application.component.explanation elif ihm_id == 'number_explanations' : - if value_ml_model is None or pretrained_model_contents is None or instance_contents is None or xtype is None: + if model_application.ml_model is None or model_application.pretrained_model is None or model_application.instance is None: raise PreventUpdate - instance = parse_contents_instance(instance_contents, instance_filename) + instance = parse_contents_instance(model_application.instance, instance_filename) model_application.update_instance(instance, enum, xtype) return pretrained_model_filename, instance_filename, model_application.component.network, model_application.component.explanation elif ihm_id == 'explanation_type' : - if value_ml_model is None or pretrained_model_contents is None or instance_contents is None or enum is None : + if model_application.ml_model is None or model_application.pretrained_model is None or model_application.instance is None: raise PreventUpdate - instance = parse_contents_instance(instance_contents, instance_filename) + instance = parse_contents_instance(model_application.instance, instance_filename) model_application.update_instance(instance, enum, xtype) return pretrained_model_filename, instance_filename, model_application.component.network, model_application.component.explanation elif ihm_id == 'solver_sat' : - if value_ml_model is None or pretrained_model_contents is None or instance_contents is None or enum is None or xtype is None: + if model_application.ml_model is None or model_application.pretrained_model is None or model_application.instance is None: raise PreventUpdate - instance = parse_contents_instance(instance_contents, instance_filename) + instance = parse_contents_instance(model_application.instance, instance_filename) model_application.update_instance(instance, enum, xtype, solver=solver) return pretrained_model_filename, instance_filename, model_application.component.network, model_application.component.explanation diff --git a/pages/application/DecisionTree/DecisionTreeComponent.py b/pages/application/DecisionTree/DecisionTreeComponent.py index 2612aef11c971b2cf4876134904064d2e97bafd1..8c795d7cc9ad61dbd7d3bbc50806183ec91a5764 100644 --- a/pages/application/DecisionTree/DecisionTreeComponent.py +++ b/pages/application/DecisionTree/DecisionTreeComponent.py @@ -12,7 +12,6 @@ from pages.application.DecisionTree.utils.dtviz import (visualize, visualize_expl, visualize_instance) - class DecisionTreeComponent(): def __init__(self, tree, filename_tree): @@ -21,11 +20,12 @@ class DecisionTreeComponent(): feature_names = tree.feature_names_in_ except: print("You did not dump the model with the features names") - feature_names = [str(i) for i in range(tree.n_features_in_)] + feature_names = [f'f{i}' for i in range(tree.n_features_in_)] self.uploaded_dt = UploadedDecisionTree(tree, 'SKL', filename_tree, maxdepth=tree.get_depth(), feature_names=feature_names, nb_classes=tree.n_classes_) #need a function that takes as input UploadedDecisionTree and gives DecisionTree - self.dt = DecisionTree(from_dt=self.uploaded_dt.dump()) + self.dt_format, self.map, features_names_mapping = self.uploaded_dt.convert_dt(feat_names=feature_names) + self.dt = DecisionTree(from_dt=self.dt_format, mapfile = self.map, mapping_features=features_names_mapping) dot_source = visualize(self.dt) self.network = [dbc.Row(dash_interactive_graphviz.DashInteractiveGraphviz(dot_source=dot_source, style = {"width": "60%", diff --git a/pages/application/DecisionTree/cancer.dt b/pages/application/DecisionTree/cancer.dt deleted file mode 100644 index 6b7796a65779c1c226e17f8791fc6a1aadd81501..0000000000000000000000000000000000000000 --- a/pages/application/DecisionTree/cancer.dt +++ /dev/null @@ -1,35 +0,0 @@ -21 -1 -I 1 2 4 6 8 11 12 14 17 19 -T 3 5 7 9 10 13 15 16 18 20 21 -3 T 0 -5 T 0 -7 T 0 -9 T 0 -10 T 1 -13 T 0 -15 T 0 -16 T 1 -18 T 1 -20 T 0 -21 T 1 -1 f5 17 2 -1 f5 16 11 -2 f4 17 3 -2 f4 16 4 -4 f2 17 5 -4 f2 16 6 -6 f2 19 7 -6 f2 18 8 -8 f0 17 9 -8 f0 16 10 -11 f4 17 12 -11 f4 16 17 -12 f1 17 13 -12 f1 16 14 -14 f7 17 15 -14 f7 16 16 -17 f5 15 18 -17 f5 14 19 -19 f2 17 20 -19 f2 16 21 diff --git a/pages/application/DecisionTree/iris.dt b/pages/application/DecisionTree/iris.dt deleted file mode 100644 index 77d79f3ffaaa4b715c6a8f0d3f1514a4a346d3db..0000000000000000000000000000000000000000 --- a/pages/application/DecisionTree/iris.dt +++ /dev/null @@ -1,35 +0,0 @@ -21 -1 -I 1 3 5 7 9 11 13 15 17 19 -T 2 4 6 8 10 12 14 16 18 20 21 -2 T Setosa -4 T Setosa -6 T Setosa -8 T Setosa -10 T Setosa -12 T Versicolor -14 T Versicolor -16 T Versicolor -18 T Versicolor -20 T Versicolor -21 T Virginica -1 f3 23 2 -1 f3 22 3 -3 f2 31 4 -3 f2 30 5 -5 f3 41 6 -5 f3 40 7 -7 f2 21 8 -7 f2 20 9 -9 f3 15 10 -9 f3 14 11 -11 f3 25 12 -11 f3 24 13 -13 f3 7 14 -13 f3 6 15 -15 f3 1 16 -15 f3 0 17 -17 f3 5 18 -17 f3 4 19 -19 f2 73 20 -19 f2 72 21 \ No newline at end of file diff --git a/pages/application/DecisionTree/meteo.dt b/pages/application/DecisionTree/meteo.dt deleted file mode 100644 index 3892d46e85daf5116bbe2741413f204bed6908a7..0000000000000000000000000000000000000000 --- a/pages/application/DecisionTree/meteo.dt +++ /dev/null @@ -1,8 +0,0 @@ -3 -1 -I 1 -T 2 3 -2 T - -3 T + -1 f0 5 2 -1 f0 4 3 diff --git a/pages/application/DecisionTree/utils/dtree.py b/pages/application/DecisionTree/utils/dtree.py index 5d52a233c4300ecb230b3e896b0c53893a2d8d1c..837decf059df769b92f14e3bf7612b759064714a 100644 --- a/pages/application/DecisionTree/utils/dtree.py +++ b/pages/application/DecisionTree/utils/dtree.py @@ -45,7 +45,7 @@ class DecisionTree(): Simple decision tree class. """ - def __init__(self, from_dt=None, verbose=0): + def __init__(self, from_dt=None, mapfile=None, mapping_features=None, verbose=0): """ Constructor. """ @@ -63,6 +63,8 @@ class DecisionTree(): self.fdoms = {} self.fvmap = {} + self.features_names = mapping_features + # OHE mapping OHEMap = collections.namedtuple('OHEMap', ['dir', 'opp']) self.ohmap = OHEMap(dir={}, opp={}) @@ -70,9 +72,13 @@ class DecisionTree(): if from_dt: self.from_dt(from_dt) - for f in self.feats: - for v in self.fdoms[f]: - self.fvmap[tuple([f, v])] = '{0}={1}'.format(f, v) + if mapfile: + self.parse_mapping(mapfile) + else: # no mapping is given + for f in self.feats: + for v in self.fdoms[f]: + self.fvmap[tuple([f, v])] = '{0}={1}'.format(f, v) + def from_dt(self, data): """ @@ -138,6 +144,84 @@ class DecisionTree(): self.paths = collections.defaultdict(lambda: []) self.extract_paths(root=self.root_node, prefix=[]) + def parse_mapping(self, mapfile): + """ + Parse feature-value mapping from a file. + """ + + self.fvmap = {} + + lines = mapfile.split('\n') + + if lines[0].startswith('OHE'): + for i in range(int(lines[1])): + feats = lines[i + 2].strip().split(',') + orig, ohe = feats[0], tuple(feats[1:]) + self.ohmap.dir[orig] = tuple(ohe) + for f in ohe: + self.ohmap.opp[f] = orig + + lines = lines[(int(lines[1]) + 2):] + + elif lines[0].startswith('Categorical'): + # skipping the first comment line if necessary + lines = lines[1:] + + elif lines[0].startswith('Ordinal'): + # skipping the first comment line if necessary + lines = lines[1:] + + for line in lines[1:]: + feat, val, real = line.split() + self.fvmap[tuple([feat, int(val)])] = '{0}{1}'.format(self.features_names[feat], real) + #if feat not in self.feids: + # self.feids[feat] = len(self.feids) + + #assert len(self.feids) == self.nof_feats + + def convert_to_multiedges(self): + """ + Convert ITI trees with '!=' edges to multi-edges. + """ + + # new feature domains + fdoms = collections.defaultdict(lambda: []) + + # tentative mapping relating negative and positive values + nemap = collections.defaultdict(lambda: collections.defaultdict(lambda: [None, None])) + + for fv, tval in self.fvmap.items(): + if '!=' in tval: + nemap[fv[0]][tval.split('=')[1]][0] = fv[1] + else: + fdoms[fv[0]].append(fv[1]) + nemap[fv[0]][tval.split('=')[1]][1] = fv[1] + + # a mapping from negative values to sets + fnmap = collections.defaultdict(lambda: {}) + for f in nemap: + for t, vals in nemap[f].items(): + if vals[0] != None: + fnmap[(f, frozenset({vals[0]}))] = frozenset(set(fdoms[f]).difference({vals[1]})) + + # updating node connections + for n in self.nodes: + vals = {} + for v in self.nodes[n].vals.keys(): + fn = (self.nodes[n].feat, v) + if fn in fnmap: + vals[fnmap[fn]] = self.nodes[n].vals[v] + else: + vals[v] = self.nodes[n].vals[v] + self.nodes[n].vals = vals + + # updating the domains + self.fdoms = fdoms + + # extracting the paths again + self.paths = collections.defaultdict(lambda: []) + self.extract_paths(root=self.root_node, prefix=[]) + def extract_paths(self, root, prefix): """ Traverse the tree and extract explicit paths. @@ -223,9 +307,9 @@ class DecisionTree(): for path in paths: to_hit = [] for item in path: + fv = inst[self.feids[item[0]]] # if the instance disagrees with the path on this item - if inst[self.feids[item[0]]] and not inst[self.feids[item[0]]][1] in item[1]: - fv = inst[self.feids[item[0]]] + if fv and not fv[1] in item[1]: if fv[0] in self.ohmap.opp: to_hit.append(tuple([self.ohmap.opp[fv[0]], None])) else: @@ -251,7 +335,7 @@ class DecisionTree(): inst_dic = {} for i in range(len(inst)): inst_dic[inst[i][0]] = np.float32(inst[i][1]) - path, term, depth = self.execute(inst) + path, term, depth = self.execute(inst, pathlits) #contaiins all the elements for explanation explanation_dic = {} @@ -259,10 +343,18 @@ class DecisionTree(): explanation_dic["Instance : "] = str(inst_dic) #decision path - decision_path_str = 'IF {0} THEN class={1}'.format(' AND '.join([self.fvmap[inst[self.feids[self.nodes[n].feat]] ] for n in path]), term) + decision_path_str = 'IF {0} THEN class={1}'.format(' AND '.join([str(inst[self.feids[self.nodes[n].feat]]) for n in path]), term) explanation_dic["Decision path of instance : "] = decision_path_str explanation_dic["Decision path length : "] = 'Path length is :'+ str(depth) + + if self.ohmap.dir: + f2v = {fv[0]: fv[1] for fv in inst} + + # updating fvmap for printing ohe features + for fo, fis in self.ohmap.dir.items(): + self.fvmap[tuple([fo, None])] = '(' + ' AND '.join([self.fvmap[tuple([fi, f2v[fi]])] for fi in fis]) + ')' + # computing the sets to hit to_hit = self.prepare_sets(inst, term) for type in xtype : @@ -284,7 +376,7 @@ class DecisionTree(): expls = [] for i, expl in enumerate(hitman.enumerate(), 1): list_expls.append([ str(p[0]) + "=" + str(p[1]) for p in expl]) - list_expls_str.append('Explanation: IF {0} THEN class={1}'.format(' AND '.join([self.fvmap[p] for p in sorted(expl, key=lambda p: p[0])]), term)) + list_expls_str.append('Explanation: IF {0} THEN class={1}'.format(' AND '.join([str(p) for p in sorted(expl, key=lambda p: p[0])]), term)) expls.append(expl) if i == enum: @@ -317,7 +409,7 @@ class DecisionTree(): list_expls_str = [] explanation = {} for expl in expls: - list_expls_str.append('Contrastive: IF {0} THEN class!={1}'.format(' OR '.join(['!{0}'.format(self.fvmap[p]) for p in sorted(expl, key=lambda p: p[0])]), term)) + list_expls_str.append('Contrastive: IF {0} THEN class!={1}'.format(' OR '.join(['!{0}'.format(str(p)) for p in sorted(expl, key=lambda p: p[0])]), term)) explanation["List of contrastive explanation(s)"] = list_expls_str explanation["Number of contrastive explanation(s) : "]=str(len(expls)) diff --git a/pages/application/DecisionTree/utils/dtviz.py b/pages/application/DecisionTree/utils/dtviz.py index a54d879ea2944a2de0c58b8c4e2baa4b40eee3c6..681ccca8ad583df7b49358d7a7d1ccaed508ae97 100755 --- a/pages/application/DecisionTree/utils/dtviz.py +++ b/pages/application/DecisionTree/utils/dtviz.py @@ -50,7 +50,7 @@ def visualize(dt): # non-terminal nodes for n in dt.nodes: - g.add_node(n, label=dt.nodes[n].feat) + g.add_node(n, label=dt.features_names[dt.nodes[n].feat]) node = g.get_node(n) node.attr['shape'] = 'circle' node.attr['fontsize'] = 13 @@ -98,7 +98,7 @@ def visualize_instance(dt, instance): # non-terminal nodes for n in dt.nodes: - g.add_node(n, label=dt.nodes[n].feat) + g.add_node(n, label=dt.features_names[dt.nodes[n].feat]) node = g.get_node(n) node.attr['shape'] = 'circle' node.attr['fontsize'] = 13 @@ -156,7 +156,7 @@ def visualize_expl(dt, instance, expl): # non-terminal nodes for n in dt.nodes: - g.add_node(n, label=dt.nodes[n].feat) + g.add_node(n, label=dt.features_names[dt.nodes[n].feat]) node = g.get_node(n) node.attr['shape'] = 'circle' node.attr['fontsize'] = 13 diff --git a/pages/application/DecisionTree/utils/save/dtree save.py b/pages/application/DecisionTree/utils/save/dtree save.py deleted file mode 100644 index 6bd118dfd8cb6ca577338d8a8189cb00cea33a68..0000000000000000000000000000000000000000 --- a/pages/application/DecisionTree/utils/save/dtree save.py +++ /dev/null @@ -1,290 +0,0 @@ -#!/usr/bin/env python -#-*- coding:utf-8 -*- -## -## dtree.py -## -## Created on: Jul 6, 2020 -## Author: Alexey Ignatiev -## E-mail: alexey.ignatiev@monash.edu -## - -# -#============================================================================== -from __future__ import print_function - -import collections -from functools import reduce - -import sklearn -from pysat.card import * -from pysat.examples.hitman import Hitman -from pysat.formula import CNF, IDPool -from pysat.solvers import Solver -from torch import threshold - -try: # for Python2 - from cStringIO import StringIO -except ImportError: # for Python3 - from io import StringIO - -import numpy as np -from dash import dcc, html -from sklearn.tree import _tree - - -# -#============================================================================== -class Node(): - """ - Node class. - """ - - def __init__(self, feat='', vals=None, threshold=None, children_left= None, children_right=None): - """ - Constructor. - """ - - self.feat = feat - if threshold is not None : - self.threshold = threshold - self.children_left = 0 - self.children_right = 0 - else : - self.vals = {} - - -# -#============================================================================== -class DecisionTree(): - """ - Simple decision tree class. - """ - - def __init__(self, from_pickle=None, verbose=0): - """ - Constructor. - """ - - self.verbose = verbose - self.typ="" - - self.nof_nodes = 0 - self.nof_terms = 0 - self.root_node = None - self.terms = [] - self.nodes = {} - self.paths = {} - self.feats = [] - self.feids = {} - - if from_pickle: - self.typ="pkl" - self.tree_ = '' - self.from_pickle_file(from_pickle) - - #problem de feature names et problem de vals dans node - def from_pickle_file(self, tree): - #help(_tree.Tree) - self.tree_ = tree.tree_ - #print(sklearn.tree.export_text(tree)) - try: - feature_names = tree.feature_names_in_ - except: - print("You did not dump the model with the features names") - feature_names = [str(i) for i in range(tree.n_features_in_)] - - class_names = tree.classes_ - self.nodes = collections.defaultdict(lambda: Node(feat='', threshold=int(0), children_left=int(0), children_right=int(0))) - self.terms={} - self.nof_nodes = self.tree_.node_count - self.root_node = 0 - self.feats = feature_names - - feature_name = [ - feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!" - for i in self.tree_.feature] - - def recurse(node): - if self.tree_.feature[node] != _tree.TREE_UNDEFINED: - name = feature_name[node] - val = self.tree_.threshold[node] - - #faire une boucle for des vals ? - self.nodes[int(node)].feat = name - self.nodes[int(node)].threshold = np.round(val, 4) - self.nodes[int(node)].children_left = int(self.tree_.children_left[node]) - self.nodes[int(node)].children_right = int(self.tree_.children_right[node]) - - recurse(self.tree_.children_left[node]) - recurse(self.tree_.children_right[node]) - - else: - self.terms[node] = class_names[np.argmax(self.tree_.value[node])] - - recurse(self.root_node) - - self.feats = sorted(self.feats) - self.feids = {f: i for i, f in enumerate(self.feats)} - self.nof_terms = len(self.terms) - self.nof_nodes -= len(self.terms) - self.nof_feats = len(self.feats) - - self.paths = collections.defaultdict(lambda: []) - self.extract_paths(root=self.root_node, prefix=[]) - - def extract_paths(self, root, prefix): - """ - Traverse the tree and extract explicit paths. - """ - - if root in self.terms.keys(): - # store the path - term = self.terms[root] - self.paths[term].append(prefix) - else: - # select next node - feat, threshold, children_left, children_right = self.nodes[root].feat, self.nodes[root].threshold, self.nodes[root].children_left, self.nodes[root].children_right - self.extract_paths(children_left, prefix + [tuple([feat, "<=" + str(threshold)])]) - self.extract_paths(children_right, prefix + [tuple([feat, ">"+ str(threshold)])]) - - def execute(self, inst): - inst = np.array([inst]) - path = self.tree_.decision_path(inst) - term_id_node = self.tree_.apply(inst) - term_id_node = term_id_node[0] - path = path.indices[path.indptr[0] : path.indptr[0 + 1]] - return path, term_id_node - - def prepare_sets(self, inst, term): - """ - Hitting set based encoding of the problem. - (currently not incremental -- should be fixed later) - """ - - sets = [] - for t, paths in self.paths.items(): - # ignoring the right class - if term in self.terms.keys() and self.terms[term] == t: - continue - - # computing the sets to hit - for path in paths: - to_hit = [] - for item in path: - # if the instance disagrees with the path on this item - if ("<=" in item[1] and (inst[item[0]] > np.float32(item[1][2:]))) or (">" in item[1] and (inst[item[0]] <= np.float32(item[1][1:]))) : - if "<=" in item[1] : - fv = tuple([item[0], str(inst[item[0]]), ">" , str(np.float32(item[1][2:]))]) - else : - fv = tuple([item[0], str(inst[item[0]]) , "<=" , str(np.float32(item[1][1:]))]) - to_hit.append(fv) - - if len(to_hit)>0 : - to_hit = sorted(set(to_hit)) - sets.append(tuple(to_hit)) - - # returning the set of sets with no duplicates - return list(dict.fromkeys(sets)) - - def explain(self, inst, enum=1, pathlits=False, xtype = ["AXp"], solver='g3', htype='sorted'): - """ - Compute a given number of explanations. - """ - - inst_values = [np.float32(i[1]) for i in inst] - inst_dic = {} - for i in range(len(inst)): - inst_dic[inst[i][0]] = np.float32(inst[i][1]) - path, term = self.execute(inst_values) - - #contaiins all the elements for explanation - explanation_dic = {} - #instance plotting - explanation_dic["Instance : "] = str(inst_dic) - - #decision path - decision_path_str = "IF : " - for node_id in path: - # continue to the next node if it is a leaf node - if term == node_id: - continue - - decision_path_str +="(inst[{feature}] = {value}) {inequality} {threshold}) AND ".format( - feature=self.nodes[node_id].feat, - value=inst_dic[self.nodes[node_id].feat], - inequality="<=" if inst_dic[self.nodes[node_id].feat] <= self.nodes[node_id].threshold else ">" , - threshold=self.nodes[node_id].threshold) - - decision_path_str += "THEN " + str(self.terms[term]) - explanation_dic["Decision path of instance : "] = decision_path_str - explanation_dic["Decision path length : "] = 'Path length is :'+ str(len(path)) - - # computing the sets to hit - to_hit = self.prepare_sets(inst_dic, term) - - for type in xtype : - if type == "AXp": - explanation_dic.update(self.enumerate_abductive(to_hit, enum, solver, htype, term)) - else : - explanation_dic.update(self.enumerate_contrastive(to_hit, term)) - - return explanation_dic - - def enumerate_abductive(self, to_hit, enum, solver, htype, term): - """ - Enumerate abductive explanations. - """ - list_expls = [] - list_expls_str = [] - explanation = {} - with Hitman(bootstrap_with=to_hit, solver='m22', htype=htype) as hitman: - expls = [] - for i, expl in enumerate(hitman.enumerate(), 1): - list_expls.append([ p[0] + p[2] + p[3] for p in expl]) - list_expls_str.append('Explanation: IF {0} THEN class={1}'.format(' AND '.join(["(inst[{feature}] = {value}) {inequality} {threshold})".format(feature=p[0], - value=p[1], - inequality=p[2], - threshold=p[3]) - for p in sorted(expl, key=lambda p: p[0])]), str(self.terms[term]))) - expls.append(expl) - if i == enum: - break - explanation["List of path explanation(s)"] = list_expls - explanation["List of abductive explanation(s)"] = list_expls_str - explanation["Number of abductive explanation(s) : "] = str(i) - explanation["Minimal abductive explanation : "] = str( min([len(e) for e in expls])) - explanation["Maximal abductive explanation : "] = str( max([len(e) for e in expls])) - explanation["Average abductive explanation : "] = '{0:.2f}'.format(sum([len(e) for e in expls]) / len(expls)) - - return explanation - - def enumerate_contrastive(self, to_hit, term): - """ - Enumerate contrastive explanations. - """ - - def process_set(done, target): - for s in done: - if s <= target: - break - else: - done.append(target) - return done - - to_hit = [set(s) for s in to_hit] - to_hit.sort(key=lambda s: len(s)) - expls = list(reduce(process_set, to_hit, [])) - list_expls_str = [] - explanation = {} - for expl in expls: - list_expls_str.append('Contrastive: IF {0} THEN class!={1}'.format(' OR '.join(["inst[{feature}] {inequality} {threshold})".format(feature=p[0], - inequality="<=" if p[2]==">" else ">", - threshold=p[3]) - for p in sorted(expl, key=lambda p: p[0])]), str(self.terms[term]))) - explanation["List of contrastive explanation(s)"] = list_expls_str - explanation["Number of contrastive explanation(s) : "]=str(len(expls)) - explanation["Minimal contrastive explanation : "]= str( min([len(e) for e in expls])) - explanation["Maximal contrastive explanation : "]= str( max([len(e) for e in expls])) - explanation["Average contrastive explanation : "]='{0:.2f}'.format(sum([len(e) for e in expls]) / len(expls)) - - return explanation diff --git a/pages/application/DecisionTree/utils/save/dtviz save.py b/pages/application/DecisionTree/utils/save/dtviz save.py deleted file mode 100755 index aacf7b646dd1384157116c138a125a069670c772..0000000000000000000000000000000000000000 --- a/pages/application/DecisionTree/utils/save/dtviz save.py +++ /dev/null @@ -1,206 +0,0 @@ -#!/usr/bin/env python -#-*- coding:utf-8 -*- -## -## dtviz.py -## -## Created on: Jul 7, 2020 -## Author: Alexey Ignatiev -## E-mail: alexey.ignatiev@monash.edu -## - -import numpy as np -import pygraphviz -# -#============================================================================== -def create_legend(g): - legend = g.subgraphs()[-1] - legend.add_node("a", style = "invis") - legend.add_node("b", style = "invis") - legend.add_node("c", style = "invis") - legend.add_node("d", style = "invis") - - legend.add_edge("a","b") - edge = legend.get_edge("a","b") - edge.attr["label"] = "instance" - edge.attr["style"] = "dashed" - - legend.add_edge("c","d") - edge = legend.get_edge("c","d") - edge.attr["label"] = "instance with explanation" - edge.attr["color"] = "blue" - edge.attr["style"] = "dashed" - - -def visualize(dt): - """ - Visualize a DT with graphviz. - """ - - g = pygraphviz.AGraph(name='root', rankdir="TB") - g.is_directed() - g.is_strict() - - #g = pygraphviz.AGraph(name = "main", directed=True, strict=True) - g.edge_attr['dir'] = 'forward' - - # non-terminal nodes - for n in dt.nodes: - g.add_node(n, label=str(dt.nodes[n].feat)) - node = g.get_node(n) - node.attr['shape'] = 'circle' - node.attr['fontsize'] = 13 - - # terminal nodes - for n in dt.terms: - g.add_node(n, label=str(dt.terms[n])) - node = g.get_node(n) - node.attr['shape'] = 'square' - node.attr['fontsize'] = 13 - - for n1 in dt.nodes: - threshold = dt.nodes[n1].threshold - - children_left = dt.nodes[n1].children_left - g.add_edge(n1, children_left) - edge = g.get_edge(n1, children_left) - edge.attr['label'] = str(dt.nodes[n1].feat) + "<=" + str(threshold) - edge.attr['fontsize'] = 10 - edge.attr['arrowsize'] = 0.8 - - children_right = dt.nodes[n1].children_right - g.add_edge(n1, children_right) - edge = g.get_edge(n1, children_right) - edge.attr['label'] = str(dt.nodes[n1].feat) + ">" + str(threshold) - edge.attr['fontsize'] = 10 - edge.attr['arrowsize'] = 0.8 - - g.add_subgraph(name='legend') - create_legend(g) - - # saving file - g.layout(prog='dot') - return(g.string()) - -# -#============================================================================== -def visualize_instance(dt, instance): - """ - Visualize a DT with graphviz and plot the running instance. - """ - g = pygraphviz.AGraph(directed=True, strict=True) - g.edge_attr['dir'] = 'forward' - g.graph_attr['rankdir'] = 'TB' - - # non-terminal nodes - for n in dt.nodes: - g.add_node(n, label=str(dt.nodes[n].feat)) - node = g.get_node(n) - node.attr['shape'] = 'circle' - node.attr['fontsize'] = 13 - - # terminal nodes - for n in dt.terms: - g.add_node(n, label=str(dt.terms[n])) - node = g.get_node(n) - node.attr['shape'] = 'square' - node.attr['fontsize'] = 13 - - #path that follows the instance - colored in blue - instance = [np.float32(i[1]) for i in instance] - path, term_id_node = dt.execute(instance) - edges_instance = [] - for i in range (len(path)-1) : - edges_instance.append((path[i], path[i+1])) - - for n1 in dt.nodes: - threshold = dt.nodes[n1].threshold - - children_left = dt.nodes[n1].children_left - g.add_edge(n1, children_left) - edge = g.get_edge(n1, children_left) - edge.attr['label'] = str(dt.nodes[n1].feat) + "<=" + str(threshold) - edge.attr['fontsize'] = 10 - edge.attr['arrowsize'] = 0.8 - #instance path in blue - if ((n1,children_left) in edges_instance): - edge.attr['style'] = 'dashed' - - children_right = dt.nodes[n1].children_right - g.add_edge(n1, children_right) - edge = g.get_edge(n1, children_right) - edge.attr['label'] = str(dt.nodes[n1].feat) + ">" + str(threshold) - edge.attr['fontsize'] = 10 - edge.attr['arrowsize'] = 0.8 - #instance path in blue - if ((n1,children_right) in edges_instance): - edge.attr['style'] = 'dashed' - - g.add_subgraph(name='legend') - create_legend(g) - - # saving file - g.layout(prog='dot') - return(g.to_string()) -# -#============================================================================== -def visualize_expl(dt, instance, expl): - """ - Visualize a DT with graphviz and plot the running instance. - """ - g = pygraphviz.AGraph(directed=True, strict=True) - g.edge_attr['dir'] = 'forward' - g.graph_attr['rankdir'] = 'TB' - - # non-terminal nodes - for n in dt.nodes: - g.add_node(n, label=str(dt.nodes[n].feat)) - node = g.get_node(n) - node.attr['shape'] = 'circle' - node.attr['fontsize'] = 13 - - # terminal nodes - for n in dt.terms: - g.add_node(n, label=str(dt.terms[n])) - node = g.get_node(n) - node.attr['shape'] = 'square' - node.attr['fontsize'] = 13 - - #path that follows the instance - colored in blue - instance = [np.float32(i[1]) for i in instance] - path, term_id_node = dt.execute(instance) - edges_instance = [] - for i in range (len(path)-1) : - edges_instance.append((path[i], path[i+1])) - - for n1 in dt.nodes: - threshold = dt.nodes[n1].threshold - - children_left = dt.nodes[n1].children_left - g.add_edge(n1, children_left) - edge = g.get_edge(n1, children_left) - edge.attr['label'] = str(dt.nodes[n1].feat) + "<=" + str(threshold) - edge.attr['fontsize'] = 10 - edge.attr['arrowsize'] = 0.8 - #instance path in blue - if ((n1,children_left) in edges_instance): - edge.attr['style'] = 'dashed' - if edge.attr['label'] in expl : - edge.attr['color'] = 'blue' - - children_right = dt.nodes[n1].children_right - g.add_edge(n1, children_right) - edge = g.get_edge(n1, children_right) - edge.attr['label'] = str(dt.nodes[n1].feat) + ">" + str(threshold) - edge.attr['fontsize'] = 10 - edge.attr['arrowsize'] = 0.8 - #instance path in blue - if ((n1,children_right) in edges_instance): - edge.attr['style'] = 'dashed' - if edge.attr['label'] in expl : - edge.attr['color'] = 'blue' - - g.add_subgraph(name='legend') - create_legend(g) - - g.layout(prog='dot') - return(g.to_string()) diff --git a/pages/application/DecisionTree/utils/upload_tree.py b/pages/application/DecisionTree/utils/upload_tree.py index 99ffa7a2fda20fd0460951c3233dbcb44861a5fa..e2babad927b807c6e11c8a5ed99f256ecb976fc1 100644 --- a/pages/application/DecisionTree/utils/upload_tree.py +++ b/pages/application/DecisionTree/utils/upload_tree.py @@ -15,6 +15,7 @@ import json import numpy as np import math import os +import six # @@ -123,133 +124,79 @@ class UploadedDecisionTree: self.n_nodes = None json_tree = get_json_tree(self.model, self.tool, maxdepth, fname) self.tree, self.n_nodes, self.depth = self.build_tree(json_tree, feature_names) - - print("c #nodes:", self.n_nodes) - print("c depth:", self.depth) - def print_tree(self): print("DT model:") walk_tree(self.tree) - - - def dump(self, fvmap, filename=None, maxdepth=None, output='temp', feat_names=None): - """ - save the dtree and data map in .dt/.map file + + def convert_dt(self, feat_names): """ - + save dtree in .dt format & generate dtree map from the tree + """ def walk_tree(node, domains, internal, terminal): """ extract internal (non-term) & terminal nodes """ - if (len(node.children) == 0): # leaf node + if not len(node.children): # leaf node terminal.append((node.id, node.values)) else: - assert (node.children[0].id == node.left_node_id) - assert (node.children[1].id == node.right_node_id) - - f = f"f{node.feature}" - - if self.tool == "DL85": - l,r = (1,0) - internal.append((node.id, f, l, node.children[0].id)) - internal.append((node.id, f, r, node.children[1].id)) - - elif self.tool == "ITI": - #l,r = (0,1) - if len(fvmap[f]) > 2: - n = 0 - for v in fvmap[f]: - if (fvmap[f][v][2] == node.threshold) and \ - (fvmap[f][v][1] == True): - l = v - n = n + 1 - if (fvmap[f][v][2] == node.threshold) and \ - (fvmap[f][v][1] == False): - r = v - n = n + 1 - - assert (n == 2) - - elif (fvmap[f][0][2] == node.threshold): - l,r = (0,1) + # internal node + f = f"f{node.feature}" + left, right = [], [] + for j in domains[f]: + if self.intvs[f][j] <= node.threshold: + left.append(j) else: - assert (fvmap[f][1][2] == node.threshold) - l,r = (1,0) - - internal.append((node.id, f, l, node.children[0].id)) - internal.append((node.id, f, r, node.children[1].id)) - - elif self.tool == "IAI": - left, right = [], [] - for p in fvmap[f]: - if fvmap[f][p][1] == True: - assert (fvmap[f][p][2] in node.split) - if node.split[fvmap[f][p][2]]: - left.append(p) - else: - right.append(p) - - internal.extend([(node.id, f, l, node.children[0].id) for l in left]) - internal.extend([(node.id, f, r, node.children[1].id) for r in right]) - - elif self.tool == 'SKL': - left, right = [], [] - for j in domains[f]: #[(j, fvmap[f][j][2]) for j in fvmap[f] if(fvmap[f][j][1])]: - if np.float32(fvmap[f][j][2]) <= node.threshold: - left.append(j) - else: - right.append(j) + right.append(j) - internal.extend([(node.id, f, l, node.children[0].id) for l in left]) - internal.extend([(node.id, f, r, node.children[1].id) for r in right]) + internal.extend([(node.id, f, l, node.children[0].id) for l in left]) + internal.extend([(node.id, f, r, node.children[1].id) for r in right]) - dom0, dom1 = dict(), dict() - dom0.update(domains) - dom1.update(domains) - dom0[f] = left - dom1[f] = right - - else: - assert False, 'Unhandled model type: {0}'.format(self.tool) - - + dom0, dom1 = dict(), dict() + dom0.update(domains) + dom1.update(domains) + dom0[f] = left + dom1[f] = right + # internal, terminal = walk_tree(node.children[0], dom0, internal, terminal) internal, terminal = walk_tree(node.children[1], dom1, internal, terminal) return internal, terminal - domains = {f:[j for j in fvmap[f] if((fvmap[f][j][1]))] for f in fvmap} - internal, terminal = walk_tree(self.tree, domains, [], []) - #.dt - dt="" - dt += f"{self.n_nodes}\n{self.tree.id}\n" + assert (self.tool == 'SKL') + domains = {f:[j for j in range(len(self.intvs[f]))] for f in self.intvs} + internal, terminal = walk_tree(self.tree, domains, [], []) + + dt = f"{self.n_nodes}\n{self.tree.id}\n" dt += f"I {' '.join(dict.fromkeys([str(i) for i,_,_,_ in internal]))}\n" - dt += f"T {' '.join([str(i) for i,_ in terminal ])}\n" + dt += f"T {' '.join([str(i) for i,_ in terminal ])}" for i,c in terminal: - dt += f"{i} T {c}\n" + dt += f"\n{i} T {c}" for i,f, j, n in internal: - dt += f"{i} {f} {j} {n}\n" + dt += f"\n{i} {f} {j} {n}" + + map = "Ordinal\n" + map += f"{len(self.intvs)}" + for f in self.intvs: + for j,t in enumerate(self.intvs[f][:-1]): + map += f"\n{f} {j} <={t}" + map += f"\n{f} {j+1} >{t}" - #.map - map="" - map+="Categorical\n" - map+=f"{len(fvmap)}\n" - for f in fvmap: - for v in fvmap[f]: - if (fvmap[f][v][1] == True): - map+=f"{f} {v} ={fvmap[f][v][2]}\n" - if (fvmap[f][v][1] == False) and self.tool == "ITI": - map+=f"{f} {v} !={fvmap[f][v][2]}\n" - - for i,fid in enumerate(feat_names): - f=f'f{i}' - map+f'{fid}:{f},'+",".join([f'{fvmap[f][v][2]}:{v}' for v in fvmap[f] if(fvmap[f][v][1])])+'\n' - # - print(dt) - print(map) - return dt, map + features_names_mapping = {} + for i,fid in enumerate(feat_names): + f=f'f{i}' + if f in self.intvs: + features_names_mapping[f] = fid + #features_names_mapping += ",".join([f'{t}:{j}' for j,t in enumerate(self.intvs[f])])+'\n' + #thresholds = self.intvs[f][:-1]+[self.intvs[f][-2]] + #fp.write(",".join([f'{t}:{j}' for j,t in enumerate(thresholds)])+'\n') + print(dt) + print(map) + print(features_names_mapping) + return dt, map, features_names_mapping + + def build_tree(self, json_tree=None, feature_names=None): def extract_data(json_node, idx, depth=0, root=None, feature_names=None): @@ -369,7 +316,13 @@ class UploadedDecisionTree: values = tree_.value m = tree_.node_count assert (m > 0), "Empty tree" - + ## + self.intvs = {f'f{feature[i]}':set([]) for i in range(tree_.node_count) if not is_leaf[i]} + for i in range(tree_.node_count): + if not is_leaf[i]: + self.intvs[f'f{feature[i]}'].add(threshold[i]) + self.intvs = {f: sorted(self.intvs[f])+[math.inf] for f in six.iterkeys(self.intvs)} + def extract_data(idx, root = None, feature_names = None): i = idx assert (i < m), "Error index node" @@ -399,6 +352,8 @@ class UploadedDecisionTree: root, node_count, maxdepth = None, None, None if(self.tool == 'SKL'): + if "feature_names_in_" in dir(self.model): + feature_names = self.model.feature_names_in_ root, node_count, maxdepth = extract_skl(self.model.tree_, self.model.classes_, feature_names) if json_tree: @@ -408,5 +363,4 @@ class UploadedDecisionTree: root,_,maxdepth = extract_data(json_tree, 1, 0, None, feature_names) node_count = json.dumps(json_tree).count('feat') + json.dumps(json_tree).count('value') - return root, node_count, maxdepth - \ No newline at end of file + return root, node_count, maxdepth \ No newline at end of file diff --git a/utils.py b/utils.py index ed7d1b55c114e8f105d18f71ad01fed4e21caafa..4562909e785c8706f2a81df72ea74afcaf14befd 100644 --- a/utils.py +++ b/utils.py @@ -1,6 +1,7 @@ import base64 import io import pickle +import joblib import numpy as np from dash import html @@ -11,7 +12,7 @@ def parse_contents_graph(contents, filename): decoded = base64.b64decode(content_string) try: if '.pkl' in filename: - data = pickle.load(io.BytesIO(decoded)) + data = joblib.load(io.BytesIO(decoded)) except Exception as e: print(e) return html.Div([