Skip to content
Snippets Groups Projects
Commit b43c3c05 authored by Caroline DE POURTALES's avatar Caroline DE POURTALES
Browse files

revert to using .dt file, needs conversion from uploadeddecisiontree to decisiontree

parent 10a400d6
No related branches found
No related tags found
1 merge request!3Decision tree done
Showing with 1301 additions and 241 deletions
...@@ -57,8 +57,8 @@ def register_callbacks(page_home, page_course, page_application, app): ...@@ -57,8 +57,8 @@ def register_callbacks(page_home, page_course, page_application, app):
elif ihm_id == 'ml_pretrained_model_choice': elif ihm_id == 'ml_pretrained_model_choice':
if value_ml_model is None : if value_ml_model is None :
raise PreventUpdate raise PreventUpdate
tree, typ = parse_contents_graph(pretrained_model_contents, pretrained_model_filename) tree = parse_contents_graph(pretrained_model_contents, pretrained_model_filename)
model_application.update_pretrained_model(tree, typ) model_application.update_pretrained_model(tree, pretrained_model_filename)
return pretrained_model_filename, None, model_application.component.network, None return pretrained_model_filename, None, model_application.component.network, None
elif ihm_id == 'ml_instance_choice' : elif ihm_id == 'ml_instance_choice' :
......
from os import path from os import path
import base64
import dash_bootstrap_components as dbc import dash_bootstrap_components as dbc
import dash_interactive_graphviz import dash_interactive_graphviz
import numpy as np import numpy as np
from dash import dcc, html from dash import dcc, html
from pages.application.DecisionTree.utils.upload_tree import UploadedDecisionTree
from pages.application.DecisionTree.utils.dtree import DecisionTree from pages.application.DecisionTree.utils.dtree import DecisionTree
from pages.application.DecisionTree.utils.dtviz import (visualize, from pages.application.DecisionTree.utils.dtviz import (visualize,
visualize_expl, visualize_expl,
visualize_instance) visualize_instance)
...@@ -12,9 +15,19 @@ from pages.application.DecisionTree.utils.dtviz import (visualize, ...@@ -12,9 +15,19 @@ from pages.application.DecisionTree.utils.dtviz import (visualize,
class DecisionTreeComponent(): class DecisionTreeComponent():
def __init__(self, tree, typ_data): def __init__(self, tree, filename_tree):
try:
feature_names = tree.feature_names_in_
except:
print("You did not dump the model with the features names")
feature_names = [str(i) for i in range(tree.n_features_in_)]
self.uploaded_dt = UploadedDecisionTree(tree, 'SKL', filename_tree, maxdepth=3, feature_names=feature_names)
self.dt = DecisionTree(from_pickle = tree) #need a function that takes as input UploadedDecisionTree and gives DecisionTree
#self.dt = DecisionTree(from_dt=)
dt = open("pages/application/DecisionTree/meteo.dt", "r").read()
self.dt = DecisionTree(from_dt=dt)
dot_source = visualize(self.dt) dot_source = visualize(self.dt)
self.network = [dbc.Row(dash_interactive_graphviz.DashInteractiveGraphviz(dot_source=dot_source, style = {"width": "60%", self.network = [dbc.Row(dash_interactive_graphviz.DashInteractiveGraphviz(dot_source=dot_source, style = {"width": "60%",
......
21
1
I 1 2 4 6 8 11 12 14 17 19
T 3 5 7 9 10 13 15 16 18 20 21
3 T 0
5 T 0
7 T 0
9 T 0
10 T 1
13 T 0
15 T 0
16 T 1
18 T 1
20 T 0
21 T 1
1 f5 17 2
1 f5 16 11
2 f4 17 3
2 f4 16 4
4 f2 17 5
4 f2 16 6
6 f2 19 7
6 f2 18 8
8 f0 17 9
8 f0 16 10
11 f4 17 12
11 f4 16 17
12 f1 17 13
12 f1 16 14
14 f7 17 15
14 f7 16 16
17 f5 15 18
17 f5 14 19
19 f2 17 20
19 f2 16 21
21
1
I 1 3 5 7 9 11 13 15 17 19
T 2 4 6 8 10 12 14 16 18 20 21
2 T Setosa
4 T Setosa
6 T Setosa
8 T Setosa
10 T Setosa
12 T Versicolor
14 T Versicolor
16 T Versicolor
18 T Versicolor
20 T Versicolor
21 T Virginica
1 f3 23 2
1 f3 22 3
3 f2 31 4
3 f2 30 5
5 f3 41 6
5 f3 40 7
7 f2 21 8
7 f2 20 9
9 f3 15 10
9 f3 14 11
11 f3 25 12
11 f3 24 13
13 f3 7 14
13 f3 6 15
15 f3 1 16
15 f3 0 17
17 f3 5 18
17 f3 4 19
19 f2 73 20
19 f2 72 21
\ No newline at end of file
3
1
I 1
T 2 3
2 T -
3 T +
1 f0 5 2
1 f0 4 3
...@@ -11,47 +11,32 @@ ...@@ -11,47 +11,32 @@
# #
#============================================================================== #==============================================================================
from __future__ import print_function from __future__ import print_function
import collections import collections
from functools import reduce from functools import reduce
import sklearn
from pysat.card import * from pysat.card import *
from pysat.examples.hitman import Hitman from pysat.examples.hitman import Hitman
from pysat.formula import CNF, IDPool from pysat.formula import CNF, IDPool
from pysat.solvers import Solver from pysat.solvers import Solver
from torch import threshold
try: # for Python2 try: # for Python2
from cStringIO import StringIO from cStringIO import StringIO
except ImportError: # for Python3 except ImportError: # for Python3
from io import StringIO from io import StringIO
import numpy as np
from dash import dcc, html
from sklearn.tree import _tree from sklearn.tree import _tree
import numpy as np
#
#==============================================================================
class Node(): class Node():
""" """
Node class. Node class.
""" """
def __init__(self, feat='', vals=None, threshold=None, children_left= None, children_right=None): def __init__(self, feat='', vals=[]):
""" """
Constructor. Constructor.
""" """
self.feat = feat self.feat = feat
if threshold is not None : self.vals = vals
self.threshold = threshold
self.children_left = 0
self.children_right = 0
else :
self.vals = {}
# #
#============================================================================== #==============================================================================
...@@ -60,13 +45,12 @@ class DecisionTree(): ...@@ -60,13 +45,12 @@ class DecisionTree():
Simple decision tree class. Simple decision tree class.
""" """
def __init__(self, from_pickle=None, verbose=0): def __init__(self, from_dt=None, verbose=0):
""" """
Constructor. Constructor.
""" """
self.verbose = verbose self.verbose = verbose
self.typ=""
self.nof_nodes = 0 self.nof_nodes = 0
self.nof_terms = 0 self.nof_terms = 0
...@@ -76,57 +60,79 @@ class DecisionTree(): ...@@ -76,57 +60,79 @@ class DecisionTree():
self.paths = {} self.paths = {}
self.feats = [] self.feats = []
self.feids = {} self.feids = {}
self.fdoms = {}
self.fvmap = {}
if from_pickle: # OHE mapping
self.typ="pkl" OHEMap = collections.namedtuple('OHEMap', ['dir', 'opp'])
self.tree_ = '' self.ohmap = OHEMap(dir={}, opp={})
self.from_pickle_file(from_pickle)
#problem de feature names et problem de vals dans node
def from_pickle_file(self, tree):
#help(_tree.Tree)
self.tree_ = tree.tree_
#print(sklearn.tree.export_text(tree))
try:
feature_names = tree.feature_names_in_
except:
print("You did not dump the model with the features names")
feature_names = [str(i) for i in range(tree.n_features_in_)]
class_names = tree.classes_
self.nodes = collections.defaultdict(lambda: Node(feat='', threshold=int(0), children_left=int(0), children_right=int(0)))
self.terms={}
self.nof_nodes = self.tree_.node_count
self.root_node = 0
self.feats = feature_names
feature_name = [
feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
for i in self.tree_.feature]
def recurse(node):
if self.tree_.feature[node] != _tree.TREE_UNDEFINED:
name = feature_name[node]
val = self.tree_.threshold[node]
#faire une boucle for des vals ?
self.nodes[int(node)].feat = name
self.nodes[int(node)].threshold = np.round(val, 4)
self.nodes[int(node)].children_left = int(self.tree_.children_left[node])
self.nodes[int(node)].children_right = int(self.tree_.children_right[node])
recurse(self.tree_.children_left[node])
recurse(self.tree_.children_right[node])
else: if from_dt:
self.terms[node] = class_names[np.argmax(self.tree_.value[node])] self.from_dt(from_dt)
recurse(self.root_node) for f in self.feats:
for v in self.fdoms[f]:
self.fvmap[tuple([f, v])] = '{0}={1}'.format(f, v)
def from_dt(self, data):
"""
Get the tree from a file pointer.
"""
contents = StringIO(data)
lines = contents.readlines()
# filtering out comment lines (those that start with '#')
lines = list(filter(lambda l: not l.startswith('#'), lines))
# number of nodes
self.nof_nodes = int(lines[0].strip())
# root node
self.root_node = int(lines[1].strip())
# number of terminal nodes (classes)
self.nof_terms = len(lines[3][2:].strip().split())
# the ordered list of terminal nodes
self.terms = {}
for i in range(self.nof_terms):
nd, _, t = lines[i + 4].strip().split()
self.terms[int(nd)] = t #int(t)
# finally, reading the nodes
self.nodes = collections.defaultdict(lambda: Node(feat='', vals={}))
self.feats = set([])
self.fdoms = collections.defaultdict(lambda: set([]))
for line in lines[(4 + self.nof_terms):]:
# reading the tuple
nid, fid, fval, child = line.strip().split()
# inserting it in the nodes list
self.nodes[int(nid)].feat = fid
self.nodes[int(nid)].vals[int(fval)] = int(child)
# updating the list of features
self.feats.add(fid)
# updaing feature domains
self.fdoms[fid].add(int(fval))
# adding complex node connections into consideration
for n1 in self.nodes:
conns = collections.defaultdict(lambda: set([]))
for v, n2 in self.nodes[n1].vals.items():
conns[n2].add(v)
self.nodes[n1].vals = {frozenset(v): n2 for n2, v in conns.items()}
# simplifying the features and their domains
self.feats = sorted(self.feats) self.feats = sorted(self.feats)
self.feids = {f: i for i, f in enumerate(self.feats)} self.feids = {f: i for i, f in enumerate(self.feats)}
self.nof_terms = len(self.terms) self.fdoms = {f: sorted(self.fdoms[f]) for f in self.fdoms}
self.nof_nodes -= len(self.terms)
# here we assume all features are present in the tree
# if not, this value will be rewritten by self.parse_mapping()
self.nof_feats = len(self.feats) self.nof_feats = len(self.feats)
self.paths = collections.defaultdict(lambda: []) self.paths = collections.defaultdict(lambda: [])
...@@ -137,23 +143,69 @@ class DecisionTree(): ...@@ -137,23 +143,69 @@ class DecisionTree():
Traverse the tree and extract explicit paths. Traverse the tree and extract explicit paths.
""" """
if root in self.terms.keys(): if root in self.terms:
# store the path # store the path
term = self.terms[root] term = self.terms[root]
self.paths[term].append(prefix) self.paths[term].append(prefix)
else: else:
# select next node # select next node
feat, threshold, children_left, children_right = self.nodes[root].feat, self.nodes[root].threshold, self.nodes[root].children_left, self.nodes[root].children_right feat, vals = self.nodes[root].feat, self.nodes[root].vals
self.extract_paths(children_left, prefix + [tuple([feat, "<=" + str(threshold)])]) for val in vals:
self.extract_paths(children_right, prefix + [tuple([feat, ">"+ str(threshold)])]) self.extract_paths(vals[val], prefix + [tuple([feat, val])])
def execute(self, inst): def execute(self, inst, pathlits=False):
inst = np.array([inst]) """
path = self.tree_.decision_path(inst) Run the tree and obtain the prediction given an input instance.
term_id_node = self.tree_.apply(inst) """
term_id_node = term_id_node[0]
path = path.indices[path.indptr[0] : path.indptr[0 + 1]] root = self.root_node
return path, term_id_node depth = 0
path = []
# this array is needed if we focus on the path's literals only
visited = [False for f in inst]
while not root in self.terms:
path.append(root)
feat, vals = self.nodes[root].feat, self.nodes[root].vals
visited[self.feids[feat]] = True
tval = inst[self.feids[feat]][1]
###############
# assert(len(vals) == 2)
next_node = root
neq = None
for vs, dest in vals.items():
if tval in vs:
next_node = dest
break
else:
for v in vs:
if '!=' in self.fvmap[(feat, v)]:
neq = dest
break
else:
next_node = neq
# if tval not in vals:
# # go to the False branch (!=)
# for i in vals:
# if "!=" in self.fvmap[(feat,i)]:
# next_node = vals[i]
# break
# else:
# next_node = vals[tval]
assert (next_node != root)
###############
root = next_node
depth += 1
if pathlits:
# filtering out non-visited literals
for i, v in enumerate(visited):
if not v:
inst[i] = None
return path, self.terms[root], depth
def prepare_sets(self, inst, term): def prepare_sets(self, inst, term):
""" """
...@@ -164,7 +216,7 @@ class DecisionTree(): ...@@ -164,7 +216,7 @@ class DecisionTree():
sets = [] sets = []
for t, paths in self.paths.items(): for t, paths in self.paths.items():
# ignoring the right class # ignoring the right class
if term in self.terms.keys() and self.terms[term] == t: if t == term:
continue continue
# computing the sets to hit # computing the sets to hit
...@@ -172,16 +224,21 @@ class DecisionTree(): ...@@ -172,16 +224,21 @@ class DecisionTree():
to_hit = [] to_hit = []
for item in path: for item in path:
# if the instance disagrees with the path on this item # if the instance disagrees with the path on this item
if ("<=" in item[1] and (inst[item[0]] > np.float32(item[1][2:]))) or (">" in item[1] and (inst[item[0]] <= np.float32(item[1][1:]))) : if inst[self.feids[item[0]]] and not inst[self.feids[item[0]]][1] in item[1]:
if "<=" in item[1] : fv = inst[self.feids[item[0]]]
fv = tuple([item[0], str(inst[item[0]]), ">" , str(np.float32(item[1][2:]))]) if fv[0] in self.ohmap.opp:
else : to_hit.append(tuple([self.ohmap.opp[fv[0]], None]))
fv = tuple([item[0], str(inst[item[0]]) , "<=" , str(np.float32(item[1][1:]))]) else:
to_hit.append(fv) to_hit.append(fv)
to_hit = sorted(set(to_hit))
sets.append(tuple(to_hit))
if self.verbose:
if self.verbose > 1:
print('c trav. path: {0}'.format(path))
if len(to_hit)>0 : print('c set to hit: {0}'.format(to_hit))
to_hit = sorted(set(to_hit))
sets.append(tuple(to_hit))
# returning the set of sets with no duplicates # returning the set of sets with no duplicates
return list(dict.fromkeys(sets)) return list(dict.fromkeys(sets))
...@@ -191,11 +248,10 @@ class DecisionTree(): ...@@ -191,11 +248,10 @@ class DecisionTree():
Compute a given number of explanations. Compute a given number of explanations.
""" """
inst_values = [np.float32(i[1]) for i in inst]
inst_dic = {} inst_dic = {}
for i in range(len(inst)): for i in range(len(inst)):
inst_dic[inst[i][0]] = np.float32(inst[i][1]) inst_dic[inst[i][0]] = np.float32(inst[i][1])
path, term = self.execute(inst_values) path, term, depth = self.execute(inst)
#contaiins all the elements for explanation #contaiins all the elements for explanation
explanation_dic = {} explanation_dic = {}
...@@ -203,24 +259,12 @@ class DecisionTree(): ...@@ -203,24 +259,12 @@ class DecisionTree():
explanation_dic["Instance : "] = str(inst_dic) explanation_dic["Instance : "] = str(inst_dic)
#decision path #decision path
decision_path_str = "IF : " decision_path_str = 'IF {0} THEN class={1}'.format(' AND '.join([self.fvmap[inst[self.feids[self.nodes[n].feat]] ] for n in path]), term)
for node_id in path:
# continue to the next node if it is a leaf node
if term == node_id:
continue
decision_path_str +="(inst[{feature}] = {value}) {inequality} {threshold}) AND ".format(
feature=self.nodes[node_id].feat,
value=inst_dic[self.nodes[node_id].feat],
inequality="<=" if inst_dic[self.nodes[node_id].feat] <= self.nodes[node_id].threshold else ">" ,
threshold=self.nodes[node_id].threshold)
decision_path_str += "THEN " + str(self.terms[term])
explanation_dic["Decision path of instance : "] = decision_path_str explanation_dic["Decision path of instance : "] = decision_path_str
explanation_dic["Decision path length : "] = 'Path length is :'+ str(len(path)) explanation_dic["Decision path length : "] = 'Path length is :'+ str(len(path))
# computing the sets to hit # computing the sets to hit
to_hit = self.prepare_sets(inst_dic, term) to_hit = self.prepare_sets(inst, term)
for type in xtype : for type in xtype :
if type == "AXp": if type == "AXp":
...@@ -240,12 +284,9 @@ class DecisionTree(): ...@@ -240,12 +284,9 @@ class DecisionTree():
with Hitman(bootstrap_with=to_hit, solver='m22', htype=htype) as hitman: with Hitman(bootstrap_with=to_hit, solver='m22', htype=htype) as hitman:
expls = [] expls = []
for i, expl in enumerate(hitman.enumerate(), 1): for i, expl in enumerate(hitman.enumerate(), 1):
list_expls.append([ p[0] + p[2] + p[3] for p in expl]) list_expls.append([ str(p[0]) + "=" + str(p[1]) for p in expl])
list_expls_str.append('Explanation: IF {0} THEN class={1}'.format(' AND '.join(["(inst[{feature}] = {value}) {inequality} {threshold})".format(feature=p[0], list_expls_str.append('Explanation: IF {0} THEN class={1}'.format(' AND '.join([self.fvmap[p] for p in sorted(expl, key=lambda p: p[0])]), term))
value=p[1],
inequality=p[2],
threshold=p[3])
for p in sorted(expl, key=lambda p: p[0])]), str(self.terms[term])))
expls.append(expl) expls.append(expl)
if i == enum: if i == enum:
break break
...@@ -277,10 +318,8 @@ class DecisionTree(): ...@@ -277,10 +318,8 @@ class DecisionTree():
list_expls_str = [] list_expls_str = []
explanation = {} explanation = {}
for expl in expls: for expl in expls:
list_expls_str.append('Contrastive: IF {0} THEN class!={1}'.format(' OR '.join(["inst[{feature}] {inequality} {threshold})".format(feature=p[0], list_expls_str.append('Contrastive: IF {0} THEN class!={1}'.format(' OR '.join(['!{0}'.format(self.fvmap[p]) for p in sorted(expl, key=lambda p: p[0])]), term))
inequality="<=" if p[2]==">" else ">",
threshold=p[3])
for p in sorted(expl, key=lambda p: p[0])]), str(self.terms[term])))
explanation["List of contrastive explanation(s)"] = list_expls_str explanation["List of contrastive explanation(s)"] = list_expls_str
explanation["Number of contrastive explanation(s) : "]=str(len(expls)) explanation["Number of contrastive explanation(s) : "]=str(len(expls))
explanation["Minimal contrastive explanation : "]= str( min([len(e) for e in expls])) explanation["Minimal contrastive explanation : "]= str( min([len(e) for e in expls]))
......
...@@ -8,78 +8,58 @@ ...@@ -8,78 +8,58 @@
## E-mail: alexey.ignatiev@monash.edu ## E-mail: alexey.ignatiev@monash.edu
## ##
import numpy as np
import pygraphviz
# #
#============================================================================== #==============================================================================
def create_legend(g): from pages.application.DecisionTree.utils.dtree import DecisionTree
legend = g.subgraphs()[-1] import getopt
legend.add_node("a", style = "invis") import os
legend.add_node("b", style = "invis") import pygraphviz
legend.add_node("c", style = "invis") import sys
legend.add_node("d", style = "invis")
legend.add_edge("a","b")
edge = legend.get_edge("a","b")
edge.attr["label"] = "instance"
edge.attr["style"] = "dashed"
legend.add_edge("c","d")
edge = legend.get_edge("c","d")
edge.attr["label"] = "instance with explanation"
edge.attr["color"] = "blue"
edge.attr["style"] = "dashed"
#
#==============================================================================
def visualize(dt): def visualize(dt):
""" """
Visualize a DT with graphviz. Visualize a DT with graphviz.
""" """
g = pygraphviz.AGraph(name='root', rankdir="TB") g = pygraphviz.AGraph(directed=True, strict=True)
g.is_directed()
g.is_strict()
#g = pygraphviz.AGraph(name = "main", directed=True, strict=True)
g.edge_attr['dir'] = 'forward' g.edge_attr['dir'] = 'forward'
g.graph_attr['rankdir'] = 'TB'
# non-terminal nodes # non-terminal nodes
for n in dt.nodes: for n in dt.nodes:
g.add_node(n, label=str(dt.nodes[n].feat)) g.add_node(n, label='{0}\\n({1})'.format(dt.nodes[n].feat, n))
node = g.get_node(n) node = g.get_node(n)
node.attr['shape'] = 'circle' node.attr['shape'] = 'circle'
node.attr['fontsize'] = 13 node.attr['fontsize'] = 13
# terminal nodes # terminal nodes
for n in dt.terms: for n in dt.terms:
g.add_node(n, label=str(dt.terms[n])) g.add_node(n, label='{0}\\n({1})'.format(dt.terms[n], n))
node = g.get_node(n) node = g.get_node(n)
node.attr['shape'] = 'square' node.attr['shape'] = 'square'
node.attr['fontsize'] = 13 node.attr['fontsize'] = 13
# transitions
for n1 in dt.nodes: for n1 in dt.nodes:
threshold = dt.nodes[n1].threshold for v in dt.nodes[n1].vals:
n2 = dt.nodes[n1].vals[v]
children_left = dt.nodes[n1].children_left g.add_edge(n1, n2)
g.add_edge(n1, children_left) edge = g.get_edge(n1, n2)
edge = g.get_edge(n1, children_left) if len(v) == 1:
edge.attr['label'] = str(dt.nodes[n1].feat) + "<=" + str(threshold) edge.attr['label'] = dt.fvmap[tuple([dt.nodes[n1].feat, tuple(v)[0]])]
edge.attr['fontsize'] = 10 else:
edge.attr['arrowsize'] = 0.8 edge.attr['label'] = '{0}'.format('\n'.join([dt.fvmap[tuple([dt.nodes[n1].feat, val])] for val in tuple(v)]))
edge.attr['fontsize'] = 10
children_right = dt.nodes[n1].children_right edge.attr['arrowsize'] = 0.8
g.add_edge(n1, children_right)
edge = g.get_edge(n1, children_right)
edge.attr['label'] = str(dt.nodes[n1].feat) + ">" + str(threshold)
edge.attr['fontsize'] = 10
edge.attr['arrowsize'] = 0.8
g.add_subgraph(name='legend')
create_legend(g)
# saving file # saving file
g.in_edges
g.layout(prog='dot') g.layout(prog='dot')
return(g.string()) return(g.to_string())
# #
#============================================================================== #==============================================================================
...@@ -87,120 +67,111 @@ def visualize_instance(dt, instance): ...@@ -87,120 +67,111 @@ def visualize_instance(dt, instance):
""" """
Visualize a DT with graphviz and plot the running instance. Visualize a DT with graphviz and plot the running instance.
""" """
#path that follows the instance - colored in blue
path, term, depth = dt.execute(instance)
edges_instance = []
for i in range (len(path)-1) :
edges_instance.append((path[i], path[i+1]))
edges_instance.append((path[-1],"term:"+term))
g = pygraphviz.AGraph(directed=True, strict=True) g = pygraphviz.AGraph(directed=True, strict=True)
g.edge_attr['dir'] = 'forward' g.edge_attr['dir'] = 'forward'
g.graph_attr['rankdir'] = 'TB' g.graph_attr['rankdir'] = 'TB'
# non-terminal nodes # non-terminal nodes
for n in dt.nodes: for n in dt.nodes:
g.add_node(n, label=str(dt.nodes[n].feat)) g.add_node(n, label='{0}\\n({1})'.format(dt.nodes[n].feat, n))
node = g.get_node(n) node = g.get_node(n)
node.attr['shape'] = 'circle' node.attr['shape'] = 'circle'
node.attr['fontsize'] = 13 node.attr['fontsize'] = 13
# terminal nodes # terminal nodes
for n in dt.terms: for n in dt.terms:
g.add_node(n, label=str(dt.terms[n])) g.add_node(n, label='{0}\\n({1})'.format(dt.terms[n], n))
node = g.get_node(n) node = g.get_node(n)
node.attr['shape'] = 'square' node.attr['shape'] = 'square'
node.attr['fontsize'] = 13 node.attr['fontsize'] = 13
#path that follows the instance - colored in blue # transitions
instance = [np.float32(i[1]) for i in instance]
path, term_id_node = dt.execute(instance)
edges_instance = []
for i in range (len(path)-1) :
edges_instance.append((path[i], path[i+1]))
for n1 in dt.nodes: for n1 in dt.nodes:
threshold = dt.nodes[n1].threshold for v in dt.nodes[n1].vals:
n2 = dt.nodes[n1].vals[v]
children_left = dt.nodes[n1].children_left n2_type = g.get_node(n2).attr['shape']
g.add_edge(n1, children_left) g.add_edge(n1, n2)
edge = g.get_edge(n1, children_left) edge = g.get_edge(n1, n2)
edge.attr['label'] = str(dt.nodes[n1].feat) + "<=" + str(threshold) if len(v) == 1:
edge.attr['fontsize'] = 10 edge.attr['label'] = dt.fvmap[tuple([dt.nodes[n1].feat, tuple(v)[0]])]
edge.attr['arrowsize'] = 0.8 else:
#instance path in blue edge.attr['label'] = '{0}'.format('\n'.join([dt.fvmap[tuple([dt.nodes[n1].feat, val])] for val in tuple(v)]))
if ((n1,children_left) in edges_instance):
edge.attr['style'] = 'dashed' #instance path in blue
if ((n1,n2) in edges_instance) or (n2_type=='square' and (n1, "term:"+ dt.terms[n2]) in edges_instance):
children_right = dt.nodes[n1].children_right edge.attr['color'] = 'blue'
g.add_edge(n1, children_right)
edge = g.get_edge(n1, children_right) edge.attr['fontsize'] = 10
edge.attr['label'] = str(dt.nodes[n1].feat) + ">" + str(threshold) edge.attr['arrowsize'] = 0.8
edge.attr['fontsize'] = 10
edge.attr['arrowsize'] = 0.8
#instance path in blue
if ((n1,children_right) in edges_instance):
edge.attr['style'] = 'dashed'
g.add_subgraph(name='legend')
create_legend(g)
# saving file # saving file
g.layout(prog='dot') g.layout(prog='dot')
return(g.to_string()) return(g.to_string())
#
#============================================================================== #==============================================================================
def visualize_expl(dt, instance, expl): def visualize_expl(dt, instance, expl):
""" """
Visualize a DT with graphviz and plot the running instance. Visualize a DT with graphviz and plot the running instance.
""" """
if '=' in instance[0]:
instance = list(map(lambda i: tuple([i[0], int(i[1])]), [i.split('=') for i in instance]))
else:
instance = list(map(lambda i : tuple(['f{0}'.format(i[0]), int(i[1])]), [(i, j) for i,j in enumerate(instance)]))
#path that follows the instance - colored in blue
path, term, depth = dt.execute(instance)
edges_instance = []
for i in range (len(path)-1) :
edges_instance.append((path[i], path[i+1]))
edges_instance.append((path[-1],"term:"+term))
g = pygraphviz.AGraph(directed=True, strict=True) g = pygraphviz.AGraph(directed=True, strict=True)
g.edge_attr['dir'] = 'forward' g.edge_attr['dir'] = 'forward'
g.graph_attr['rankdir'] = 'TB' g.graph_attr['rankdir'] = 'TB'
# non-terminal nodes # non-terminal nodes
for n in dt.nodes: for n in dt.nodes:
g.add_node(n, label=str(dt.nodes[n].feat)) g.add_node(n, label='{0}\\n({1})'.format(dt.nodes[n].feat, n))
node = g.get_node(n) node = g.get_node(n)
node.attr['shape'] = 'circle' node.attr['shape'] = 'circle'
node.attr['fontsize'] = 13 node.attr['fontsize'] = 13
# terminal nodes # terminal nodes
for n in dt.terms: for n in dt.terms:
g.add_node(n, label=str(dt.terms[n])) g.add_node(n, label='{0}\\n({1})'.format(dt.terms[n], n))
node = g.get_node(n) node = g.get_node(n)
node.attr['shape'] = 'square' node.attr['shape'] = 'square'
node.attr['fontsize'] = 13 node.attr['fontsize'] = 13
#path that follows the instance - colored in blue # transitions
instance = [np.float32(i[1]) for i in instance]
path, term_id_node = dt.execute(instance)
edges_instance = []
for i in range (len(path)-1) :
edges_instance.append((path[i], path[i+1]))
for n1 in dt.nodes: for n1 in dt.nodes:
threshold = dt.nodes[n1].threshold for v in dt.nodes[n1].vals:
n2 = dt.nodes[n1].vals[v]
children_left = dt.nodes[n1].children_left n2_type = g.get_node(n2).attr['shape']
g.add_edge(n1, children_left) g.add_edge(n1, n2)
edge = g.get_edge(n1, children_left) edge = g.get_edge(n1, n2)
edge.attr['label'] = str(dt.nodes[n1].feat) + "<=" + str(threshold) if len(v) == 1:
edge.attr['fontsize'] = 10 edge.attr['label'] = dt.fvmap[tuple([dt.nodes[n1].feat, tuple(v)[0]])]
edge.attr['arrowsize'] = 0.8 else:
#instance path in blue edge.attr['label'] = '{0}'.format('\n'.join([dt.fvmap[tuple([dt.nodes[n1].feat, val])] for val in tuple(v)]))
if ((n1,children_left) in edges_instance):
edge.attr['style'] = 'dashed' #instance path in blue
if edge.attr['label'] in expl : if ((n1,n2) in edges_instance) or (n2_type=='square' and (n1, "term:"+ dt.terms[n2]) in edges_instance):
edge.attr['color'] = 'blue' edge.attr['color'] = 'blue'
children_right = dt.nodes[n1].children_right edge.attr['fontsize'] = 10
g.add_edge(n1, children_right) edge.attr['arrowsize'] = 0.8
edge = g.get_edge(n1, children_right)
edge.attr['label'] = str(dt.nodes[n1].feat) + ">" + str(threshold)
edge.attr['fontsize'] = 10
edge.attr['arrowsize'] = 0.8
#instance path in blue
if ((n1,children_right) in edges_instance):
edge.attr['style'] = 'dashed'
if edge.attr['label'] in expl :
edge.attr['color'] = 'blue'
g.add_subgraph(name='legend')
create_legend(g)
# saving file
g.layout(prog='dot') g.layout(prog='dot')
return(g.to_string()) return(g.to_string())
#!/usr/bin/env python
#-*- coding:utf-8 -*-
##
## dtree.py
##
## Created on: Jul 6, 2020
## Author: Alexey Ignatiev
## E-mail: alexey.ignatiev@monash.edu
##
#
#==============================================================================
from __future__ import print_function
import collections
from functools import reduce
import sklearn
from pysat.card import *
from pysat.examples.hitman import Hitman
from pysat.formula import CNF, IDPool
from pysat.solvers import Solver
from torch import threshold
try: # for Python2
from cStringIO import StringIO
except ImportError: # for Python3
from io import StringIO
import numpy as np
from dash import dcc, html
from sklearn.tree import _tree
#
#==============================================================================
class Node():
"""
Node class.
"""
def __init__(self, feat='', vals=None, threshold=None, children_left= None, children_right=None):
"""
Constructor.
"""
self.feat = feat
if threshold is not None :
self.threshold = threshold
self.children_left = 0
self.children_right = 0
else :
self.vals = {}
#
#==============================================================================
class DecisionTree():
"""
Simple decision tree class.
"""
def __init__(self, from_pickle=None, verbose=0):
"""
Constructor.
"""
self.verbose = verbose
self.typ=""
self.nof_nodes = 0
self.nof_terms = 0
self.root_node = None
self.terms = []
self.nodes = {}
self.paths = {}
self.feats = []
self.feids = {}
if from_pickle:
self.typ="pkl"
self.tree_ = ''
self.from_pickle_file(from_pickle)
#problem de feature names et problem de vals dans node
def from_pickle_file(self, tree):
#help(_tree.Tree)
self.tree_ = tree.tree_
#print(sklearn.tree.export_text(tree))
try:
feature_names = tree.feature_names_in_
except:
print("You did not dump the model with the features names")
feature_names = [str(i) for i in range(tree.n_features_in_)]
class_names = tree.classes_
self.nodes = collections.defaultdict(lambda: Node(feat='', threshold=int(0), children_left=int(0), children_right=int(0)))
self.terms={}
self.nof_nodes = self.tree_.node_count
self.root_node = 0
self.feats = feature_names
feature_name = [
feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
for i in self.tree_.feature]
def recurse(node):
if self.tree_.feature[node] != _tree.TREE_UNDEFINED:
name = feature_name[node]
val = self.tree_.threshold[node]
#faire une boucle for des vals ?
self.nodes[int(node)].feat = name
self.nodes[int(node)].threshold = np.round(val, 4)
self.nodes[int(node)].children_left = int(self.tree_.children_left[node])
self.nodes[int(node)].children_right = int(self.tree_.children_right[node])
recurse(self.tree_.children_left[node])
recurse(self.tree_.children_right[node])
else:
self.terms[node] = class_names[np.argmax(self.tree_.value[node])]
recurse(self.root_node)
self.feats = sorted(self.feats)
self.feids = {f: i for i, f in enumerate(self.feats)}
self.nof_terms = len(self.terms)
self.nof_nodes -= len(self.terms)
self.nof_feats = len(self.feats)
self.paths = collections.defaultdict(lambda: [])
self.extract_paths(root=self.root_node, prefix=[])
def extract_paths(self, root, prefix):
"""
Traverse the tree and extract explicit paths.
"""
if root in self.terms.keys():
# store the path
term = self.terms[root]
self.paths[term].append(prefix)
else:
# select next node
feat, threshold, children_left, children_right = self.nodes[root].feat, self.nodes[root].threshold, self.nodes[root].children_left, self.nodes[root].children_right
self.extract_paths(children_left, prefix + [tuple([feat, "<=" + str(threshold)])])
self.extract_paths(children_right, prefix + [tuple([feat, ">"+ str(threshold)])])
def execute(self, inst):
inst = np.array([inst])
path = self.tree_.decision_path(inst)
term_id_node = self.tree_.apply(inst)
term_id_node = term_id_node[0]
path = path.indices[path.indptr[0] : path.indptr[0 + 1]]
return path, term_id_node
def prepare_sets(self, inst, term):
"""
Hitting set based encoding of the problem.
(currently not incremental -- should be fixed later)
"""
sets = []
for t, paths in self.paths.items():
# ignoring the right class
if term in self.terms.keys() and self.terms[term] == t:
continue
# computing the sets to hit
for path in paths:
to_hit = []
for item in path:
# if the instance disagrees with the path on this item
if ("<=" in item[1] and (inst[item[0]] > np.float32(item[1][2:]))) or (">" in item[1] and (inst[item[0]] <= np.float32(item[1][1:]))) :
if "<=" in item[1] :
fv = tuple([item[0], str(inst[item[0]]), ">" , str(np.float32(item[1][2:]))])
else :
fv = tuple([item[0], str(inst[item[0]]) , "<=" , str(np.float32(item[1][1:]))])
to_hit.append(fv)
if len(to_hit)>0 :
to_hit = sorted(set(to_hit))
sets.append(tuple(to_hit))
# returning the set of sets with no duplicates
return list(dict.fromkeys(sets))
def explain(self, inst, enum=1, pathlits=False, xtype = ["AXp"], solver='g3', htype='sorted'):
"""
Compute a given number of explanations.
"""
inst_values = [np.float32(i[1]) for i in inst]
inst_dic = {}
for i in range(len(inst)):
inst_dic[inst[i][0]] = np.float32(inst[i][1])
path, term = self.execute(inst_values)
#contaiins all the elements for explanation
explanation_dic = {}
#instance plotting
explanation_dic["Instance : "] = str(inst_dic)
#decision path
decision_path_str = "IF : "
for node_id in path:
# continue to the next node if it is a leaf node
if term == node_id:
continue
decision_path_str +="(inst[{feature}] = {value}) {inequality} {threshold}) AND ".format(
feature=self.nodes[node_id].feat,
value=inst_dic[self.nodes[node_id].feat],
inequality="<=" if inst_dic[self.nodes[node_id].feat] <= self.nodes[node_id].threshold else ">" ,
threshold=self.nodes[node_id].threshold)
decision_path_str += "THEN " + str(self.terms[term])
explanation_dic["Decision path of instance : "] = decision_path_str
explanation_dic["Decision path length : "] = 'Path length is :'+ str(len(path))
# computing the sets to hit
to_hit = self.prepare_sets(inst_dic, term)
for type in xtype :
if type == "AXp":
explanation_dic.update(self.enumerate_abductive(to_hit, enum, solver, htype, term))
else :
explanation_dic.update(self.enumerate_contrastive(to_hit, term))
return explanation_dic
def enumerate_abductive(self, to_hit, enum, solver, htype, term):
"""
Enumerate abductive explanations.
"""
list_expls = []
list_expls_str = []
explanation = {}
with Hitman(bootstrap_with=to_hit, solver='m22', htype=htype) as hitman:
expls = []
for i, expl in enumerate(hitman.enumerate(), 1):
list_expls.append([ p[0] + p[2] + p[3] for p in expl])
list_expls_str.append('Explanation: IF {0} THEN class={1}'.format(' AND '.join(["(inst[{feature}] = {value}) {inequality} {threshold})".format(feature=p[0],
value=p[1],
inequality=p[2],
threshold=p[3])
for p in sorted(expl, key=lambda p: p[0])]), str(self.terms[term])))
expls.append(expl)
if i == enum:
break
explanation["List of path explanation(s)"] = list_expls
explanation["List of abductive explanation(s)"] = list_expls_str
explanation["Number of abductive explanation(s) : "] = str(i)
explanation["Minimal abductive explanation : "] = str( min([len(e) for e in expls]))
explanation["Maximal abductive explanation : "] = str( max([len(e) for e in expls]))
explanation["Average abductive explanation : "] = '{0:.2f}'.format(sum([len(e) for e in expls]) / len(expls))
return explanation
def enumerate_contrastive(self, to_hit, term):
"""
Enumerate contrastive explanations.
"""
def process_set(done, target):
for s in done:
if s <= target:
break
else:
done.append(target)
return done
to_hit = [set(s) for s in to_hit]
to_hit.sort(key=lambda s: len(s))
expls = list(reduce(process_set, to_hit, []))
list_expls_str = []
explanation = {}
for expl in expls:
list_expls_str.append('Contrastive: IF {0} THEN class!={1}'.format(' OR '.join(["inst[{feature}] {inequality} {threshold})".format(feature=p[0],
inequality="<=" if p[2]==">" else ">",
threshold=p[3])
for p in sorted(expl, key=lambda p: p[0])]), str(self.terms[term])))
explanation["List of contrastive explanation(s)"] = list_expls_str
explanation["Number of contrastive explanation(s) : "]=str(len(expls))
explanation["Minimal contrastive explanation : "]= str( min([len(e) for e in expls]))
explanation["Maximal contrastive explanation : "]= str( max([len(e) for e in expls]))
explanation["Average contrastive explanation : "]='{0:.2f}'.format(sum([len(e) for e in expls]) / len(expls))
return explanation
#!/usr/bin/env python
#-*- coding:utf-8 -*-
##
## dtviz.py
##
## Created on: Jul 7, 2020
## Author: Alexey Ignatiev
## E-mail: alexey.ignatiev@monash.edu
##
import numpy as np
import pygraphviz
#
#==============================================================================
def create_legend(g):
legend = g.subgraphs()[-1]
legend.add_node("a", style = "invis")
legend.add_node("b", style = "invis")
legend.add_node("c", style = "invis")
legend.add_node("d", style = "invis")
legend.add_edge("a","b")
edge = legend.get_edge("a","b")
edge.attr["label"] = "instance"
edge.attr["style"] = "dashed"
legend.add_edge("c","d")
edge = legend.get_edge("c","d")
edge.attr["label"] = "instance with explanation"
edge.attr["color"] = "blue"
edge.attr["style"] = "dashed"
def visualize(dt):
"""
Visualize a DT with graphviz.
"""
g = pygraphviz.AGraph(name='root', rankdir="TB")
g.is_directed()
g.is_strict()
#g = pygraphviz.AGraph(name = "main", directed=True, strict=True)
g.edge_attr['dir'] = 'forward'
# non-terminal nodes
for n in dt.nodes:
g.add_node(n, label=str(dt.nodes[n].feat))
node = g.get_node(n)
node.attr['shape'] = 'circle'
node.attr['fontsize'] = 13
# terminal nodes
for n in dt.terms:
g.add_node(n, label=str(dt.terms[n]))
node = g.get_node(n)
node.attr['shape'] = 'square'
node.attr['fontsize'] = 13
for n1 in dt.nodes:
threshold = dt.nodes[n1].threshold
children_left = dt.nodes[n1].children_left
g.add_edge(n1, children_left)
edge = g.get_edge(n1, children_left)
edge.attr['label'] = str(dt.nodes[n1].feat) + "<=" + str(threshold)
edge.attr['fontsize'] = 10
edge.attr['arrowsize'] = 0.8
children_right = dt.nodes[n1].children_right
g.add_edge(n1, children_right)
edge = g.get_edge(n1, children_right)
edge.attr['label'] = str(dt.nodes[n1].feat) + ">" + str(threshold)
edge.attr['fontsize'] = 10
edge.attr['arrowsize'] = 0.8
g.add_subgraph(name='legend')
create_legend(g)
# saving file
g.layout(prog='dot')
return(g.string())
#
#==============================================================================
def visualize_instance(dt, instance):
"""
Visualize a DT with graphviz and plot the running instance.
"""
g = pygraphviz.AGraph(directed=True, strict=True)
g.edge_attr['dir'] = 'forward'
g.graph_attr['rankdir'] = 'TB'
# non-terminal nodes
for n in dt.nodes:
g.add_node(n, label=str(dt.nodes[n].feat))
node = g.get_node(n)
node.attr['shape'] = 'circle'
node.attr['fontsize'] = 13
# terminal nodes
for n in dt.terms:
g.add_node(n, label=str(dt.terms[n]))
node = g.get_node(n)
node.attr['shape'] = 'square'
node.attr['fontsize'] = 13
#path that follows the instance - colored in blue
instance = [np.float32(i[1]) for i in instance]
path, term_id_node = dt.execute(instance)
edges_instance = []
for i in range (len(path)-1) :
edges_instance.append((path[i], path[i+1]))
for n1 in dt.nodes:
threshold = dt.nodes[n1].threshold
children_left = dt.nodes[n1].children_left
g.add_edge(n1, children_left)
edge = g.get_edge(n1, children_left)
edge.attr['label'] = str(dt.nodes[n1].feat) + "<=" + str(threshold)
edge.attr['fontsize'] = 10
edge.attr['arrowsize'] = 0.8
#instance path in blue
if ((n1,children_left) in edges_instance):
edge.attr['style'] = 'dashed'
children_right = dt.nodes[n1].children_right
g.add_edge(n1, children_right)
edge = g.get_edge(n1, children_right)
edge.attr['label'] = str(dt.nodes[n1].feat) + ">" + str(threshold)
edge.attr['fontsize'] = 10
edge.attr['arrowsize'] = 0.8
#instance path in blue
if ((n1,children_right) in edges_instance):
edge.attr['style'] = 'dashed'
g.add_subgraph(name='legend')
create_legend(g)
# saving file
g.layout(prog='dot')
return(g.to_string())
#
#==============================================================================
def visualize_expl(dt, instance, expl):
"""
Visualize a DT with graphviz and plot the running instance.
"""
g = pygraphviz.AGraph(directed=True, strict=True)
g.edge_attr['dir'] = 'forward'
g.graph_attr['rankdir'] = 'TB'
# non-terminal nodes
for n in dt.nodes:
g.add_node(n, label=str(dt.nodes[n].feat))
node = g.get_node(n)
node.attr['shape'] = 'circle'
node.attr['fontsize'] = 13
# terminal nodes
for n in dt.terms:
g.add_node(n, label=str(dt.terms[n]))
node = g.get_node(n)
node.attr['shape'] = 'square'
node.attr['fontsize'] = 13
#path that follows the instance - colored in blue
instance = [np.float32(i[1]) for i in instance]
path, term_id_node = dt.execute(instance)
edges_instance = []
for i in range (len(path)-1) :
edges_instance.append((path[i], path[i+1]))
for n1 in dt.nodes:
threshold = dt.nodes[n1].threshold
children_left = dt.nodes[n1].children_left
g.add_edge(n1, children_left)
edge = g.get_edge(n1, children_left)
edge.attr['label'] = str(dt.nodes[n1].feat) + "<=" + str(threshold)
edge.attr['fontsize'] = 10
edge.attr['arrowsize'] = 0.8
#instance path in blue
if ((n1,children_left) in edges_instance):
edge.attr['style'] = 'dashed'
if edge.attr['label'] in expl :
edge.attr['color'] = 'blue'
children_right = dt.nodes[n1].children_right
g.add_edge(n1, children_right)
edge = g.get_edge(n1, children_right)
edge.attr['label'] = str(dt.nodes[n1].feat) + ">" + str(threshold)
edge.attr['fontsize'] = 10
edge.attr['arrowsize'] = 0.8
#instance path in blue
if ((n1,children_right) in edges_instance):
edge.attr['style'] = 'dashed'
if edge.attr['label'] in expl :
edge.attr['color'] = 'blue'
g.add_subgraph(name='legend')
create_legend(g)
g.layout(prog='dot')
return(g.to_string())
#!/usr/bin/env python
#-*- coding:utf-8 -*-
##
## tree.py (reuses parts of the code of SHAP)
##
## Created on: Dec 7, 2018
## Author: Nina Narodytska
## E-mail: narodytska@vmware.com
##
#
#==============================================================================
from anytree import Node, RenderTree,AsciiStyle
import json
import numpy as np
import math
import os
#
#==============================================================================
class xgnode(Node):
def __init__(self, id, parent = None):
Node.__init__(self, id, parent)
self.id = id # node value
self.name = None
self.left_node_id = -1 # left child
self.right_node_id = -1 # right child
self.feature = -1
self.threshold = None
self.values = -1
#iai
self.split = None
def __str__(self):
pref = ' ' * self.depth
if (len(self.children) == 0):
return (pref+ f"leaf:{self.id} {self.values}")
else:
if(self.name is None):
if (self.threshold is None):
return (pref+ f"({self.id}) f{self.feature}")
else:
return (pref+ f"({self.id}) f{self.feature} = {self.threshold}")
else:
if (self.threshold is None):
return (pref+ f"({self.id}) \"{self.name}\"")
else:
return (pref+ f"({self.id}) \"{self.name}\" = {self.threshold}")
#
#==============================================================================
def walk_tree(node):
if (len(node.children) == 0):
# leaf
print(node)
else:
print(node)
walk_tree(node.children[0])
walk_tree(node.children[1])
#
#==============================================================================
def scores_tree(node, sample):
if (len(node.children) == 0):
# leaf
return node.values
else:
feature_branch = node.feature
sample_value = sample[feature_branch]
assert(sample_value is not None)
if(sample_value < node.threshold):
return scores_tree(node.children[0], sample)
else:
return scores_tree(node.children[1], sample)
#
#==============================================================================
def get_json_tree(model, tool, maxdepth=None, fname=None):
"""
returns the dtree in JSON format
"""
jt = None
if tool == "DL85":
jt = model.tree_
elif tool == "IAI":
fname = os.path.splitext(os.path.basename(fname))[0]
dir_name = os.path.join("temp", f"{tool}{maxdepth}")
try:
os.stat(dir_name)
except:
os.makedirs(dir_name)
iai_json = os.path.join(dir_name, fname+'.json')
model.write_json(iai_json)
print(f'load JSON tree from {iai_json} ...')
with open(iai_json) as fp:
jt = json.load(fp)
elif tool == "ITI":
print(f'load JSON tree from {model.json_name} ...')
with open(model.json_name) as fp:
jt = json.load(fp)
#else:
# assert False, 'Unhandled model type: {0}'.format(self.tool)
return jt
#
#==============================================================================
class UploadedDecisionTree:
""" A decision tree.
This object provides a common interface to many different types of models.
"""
def __init__(self, model, tool, fname, maxdepth, feature_names=None, nb_classes = 0):
self.tool = tool
self.model = model
self.tree = None
self.depth = None
self.n_nodes = None
json_tree = get_json_tree(self.model, self.tool, maxdepth, fname)
self.tree, self.n_nodes, self.depth = self.build_tree(json_tree, feature_names)
print("c #nodes:", self.n_nodes)
print("c depth:", self.depth)
def print_tree(self):
print("DT model:")
walk_tree(self.tree)
def dump(self, fvmap, filename=None, maxdepth=None, output='temp', feat_names=None):
"""
save the dtree and data map in .dt/.map file
"""
def walk_tree(node, domains, internal, terminal):
"""
extract internal (non-term) & terminal nodes
"""
if (len(node.children) == 0): # leaf node
terminal.append((node.id, node.values))
else:
assert (node.children[0].id == node.left_node_id)
assert (node.children[1].id == node.right_node_id)
f = f"f{node.feature}"
if self.tool == "DL85":
l,r = (1,0)
internal.append((node.id, f, l, node.children[0].id))
internal.append((node.id, f, r, node.children[1].id))
elif self.tool == "ITI":
#l,r = (0,1)
if len(fvmap[f]) > 2:
n = 0
for v in fvmap[f]:
if (fvmap[f][v][2] == node.threshold) and \
(fvmap[f][v][1] == True):
l = v
n = n + 1
if (fvmap[f][v][2] == node.threshold) and \
(fvmap[f][v][1] == False):
r = v
n = n + 1
assert (n == 2)
elif (fvmap[f][0][2] == node.threshold):
l,r = (0,1)
else:
assert (fvmap[f][1][2] == node.threshold)
l,r = (1,0)
internal.append((node.id, f, l, node.children[0].id))
internal.append((node.id, f, r, node.children[1].id))
elif self.tool == "IAI":
left, right = [], []
for p in fvmap[f]:
if fvmap[f][p][1] == True:
assert (fvmap[f][p][2] in node.split)
if node.split[fvmap[f][p][2]]:
left.append(p)
else:
right.append(p)
internal.extend([(node.id, f, l, node.children[0].id) for l in left])
internal.extend([(node.id, f, r, node.children[1].id) for r in right])
elif self.tool == 'SKL':
left, right = [], []
for j in domains[f]: #[(j, fvmap[f][j][2]) for j in fvmap[f] if(fvmap[f][j][1])]:
if np.float32(fvmap[f][j][2]) <= node.threshold:
left.append(j)
else:
right.append(j)
internal.extend([(node.id, f, l, node.children[0].id) for l in left])
internal.extend([(node.id, f, r, node.children[1].id) for r in right])
dom0, dom1 = dict(), dict()
dom0.update(domains)
dom1.update(domains)
dom0[f] = left
dom1[f] = right
else:
assert False, 'Unhandled model type: {0}'.format(self.tool)
internal, terminal = walk_tree(node.children[0], dom0, internal, terminal)
internal, terminal = walk_tree(node.children[1], dom1, internal, terminal)
return internal, terminal
domains = {f:[j for j in fvmap[f] if((fvmap[f][j][1]))] for f in fvmap}
internal, terminal = walk_tree(self.tree, domains, [], [])
if filename and maxdepth:
fname = os.path.splitext(os.path.basename(filename))[0]
dir_name = os.path.join(output, 'tree', fname)
dir_name = os.path.join(dir_name, f"{self.tool}{maxdepth}")
if self.tool == 'ITI':
dir_name = os.path.join(dir_name, self.tool)
elif filename:
fname = os.path.splitext(os.path.basename(filename))[0]
dir_name = os.path.join(output, f'tree/{fname}/{self.tool}')
else:
fname = "tree"
dir_name = os.path.join(output)
try:
os.stat(dir_name)
except:
os.makedirs(dir_name)
fname = os.path.join(dir_name, fname+'.dt')
print("saving dtree to ", fname)
with open(fname, 'w') as fp:
fp.write(f"{self.n_nodes}\n{self.tree.id}\n")
fp.write(f"I {' '.join(dict.fromkeys([str(i) for i,_,_,_ in internal]))}\n")
fp.write(f"T {' '.join([str(i) for i,_ in terminal ])}\n")
for i,c in terminal:
fp.write(f"{i} T {c}\n")
for i,f, j, n in internal:
fp.write(f"{i} {f} {j} {n}\n")
if filename and maxdepth:
fname = os.path.splitext(os.path.basename(filename))[0]
dir_name = os.path.join(output, 'map', fname)
if self.tool == "ITI":
dir_name = os.path.join(dir_name, self.tool)
else:
dir_name = os.path.join(dir_name, f'{self.tool}{maxdepth}')
elif filename:
fname = os.path.splitext(os.path.basename(filename))[0]
dir_name = os.path.join(output, f'map/{fname}/{self.tool}')
else:
fname = "tree"
dir_name = os.path.join(output)
try:
os.stat(dir_name)
except:
os.makedirs(dir_name)
fname = os.path.join(dir_name, fname+'.map')
print("saving dtree map to ", fname)
with open(fname, 'w') as fp:
fp.write("Categorical\n")
fp.write(f"{len(fvmap)}\n")
for f in fvmap:
for v in fvmap[f]:
if (fvmap[f][v][1] == True):
fp.write(f"{f} {v} ={fvmap[f][v][2]}\n")
if (fvmap[f][v][1] == False) and self.tool == "ITI":
fp.write(f"{f} {v} !={fvmap[f][v][2]}\n")
if feat_names is not None:
if filename:
fname = os.path.splitext(os.path.basename(filename))[0]
fname = os.path.join(dir_name, fname+'.txt')
else:
fname = os.path.join(dir_name, 'map.txt')
print("saving feature map to ", fname)
with open(fname, 'w') as fp:
for i,fid in enumerate(feat_names):
f=f'f{i}'
fp.write(f'{fid}:{f},'+",".join([f'{fvmap[f][v][2]}:{v}' for v in fvmap[f] if(fvmap[f][v][1])])+'\n')
#
print('Done')
# end dump fct
def build_tree(self, json_tree=None, feature_names=None):
def extract_data(json_node, idx, depth=0, root=None, feature_names=None):
"""
Incremental Tree Inducer / DL8.5
"""
if (root is None):
node = xgnode(idx)
else:
node = xgnode(idx, parent = root)
if "feat" in json_node:
if self.tool == "ITI": #f0, f1, ...,fn
node.feature = json_node["feat"][1:]
else:
node.feature = json_node["feat"] #json DL8.5
if (feature_names is not None):
node.name = feature_names[node.feature]
if self.tool == "ITI":
node.threshold = json_node[json_node["feat"]]
node.left_node_id = idx + 1
_, idx, d1 = extract_data(json_node['left'], idx+1, depth+1, node, feature_names)
node.right_node_id = idx + 1
_, idx, d2 = extract_data(json_node['right'], idx+1, depth+1, node, feature_names)
depth = max(d1, d2)
elif "value" in json_node:
node.values = json_node["value"]
return node, idx, depth
def extract_iai(lnr, json_tree, feature_names = None):
"""
Interpretable AI tree
"""
json_tree = json_tree['tree_']
nodes = []
depth = 0
for i, json_node in enumerate(json_tree["nodes"]):
if json_node["parent"] == -2:
node = xgnode(json_node["id"])
else:
root = nodes[json_node["parent"] - 1]
node = xgnode(json_node["id"], parent = root)
assert (json_node["parent"] > 0)
assert (root.id == json_node["parent"])
if json_node["split_type"] == "LEAF":
#node.values = target[json_node["fit"]["class"] - 1]
##assert json_node["fit"]["probs"][node.values] == 1.0
node.values = lnr.get_classification_label(node.id)
depth = max(depth, lnr.get_depth(node.id))
assert (json_node["lower_child"] == -2 and json_node["upper_child"] == -2)
elif json_node["split_type"] == "MIXED":
#node.feature = json_node["split_mixed"]["categoric_split"]["feature"] - 1
#node.left_node_id = json_node["lower_child"]
#node.right_node_id = json_node["upper_child"]
node.feature = lnr.get_split_feature(node.id)
node.left_node_id = lnr.get_lower_child(node.id)
node.right_node_id = lnr.get_upper_child(node.id)
node.split = lnr.get_split_categories(node.id)
assert (json_node["split_mixed"]["categoric_split"]["feature"] > 0)
assert (json_node["lower_child"] > 0)
assert (json_node["upper_child"] > 0)
else:
assert False, 'Split feature is not \"categoric_split\"'
nodes.append(node)
return nodes[0], json_tree["node_count"], depth
def extract_skl(tree_, classes_, feature_names=None):
"""
scikit-learn tree
"""
def get_CART_tree(tree_):
n_nodes = tree_.node_count
children_left = tree_.children_left
children_right = tree_.children_right
#feature = tree_.feature
#threshold = tree_.threshold
#values = tree_.value
node_depth = np.zeros(shape=n_nodes, dtype=np.int64)
is_leaf = np.zeros(shape=n_nodes, dtype=bool)
stack = [(0, -1)] # seed is the root node id and its parent depth
while len(stack) > 0:
node_id, parent_depth = stack.pop()
node_depth[node_id] = parent_depth + 1
# If we have a test node
if (children_left[node_id] != children_right[node_id]):
stack.append((children_left[node_id], parent_depth + 1))
stack.append((children_right[node_id], parent_depth + 1))
else:
is_leaf[node_id] = True
return children_left, children_right, is_leaf, node_depth
children_left, children_right, is_leaf, node_depth = get_CART_tree(tree_)
feature = tree_.feature
threshold = tree_.threshold
values = tree_.value
m = tree_.node_count
assert (m > 0), "Empty tree"
def extract_data(idx, root = None, feature_names = None):
i = idx
assert (i < m), "Error index node"
if (root is None):
node = xgnode(i)
else:
node = xgnode(i, parent = root)
if is_leaf[i]:
node.values = classes_[np.argmax(values[i])]
else:
node.feature = feature[i]
if (feature_names):
node.name = feature_names[feature[i]]
node.threshold = threshold[i]
node.left_node_id = children_left[i]
node.right_node_id = children_right[i]
extract_data(node.left_node_id, node, feature_names)
extract_data(node.right_node_id, node, feature_names)
return node
root = extract_data(0, None, feature_names)
return root, tree_.node_count, tree_.max_depth
root, node_count, maxdepth = None, None, None
if(self.tool == 'SKL'):
root, node_count, maxdepth = extract_skl(self.model.tree_, self.model.classes_, feature_names)
if json_tree:
if self.tool == "IAI":
root, node_count, maxdepth = extract_iai(self.model, json_tree, feature_names)
else:
root,_,maxdepth = extract_data(json_tree, 1, 0, None, feature_names)
node_count = json.dumps(json_tree).count('feat') + json.dumps(json_tree).count('value')
return root, node_count, maxdepth
\ No newline at end of file
...@@ -19,7 +19,6 @@ class Model(): ...@@ -19,7 +19,6 @@ class Model():
self.ml_model = '' self.ml_model = ''
self.pretrained_model = '' self.pretrained_model = ''
self.typ_data = ''
self.instance = '' self.instance = ''
...@@ -34,10 +33,9 @@ class Model(): ...@@ -34,10 +33,9 @@ class Model():
self.component_class = self.dict_components[self.ml_model] self.component_class = self.dict_components[self.ml_model]
self.component_class = globals()[self.component_class] self.component_class = globals()[self.component_class]
def update_pretrained_model(self, pretrained_model_update, typ_data): def update_pretrained_model(self, pretrained_model_update, filename_model):
self.pretrained_model = pretrained_model_update self.pretrained_model = pretrained_model_update
self.typ_data = typ_data self.component = self.component_class(self.pretrained_model, filename_model)
self.component = self.component_class(self.pretrained_model, self.typ_data)
def update_instance(self, instance, enum, xtype, solver="g3"): def update_instance(self, instance, enum, xtype, solver="g3"):
self.instance = instance self.instance = instance
......
...@@ -11,4 +11,5 @@ scipy>=1.2.1 ...@@ -11,4 +11,5 @@ scipy>=1.2.1
dash_bootstrap_components dash_bootstrap_components
dash_interactive_graphviz dash_interactive_graphviz
python-sat[pblib,aiger] python-sat[pblib,aiger]
pygraphviz pygraphviz
\ No newline at end of file anytree
\ No newline at end of file
...@@ -12,14 +12,14 @@ def parse_contents_graph(contents, filename): ...@@ -12,14 +12,14 @@ def parse_contents_graph(contents, filename):
try: try:
if '.pkl' in filename: if '.pkl' in filename:
data = pickle.load(io.BytesIO(decoded)) data = pickle.load(io.BytesIO(decoded))
typ = 'pkl'
except Exception as e: except Exception as e:
print(e) print(e)
return html.Div([ return html.Div([
'There was an error processing this file.' 'There was an error processing this file.'
]) ])
return data, typ return data
def parse_contents_instance(contents, filename): def parse_contents_instance(contents, filename):
content_type, content_string = contents.split(',') content_type, content_string = contents.split(',')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment