diff --git a/pages/application/RandomForest/RandomForestComponent.py b/pages/application/RandomForest/RandomForestComponent.py index 25f27d5e666a38a711f32ce18a3ff8ec5d83ae3e..c02c75ec6520c8bd14e4236ff7c5ceda620dcf43 100644 --- a/pages/application/RandomForest/RandomForestComponent.py +++ b/pages/application/RandomForest/RandomForestComponent.py @@ -36,8 +36,8 @@ class RandomForestComponent: options["n_classes"] = self.data.num_class options["feature_names"] = self.data.feature_names options["n_features"] = self.data.nb_features - if isinstance(model, RandomForestClassifier) or isinstance(model, VotingClassifier) or isinstance(model, - xrf.rndmforest.RF2001): + if isinstance(model, RandomForestClassifier) or isinstance(model, VotingClassifier)\ + or isinstance(model, xrf.rndmforest.RF2001): self.random_forest = XRF(model, self.data) elif isinstance(model, XGBRFClassifier): self.random_forest = XGBRandomForest(options, from_model=model) diff --git a/pages/application/RandomForest/utils/xgbooster/explain.py b/pages/application/RandomForest/utils/xgbooster/explain.py index a4e5f8b1e107d51263a616ce5c84588807bdce61..d522b38fe71c362a8cabd51390a8d00610c826f2 100644 --- a/pages/application/RandomForest/utils/xgbooster/explain.py +++ b/pages/application/RandomForest/utils/xgbooster/explain.py @@ -158,9 +158,7 @@ class SMTExplainer(object): else: self.preamble.append(v) - explanation_dic = {} - explanation_dic["explaning instance"] = ' explaining: "IF {0} THEN {1}"'.format(' AND '.join(self.preamble), self.output) - return explanation_dic + return "IF {0} THEN {1}".format(' AND '.join(self.preamble), self.output) def explain(self, sample, smallest, expl_ext=None, prefer_ext=False): """ @@ -168,7 +166,8 @@ class SMTExplainer(object): """ # adapt the solver to deal with the current sample - explanation_dic = self.prepare(sample) + explanation_dic = {} + explanation_dic["Instance :"] = self.prepare(sample) # saving external explanation to be minimized further if expl_ext == None or prefer_ext: @@ -186,11 +185,10 @@ class SMTExplainer(object): else: self.compute_smallest() - expl = sorted([self.sel2fid[h] for h in self.rhypos]) - explanation_dic["explanation brute "] = expl - self.preamble = [self.preamble[i] for i in expl] - explanation_dic["explanation"] = ' explanation: "IF {0} THEN {1}"'.format(' AND '.join(self.preamble), self.xgb.target_name[self.out_id]) - explanation_dic["Hyphothesis left"] = ' # hypos left:' + str(len(self.rhypos)) + explanation = sorted([self.sel2fid[h] for h in self.rhypos]) + self.preamble = [self.preamble[i] for i in explanation] + explanation_dic["explanation: "] = "IF {0} THEN {1}".format(' AND '.join(self.preamble), self.xgb.target_name[self.out_id]) + explanation_dic["Hyphothesis left"] = str(len(self.rhypos)) return explanation_dic diff --git a/pages/application/RandomForest/utils/xgbooster/xgbooster.py b/pages/application/RandomForest/utils/xgbooster/xgbooster.py index b7e62b449357fccc267d08442bd6e2fcc7388f3e..37e86c4f1f7f779a62b8e9ac3ad3a5e03121b575 100644 --- a/pages/application/RandomForest/utils/xgbooster/xgbooster.py +++ b/pages/application/RandomForest/utils/xgbooster/xgbooster.py @@ -56,7 +56,7 @@ class XGBooster(object): if test_on: encoder.test_sample(np.array(test_on)) - def explain(self, sample, smallest, solver, use_lime=None, use_anchor=None, use_shap=None, + def explain_sample(self, sample, smallest, solver, use_lime=None, use_anchor=None, use_shap=None, expl_ext=None, prefer_ext=False, nof_feats=5): """ Explain a prediction made for a given sample with a previously @@ -81,6 +81,15 @@ class XGBooster(object): return expl + def explain(self, samples, smallest, solver, use_lime=None, use_anchor=None, use_shap=None, + expl_ext=None, prefer_ext=False, nof_feats=5): + explanations = [] + for sample in samples : + explanations.append(self.explain_sample(sample, smallest, solver, use_lime, use_anchor, use_shap, + expl_ext, prefer_ext, nof_feats)) + + return explanations + def validate(self, sample, expl): """ Make an attempt to show that a given explanation is optimistic. diff --git a/pages/application/RandomForest/utils/xgbrf/explain.py b/pages/application/RandomForest/utils/xgbrf/explain.py index 7487f0eb683ad30b5c19546a5eb5d8db9009b28e..dde8e46b6f312c42b47fd74f2f290eb827609c4a 100644 --- a/pages/application/RandomForest/utils/xgbrf/explain.py +++ b/pages/application/RandomForest/utils/xgbrf/explain.py @@ -150,28 +150,25 @@ class SMTExplainer(object): disj.append(GT(self.outs[i], self.outs[self.out_id])) self.oracle.add_assertion(Implies(self.selv, Or(disj))) - if self.verbose: - inpvals = self.xgb.readable_sample(sample) + inpvals = self.xgb.readable_sample(sample) - self.preamble = [] - for f, v in zip(self.xgb.feature_names, inpvals): - if f not in v: - self.preamble.append('{0} = {1}'.format(f, v)) - else: - self.preamble.append(v) + self.preamble = [] + for f, v in zip(self.xgb.feature_names, inpvals): + if f not in v: + self.preamble.append('{0} = {1}'.format(f, v)) + else: + self.preamble.append(v) - print('\n explaining: "IF {0} THEN {1}"'.format(' AND '.join(self.preamble), self.output)) + return "IF {0} THEN {1}".format(' AND '.join(self.preamble), self.output) def explain(self, sample, smallest, expl_ext=None, prefer_ext=False): """ Hypotheses minimization. """ - - self.time = resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime + \ - resource.getrusage(resource.RUSAGE_SELF).ru_utime + explanation_dic = {} # adapt the solver to deal with the current sample - self.prepare(sample) + explanation_dic["Instance : "] = self.prepare(sample) # saving external explanation to be minimized further if expl_ext == None or prefer_ext: @@ -182,28 +179,19 @@ class SMTExplainer(object): # if satisfiable, then the observation is not implied by the hypotheses if self.oracle.solve([self.selv] + [h for h, c in zip(self.rhypos, self.to_consider) if c]): - print(' no implication!') - print(self.oracle.get_model()) - sys.exit(1) - - if not smallest: - self.compute_minimal(prefer_ext=prefer_ext) + explanation_dic["no implication"] = self.oracle.get_model() else: - self.compute_smallest() - - self.time = resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime + \ - resource.getrusage(resource.RUSAGE_SELF).ru_utime - self.time - - expl = sorted([self.sel2fid[h] for h in self.rhypos]) - #print('expl >>>> : ', expl) + if not smallest: + self.compute_minimal(prefer_ext=prefer_ext) + else: + self.compute_smallest() - if self.verbose: - self.preamble = [self.preamble[i] for i in expl] - print(' explanation: "IF {0} THEN {1}"'.format(' AND '.join(self.preamble), self.xgb.target_name[self.out_id])) - print(' # hypos left:', len(self.rhypos)) - print(' time: {0:.2f}'.format(self.time)) + explanation = sorted([self.sel2fid[h] for h in self.rhypos]) + self.preamble = [self.preamble[i] for i in explanation] + explanation_dic['explanation: '] = "IF {0} THEN {1}".format(' AND '.join(self.preamble), self.xgb.target_name[self.out_id]) + explanation_dic['# hypos left:'] = str(len(self.rhypos)) - return expl + return explanation_dic def compute_minimal(self, prefer_ext=False): """ diff --git a/pages/application/RandomForest/utils/xgbrf/xgb_rf.py b/pages/application/RandomForest/utils/xgbrf/xgb_rf.py index 840def741bb12263a980354ad1f6449ff18a9a97..9c921742d744a8a0e819598cd129ff94ed0e1c01 100644 --- a/pages/application/RandomForest/utils/xgbrf/xgb_rf.py +++ b/pages/application/RandomForest/utils/xgbrf/xgb_rf.py @@ -70,7 +70,7 @@ class XGBRandomForest(object): if test_on: encoder.test_sample(np.array(test_on)) - def explain(self, sample, smallest, solver, use_lime=None, use_anchor=None, use_shap=None, + def explain_sample(self, sample, smallest, solver, use_lime=None, use_anchor=None, use_shap=None, expl_ext=None, prefer_ext=False, nof_feats=5): """ Explain a prediction made for a given sample with a previously @@ -93,6 +93,15 @@ class XGBRandomForest(object): return expl + def explain(self, samples, smallest, solver, use_lime=None, use_anchor=None, use_shap=None, + expl_ext=None, prefer_ext=False, nof_feats=5): + explanations = [] + for sample in samples: + explanations.append(self.explain_sample(sample, smallest, solver, use_lime, use_anchor, use_shap, + expl_ext, prefer_ext, nof_feats)) + + return explanations + def validate(self, sample, expl): """ Make an attempt to show that a given explanation is optimistic. diff --git a/pages/application/RandomForest/utils/xrf/checker.py b/pages/application/RandomForest/utils/xrf/checker.py index 5fb8650613bc4fe4e8d9d033e71476729a051164..314e52e99061467c78f69974d80d4522b0e7ec48 100644 --- a/pages/application/RandomForest/utils/xrf/checker.py +++ b/pages/application/RandomForest/utils/xrf/checker.py @@ -1,5 +1,5 @@ # -#============================================================================== +# ============================================================================== import numpy as np import math @@ -8,10 +8,12 @@ import six from pysat.formula import CNF, IDPool from pysat.solvers import Solver from pysat.card import CardEnc, EncType -#from itertools import combinations + + +# from itertools import combinations # -#============================================================================== +# ============================================================================== def predict_tree(node, sample): if (len(node.children) == 0): # leaf @@ -21,31 +23,30 @@ def predict_tree(node, sample): sample_value = sample[feat] if sample_value is None: return predict_tree(node.children[0], sample) - elif(sample_value < node.threshold): + elif (sample_value < node.threshold): return predict_tree(node.children[0], sample) else: return predict_tree(node.children[1], sample) - - + + # -#============================================================================== +# ============================================================================== class Checker: - + def __init__(self, forest, num_class, feature_names): self.forest = forest self.num_class = num_class self.feature_names = feature_names self.cnf = None self.vpool = IDPool() - self.intvs = None self.intvs = {'{0}'.format(f): set([]) for f in feature_names if '_' not in f} for tree in self.forest.trees: self.traverse_intervals(tree) - self.intvs = {f: sorted(self.intvs[f]) + - ([math.inf] if len(self.intvs[f]) else []) - for f in six.iterkeys(self.intvs)} + self.intvs = {f: sorted(self.intvs[f]) + + ([math.inf] if len(self.intvs[f]) else []) + for f in six.iterkeys(self.intvs)} self.imaps, self.ivars = {}, {} self.thvars = {} for feat, intvs in six.iteritems(self.intvs): @@ -59,8 +60,7 @@ class Checker: if ub != math.inf: thvar = self.newVar('{0}_th{1}'.format(feat, i)) self.thvars[feat].append(thvar) - - + self.cnf = CNF() #### cvars = [self.newVar('class{0}'.format(i)) for i in range(num_class)] @@ -68,47 +68,46 @@ class Checker: ctvars = [[] for t in range(num_tree)] for k in range(num_tree): for j in range(self.num_class): - var = self.newVar('class{0}_tr{1}'.format(j,k)) - ctvars[k].append(var) - ##### - for k, tree in enumerate(self.forest.trees): + var = self.newVar('class{0}_tr{1}'.format(j, k)) + ctvars[k].append(var) + ##### + for k, tree in enumerate(self.forest.trees): self.traverse(tree, k, []) - card = CardEnc.atmost(lits=ctvars[k], vpool=self.vpool,encoding=EncType.cardnetwrk) - self.cnf.extend(card.clauses) - ###### + card = CardEnc.atmost(lits=ctvars[k], vpool=self.vpool, encoding=EncType.cardnetwrk) + self.cnf.extend(card.clauses) + ###### for f, intvs in six.iteritems(self.ivars): if not len(intvs): continue - self.cnf.append(intvs) + self.cnf.append(intvs) card = CardEnc.atmost(lits=intvs, vpool=self.vpool, encoding=EncType.cardnetwrk) - self.cnf.extend(card.clauses) + self.cnf.extend(card.clauses) for f, threshold in six.iteritems(self.thvars): for j, thvar in enumerate(threshold): - d = j+1 - pos, neg = self.ivars[f][d:], self.ivars[f][:d] + d = j + 1 + pos, neg = self.ivars[f][d:], self.ivars[f][:d] if j == 0: self.cnf.append([thvar, neg[-1]]) self.cnf.append([-thvar, -neg[-1]]) else: - self.cnf.append([thvar, neg[-1], -threshold[j-1]]) - self.cnf.append([-thvar, threshold[j-1]]) + self.cnf.append([thvar, neg[-1], -threshold[j - 1]]) + self.cnf.append([-thvar, threshold[j - 1]]) self.cnf.append([-thvar, -neg[-1]]) - + if j == len(threshold) - 1: self.cnf.append([-thvar, pos[0]]) self.cnf.append([thvar, -pos[0]]) else: - self.cnf.append([-thvar, pos[0], threshold[j+1]]) + self.cnf.append([-thvar, pos[0], threshold[j + 1]]) self.cnf.append([thvar, -pos[0]]) - self.cnf.append([thvar, -threshold[j+1]]) - - - def newVar(self, name): - if name in self.vpool.obj2id: #var has been already created + self.cnf.append([thvar, -threshold[j + 1]]) + + def newVar(self, name): + if name in self.vpool.obj2id: # var has been already created return self.vpool.obj2id[name] var = self.vpool.id('{0}'.format(name)) - return var - + return var + def traverse(self, tree, k, clause): if tree.children: f = tree.name @@ -121,13 +120,11 @@ class Checker: var = self.newVar(tree.name) pos, neg = var, -var self.traverse(tree.children[0], k, clause + [-neg]) - self.traverse(tree.children[1], k, clause + [-pos]) + self.traverse(tree.children[1], k, clause + [-pos]) else: # leaf node - cvar = self.newVar('class{0}_tr{1}'.format(tree.values,k)) + cvar = self.newVar('class{0}_tr{1}'.format(tree.values, k)) self.cnf.append(clause + [cvar]) - #self.printLits(clause + [cvar]) - - + # self.printLits(clause + [cvar]) def traverse_intervals(self, tree): if tree.children: @@ -136,198 +133,195 @@ class Checker: if f in self.intvs: self.intvs[f].add(v) self.traverse_intervals(tree.children[0]) - self.traverse_intervals(tree.children[1]) - + self.traverse_intervals(tree.children[1]) def check(self, sample, expl): print("check PI-expl") slv = Solver(name="glucose3") slv.append_formula(self.cnf) - - pred = self.forest.predict_inst(sample) + + pred = self.forest.predict_inst(sample) num_tree = len(self.forest.trees) ##### cvars = [self.newVar('class{0}'.format(i)) for i in range(self.num_class)] ctvars = [[] for t in range(num_tree)] for k in range(num_tree): for j in range(self.num_class): - var = self.newVar('class{0}_tr{1}'.format(j,k)) - ctvars[k].append(var) - # + var = self.newVar('class{0}_tr{1}'.format(j, k)) + ctvars[k].append(var) + # rhs = num_tree - 1 for j in range(pred): - lhs = [ctvars[k][j] for k in range(num_tree)] + [ - ctvars[k][pred] for k in range(num_tree)] - atms = CardEnc.atmost(lits = lhs, bound = rhs, vpool=self.vpool, encoding=EncType.cardnetwrk) - #add maj class selector to activate/deactivate eq atmsk - #self.cnf.extend([cl + [-cvars[pred]] for cl in atms]) - slv.append_formula([cl + [-cvars[pred]] for cl in atms]) - rhs = num_tree + lhs = [ctvars[k][j] for k in range(num_tree)] + [- ctvars[k][pred] for k in range(num_tree)] + atms = CardEnc.atmost(lits=lhs, bound=rhs, vpool=self.vpool, encoding=EncType.cardnetwrk) + # add maj class selector to activate/deactivate eq atmsk + # self.cnf.extend([cl + [-cvars[pred]] for cl in atms]) + slv.append_formula([cl + [-cvars[pred]] for cl in atms]) + rhs = num_tree for j in range(pred + 1, self.num_class): - lhs = [ctvars[k][j] for k in range(num_tree)] + [ - ctvars[k][pred] for k in range(num_tree)] - atms = CardEnc.atmost(lits = lhs, bound = rhs, vpool=self.vpool, encoding=EncType.cardnetwrk) - #self.cnf.extend([cl + [-cvars[pred]] for cl in atms]) - slv.append_formula([cl + [-cvars[pred]] for cl in atms]) - ######## + lhs = [ctvars[k][j] for k in range(num_tree)] + [- ctvars[k][pred] for k in range(num_tree)] + atms = CardEnc.atmost(lits=lhs, bound=rhs, vpool=self.vpool, encoding=EncType.cardnetwrk) + # self.cnf.extend([cl + [-cvars[pred]] for cl in atms]) + slv.append_formula([cl + [-cvars[pred]] for cl in atms]) + ######## ######## rhs = num_tree for j in range(pred): - lhs = [ - ctvars[k][j] for k in range(num_tree)] + [ctvars[k][pred] for k in range(num_tree)] - atms = CardEnc.atmost(lits = lhs, bound = rhs, vpool=self.vpool, encoding=EncType.cardnetwrk) - #self.cnf.extend([cl+[-cvars[j]] for cl in atms]) - slv.append_formula([cl+[-cvars[j]] for cl in atms]) - rhs = num_tree - 1 - for j in range(pred + 1, self.num_class): - lhs = [ - ctvars[k][j] for k in range(num_tree)] + [ctvars[k][pred] for k in range(num_tree)] - atms = CardEnc.atmost(lits = lhs, bound = rhs, vpool=self.vpool, encoding=EncType.cardnetwrk) - #self.cnf.extend([cl+[-cvars[j]] for cl in atms]) - slv.append_formula([cl+[-cvars[j]] for cl in atms]) + lhs = [- ctvars[k][j] for k in range(num_tree)] + [ctvars[k][pred] for k in range(num_tree)] + atms = CardEnc.atmost(lits=lhs, bound=rhs, vpool=self.vpool, encoding=EncType.cardnetwrk) + # self.cnf.extend([cl+[-cvars[j]] for cl in atms]) + slv.append_formula([cl + [-cvars[j]] for cl in atms]) + rhs = num_tree - 1 + for j in range(pred + 1, self.num_class): + lhs = [- ctvars[k][j] for k in range(num_tree)] + [ctvars[k][pred] for k in range(num_tree)] + atms = CardEnc.atmost(lits=lhs, bound=rhs, vpool=self.vpool, encoding=EncType.cardnetwrk) + # self.cnf.extend([cl+[-cvars[j]] for cl in atms]) + slv.append_formula([cl + [-cvars[j]] for cl in atms]) ############ - #self.cnf.append(cvars) - card = CardEnc.atmost(lits=cvars, vpool=self.vpool, encoding=EncType.cardnetwrk) #AtMostOne constraint - #self.cnf.extend(card.clauses) + # self.cnf.append(cvars) + card = CardEnc.atmost(lits=cvars, vpool=self.vpool, encoding=EncType.cardnetwrk) # AtMostOne constraint + # self.cnf.extend(card.clauses) slv.add_clause(cvars) - slv.append_formula(card.clauses) - + slv.append_formula(card.clauses) + assums = [] # var selectors to be used as assumptions - #sel2fid = {} # selectors to original feature ids - #sel2vid = {} # selectors to categorical feature ids - #sel2v = {} # selectors to (categorical/interval) values + # sel2fid = {} # selectors to original feature ids + # sel2vid = {} # selectors to categorical feature ids + # sel2v = {} # selectors to (categorical/interval) values sel_expl = [] - - #inps = ['f{0}'.format(f) for f in range(len(sample))] # works only with pure continuous feats + + # inps = ['f{0}'.format(f) for f in range(len(sample))] # works only with pure continuous feats inps = self.feature_names - + for i, (inp, val) in enumerate(zip(inps, sample)): - if len(self.intvs[inp]): - v = next((intv for intv in self.intvs[inp] if intv > val), None) - assert(v is not None) - selv = self.newVar('selv_{0}'.format(inp)) + if len(self.intvs[inp]): + v = next((intv for intv in self.intvs[inp] if intv > val), None) + assert (v is not None) + selv = self.newVar('selv_{0}'.format(inp)) assums.append(selv) ## if i in expl: sel_expl.append(selv) - #print('{0}={1}'.format('selv_{0}'.format(inp), val)) + # print('{0}={1}'.format('selv_{0}'.format(inp), val)) ## - for j,p in enumerate(self.ivars[inp]): + for j, p in enumerate(self.ivars[inp]): cl = [-selv] if j == self.imaps[inp][v]: cl += [p] - #self.sel2v[selv] = p + # self.sel2v[selv] = p else: - cl += [-p] - #self.cnf.append(cl) + cl += [-p] + # self.cnf.append(cl) slv.add_clause(cl) - assums = sorted(set(assums)) - #print(sel_expl, assums) + assums = sorted(set(assums)) + # print(sel_expl, assums) sel_pred = cvars[pred] - - #slv = Solver(name="glucose3") - #slv.append_formula(self.cnf) - - - assert (slv.solve(assumptions=sel_expl+[sel_pred])), '{0} is not an explanation.'.format(expl) + + # slv = Solver(name="glucose3") + # slv.append_formula(self.cnf) + + assert (slv.solve(assumptions=sel_expl + [sel_pred])), '{0} is not an explanation.'.format(expl) print('expl:{0} is valid'.format(expl)) - + for i, p in enumerate(sel_expl): - #print(i,p) + # print(i,p) to_test = sel_expl[:i] + sel_expl[(i + 1):] + [-sel_pred] print(to_test) assert slv.solve(assumptions=to_test), '{0} is not minimal explanation.'.format(expl) - + # delete sat solver slv.delete() slv = None - + print('expl:{0} is minimal'.format(expl)) print() - + def check_expl(sample, expl, forest, intvs): - print("check PI-expl") - + pred = forest.predict_inst(sample) - - sample_expl = [None]*len(sample) + + sample_expl = [None] * len(sample) for p in expl: sample_expl[p] = sample[p] - + # initializing the intervals - #intvs = {'f{0}'.format(f): set([]) for f in range(len(sample))} - #for tree in forest.trees: + # intvs = {'f{0}'.format(f): set([]) for f in range(len(sample))} + # for tree in forest.trees: # traverse_intervals(tree) - + # first, check if expl is an explanation scores = [predict_tree(dt, sample_expl) for dt in forest.trees] scores = np.asarray(scores) maj = np.argmax(np.bincount(scores)) - + assert maj == pred, '{0} is not an explanation.'.format(expl) - + print('expl:{0} is valid'.format(expl)) print("pred = ", pred) - + sample_expl = sample - + feats = ['f{0}'.format(f) for f in expl] univ = [(i, f) for i, f in enumerate(intvs) if (len(intvs[f]) and (f not in feats))] - + # Now, check if expl is a minimal for p, f in zip(expl, feats): print("{0}={1}".format(f, sample_expl[p])) - print([-math.inf]+intvs[f]) - assert(len(intvs[f])) - + print([-math.inf] + intvs[f]) + assert (len(intvs[f])) + # calculate possible values for f possible_val = [] d = next((i for i, v in enumerate(intvs[f]) if v > sample_expl[p]), None) - assert(d is not None) - print("d=",d) - + assert (d is not None) + print("d=", d) + if d: - #possible_val.append(intvs[f][0] - 1) + # possible_val.append(intvs[f][0] - 1) possible_val.append(-math.inf) - print(intvs[f][:d-1]) - for i, v in enumerate(intvs[f][:d-1]): + print(intvs[f][:d - 1]) + for i, v in enumerate(intvs[f][:d - 1]): possible_val.append((v + intvs[f][i + 1]) * 0.5) - - for i, v in enumerate(intvs[f][d+1:]): - #print('{0} + {1}'.format(v , intvs[f][d+i])) - possible_val.append((v + intvs[f][d+i]) * 0.5) - #if v == math.inf: + + for i, v in enumerate(intvs[f][d + 1:]): + # print('{0} + {1}'.format(v , intvs[f][d+i])) + possible_val.append((v + intvs[f][d + i]) * 0.5) + # if v == math.inf: # assert(v == intvs[f][-1]) # possible_val.append(v + 1) - #else: + # else: # possible_val.append((v + intvs[f][i - 1]) * 0.5) - - - ## print("{0} => {1} | {2} , {3}".format(f,sample_expl[p], [-math.inf]+intvs[f], possible_val)) + + ## print("{0} => {1} | {2} , {3}".format(f,sample_expl[p], [-math.inf]+intvs[f], possible_val)) for v in possible_val: sample_expl[p] = v for uf in univ: - for x in ([-math.inf]+intvs[uf[1]]): + for x in ([-math.inf] + intvs[uf[1]]): print('{0}={1}'.format(uf[1], x)) - sample_expl[uf[0]] = x + sample_expl[uf[0]] = x scores = [predict_tree(dt, sample_expl) for dt in forest.trees] scores = np.asarray(scores) maj = np.argmax(np.bincount(scores)) - #print("maj: {0} | {1}={2}".format( maj, f, v)) + # print("maj: {0} | {1}={2}".format( maj, f, v)) if maj != pred: - break + break sample_expl[uf[0]] = sample[p] - - print("maj: {0} | {1}={2}".format( maj, f, v)) - - else: + + print("maj: {0} | {1}={2}".format(maj, f, v)) + + else: assert False, '{0} is not minimal explanation.'.format(expl) - + sample_expl[p] = sample[p] - + print('expl:{0} is minimal'.format(expl)) print() - - return True - + + return True + + ''' def traverse_intervals(tree, intvs): if tree.children: @@ -342,5 +336,4 @@ def traverse_intervals(tree, intvs): else: return intvs -''' - +'''