-
Caroline DE POURTALES authoredCaroline DE POURTALES authored
dtree.py 13.58 KiB
#!/usr/bin/env python
#-*- coding:utf-8 -*-
##
## dtree.py
##
## Created on: Jul 6, 2020
## Author: Alexey Ignatiev
## E-mail: alexey.ignatiev@monash.edu
##
#
#==============================================================================
from __future__ import print_function
import collections
from functools import reduce
from pysat.card import *
from pysat.examples.hitman import Hitman
from pysat.formula import CNF, IDPool
from pysat.solvers import Solver
try: # for Python2
from cStringIO import StringIO
except ImportError: # for Python3
from io import StringIO
from sklearn.tree import _tree
import numpy as np
#
#==============================================================================
class Node():
"""
Node class.
"""
def __init__(self, feat='', vals=[], threshold=None):
"""
Constructor.
"""
self.feat = feat
if threshold is not None :
self.threshold = threshold
else :
self.vals = vals
#
#==============================================================================
class DecisionTree():
"""
Simple decision tree class.
"""
def __init__(self, from_dt=None, from_pickle=None,
mapfile=None, verbose=0):
"""
Constructor.
"""
self.verbose = verbose
self.nof_nodes = 0
self.nof_terms = 0
self.root_node = None
self.terms = []
self.nodes = {}
self.paths = {}
self.feats = []
self.feids = {}
self.fdoms = {}
self.fvmap = {}
# OHE mapping
OHEMap = collections.namedtuple('OHEMap', ['dir', 'opp'])
self.ohmap = OHEMap(dir={}, opp={})
if from_dt:
self.from_dt(from_dt)
elif from_pickle:
self.from_pickle_file(from_pickle)
if mapfile:
self.parse_mapping(mapfile)
else: # no mapping is given
for f in self.feats:
for v in self.fdoms[f]:
self.fvmap[tuple([f, v])] = '{0}={1}'.format(f, v)
#problem de feature names et problem de vals dans node
def from_pickle_file(self, tree):
tree_ = tree.tree_
try:
feature_names = tree.feature_names_in_
except:
print("You did not dump the model with the features names")
feature_names = [str(i) for i in range(tree.n_features_in_)]
class_names = tree.classes_
self.nodes = collections.defaultdict(lambda: Node(feat='', vals={}))
self.terms={}
self.nof_nodes = tree_.node_count
self.nof_terms = 0
self.root_node = 0
feature_name = [
feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
for i in tree_.feature]
def recurse(feats, fdoms, node):
if tree_.feature[node] != _tree.TREE_UNDEFINED:
name = feature_name[node]
val = tree_.threshold[node]
#faire une boucle for des vals ?
self.nodes[int(node)].feat = name
self.nodes[int(node)].vals[int(np.round(val,4))] = int(tree_.children_left[node])
self.nodes[int(node)].feat = name
self.nodes[int(node)].vals[int(4854)] = int(tree_.children_right[node])
feats.add(name)
fdoms[name].add(int(np.round(val,4)))
feats, fdoms = recurse(feats, fdoms, tree_.children_left[node])
fdoms[name].add(4854)
feats, fdoms = recurse(feats, fdoms, tree_.children_right[node])
else:
self.terms[node] = class_names[np.argmax(tree_.value[node])]
print("leaf {}".format(tree_.value[node]))
return feats, fdoms
self.feats, self.fdoms = recurse(set([]), collections.defaultdict(lambda: set([])), self.root_node)
for parent in self.nodes:
conns = collections.defaultdict(lambda: set([]))
for val, child in self.nodes[parent].vals.items():
conns[child].add(val)
self.nodes[parent].vals = {frozenset(val): child for child, val in conns.items()}
self.feats = sorted(self.feats)
self.feids = {f: i for i, f in enumerate(self.feats)}
self.fdoms = {f: sorted(self.fdoms[f]) for f in self.fdoms}
self.nof_terms = len(self.terms)
self.nof_feats = len(self.feats)
self.paths = collections.defaultdict(lambda: [])
self.extract_paths(root=self.root_node, prefix=[])
def from_dt(self, data):
"""
Get the tree from a file pointer.
"""
contents = StringIO(data)
lines = contents.readlines()
# filtering out comment lines (those that start with '#')
lines = list(filter(lambda l: not l.startswith('#'), lines))
# number of nodes
self.nof_nodes = int(lines[0].strip())
# root node
self.root_node = int(lines[1].strip())
# number of terminal nodes (classes)
self.nof_terms = len(lines[3][2:].strip().split())
# the ordered list of terminal nodes
self.terms = {}
for i in range(self.nof_terms):
nd, _, t = lines[i + 4].strip().split()
self.terms[int(nd)] = t #int(t)
# finally, reading the nodes
self.nodes = collections.defaultdict(lambda: Node(feat='', vals={}))
self.feats = set([])
self.fdoms = collections.defaultdict(lambda: set([]))
for line in lines[(4 + self.nof_terms):]:
# reading the tuple
nid, fid, fval, child = line.strip().split()
# inserting it in the nodes list
self.nodes[int(nid)].feat = fid
self.nodes[int(nid)].vals[int(fval)] = int(child)
# updating the list of features
self.feats.add(fid)
# updaing feature domains
self.fdoms[fid].add(int(fval))
# adding complex node connections into consideration
for n1 in self.nodes:
conns = collections.defaultdict(lambda: set([]))
for v, n2 in self.nodes[n1].vals.items():
conns[n2].add(v)
self.nodes[n1].vals = {frozenset(v): n2 for n2, v in conns.items()}
# simplifying the features and their domains
self.feats = sorted(self.feats)
self.feids = {f: i for i, f in enumerate(self.feats)}
self.fdoms = {f: sorted(self.fdoms[f]) for f in self.fdoms}
# here we assume all features are present in the tree
# if not, this value will be rewritten by self.parse_mapping()
self.nof_feats = len(self.feats)
self.paths = collections.defaultdict(lambda: [])
self.extract_paths(root=self.root_node, prefix=[])
def extract_paths(self, root, prefix):
"""
Traverse the tree and extract explicit paths.
"""
if root in self.terms:
# store the path
term = self.terms[root]
self.paths[term].append(prefix)
else:
# select next node
feat, vals = self.nodes[root].feat, self.nodes[root].vals
for val in vals:
self.extract_paths(vals[val], prefix + [tuple([feat, val])])
def execute(self, inst, pathlits=False):
"""
Run the tree and obtain the prediction given an input instance.
"""
root = self.root_node
depth = 0
path = []
# this array is needed if we focus on the path's literals only
visited = [False for f in inst]
while not root in self.terms:
path.append(root)
feat, vals = self.nodes[root].feat, self.nodes[root].vals
visited[self.feids[feat]] = True
tval = inst[self.feids[feat]][1]
###############
# assert(len(vals) == 2)
next_node = root
neq = None
for vs, dest in vals.items():
if tval in vs:
next_node = dest
break
else:
for v in vs:
if '!=' in self.fvmap[(feat, v)]:
neq = dest
break
else:
next_node = neq
# if tval not in vals:
# # go to the False branch (!=)
# for i in vals:
# if "!=" in self.fvmap[(feat,i)]:
# next_node = vals[i]
# break
# else:
# next_node = vals[tval]
assert (next_node != root)
###############
root = next_node
depth += 1
if pathlits:
# filtering out non-visited literals
for i, v in enumerate(visited):
if not v:
inst[i] = None
return path, self.terms[root], depth
def prepare_sets(self, inst, term):
"""
Hitting set based encoding of the problem.
(currently not incremental -- should be fixed later)
"""
sets = []
for t, paths in self.paths.items():
# ignoring the right class
if t == term:
continue
# computing the sets to hit
for path in paths:
to_hit = []
for item in path:
# if the instance disagrees with the path on this item
if inst[self.feids[item[0]]] and not inst[self.feids[item[0]]][1] in item[1]:
fv = inst[self.feids[item[0]]]
if fv[0] in self.ohmap.opp:
to_hit.append(tuple([self.ohmap.opp[fv[0]], None]))
else:
to_hit.append(fv)
to_hit = sorted(set(to_hit))
sets.append(tuple(to_hit))
if self.verbose:
if self.verbose > 1:
print('c trav. path: {0}'.format(path))
print('c set to hit: {0}'.format(to_hit))
# returning the set of sets with no duplicates
return list(dict.fromkeys(sets))
def explain(self, inst, enum=1, pathlits=False, xtype = ["AXp"], solver='g3', htype='sorted'):
"""
Compute a given number of explanations.
"""
inst = list(map(lambda i: tuple([i[0], int(i[1])]), [i.split('=') for i in inst]))
inst_orig = inst[:]
path, term, depth = self.execute(inst, pathlits)
explanation = str(inst) + "\n \n"
#print('c instance: IF {0} THEN class={1}'.format(' AND '.join([self.fvmap[p] for p in inst_orig]), term))
#print(term)
explanation += 'c instance: IF {0} THEN class={1}'.format(' AND '.join([self.fvmap[ inst_orig[self.feids[self.nodes[n].feat]] ] for n in path]), term) + "\n"
explanation +='c path len:'+ str(depth)+ "\n \n \n"
if self.ohmap.dir:
f2v = {fv[0]: fv[1] for fv in inst}
# updating fvmap for printing ohe features
for fo, fis in self.ohmap.dir.items():
self.fvmap[tuple([fo, None])] = '(' + ' AND '.join([self.fvmap[tuple([fi, f2v[fi]])] for fi in fis]) + ')'
# computing the sets to hit
to_hit = self.prepare_sets(inst, term)
for type in xtype :
if type == "AXp":
explanation += "Abductive explanation : " + "\n \n"
explanation += self.enumerate_abductive(to_hit, enum, solver, htype, term)+ "\n \n"
else :
explanation += "Contrastive explanation : "+ "\n \n"
explanation += self.enumerate_contrastive(to_hit, term)+ "\n \n"
return explanation
def enumerate_abductive(self, to_hit, enum, solver, htype, term):
"""
Enumerate abductive explanations.
"""
explanation = ""
with Hitman(bootstrap_with=to_hit, solver=solver, htype=htype) as hitman:
expls = []
for i, expl in enumerate(hitman.enumerate(), 1):
explanation += 'c expl: IF {0} THEN class={1}'.format(' AND '.join([self.fvmap[p] for p in sorted(expl, key=lambda p: p[0])]), term) + "\n"
expls.append(expl)
if i == enum:
break
explanation += 'c nof expls:' + str(i)+ "\n"
explanation += 'c min expl:'+ str( min([len(e) for e in expls]))+ "\n"
explanation += 'c max expl:'+ str( max([len(e) for e in expls]))+ "\n"
explanation += 'c avg expl: {0:.2f}'.format(sum([len(e) for e in expls]) / len(expls))+ "\n \n \n"
return explanation
def enumerate_contrastive(self, to_hit, term):
"""
Enumerate contrastive explanations.
"""
def process_set(done, target):
for s in done:
if s <= target:
break
else:
done.append(target)
return done
to_hit = [set(s) for s in to_hit]
to_hit.sort(key=lambda s: len(s))
expls = list(reduce(process_set, to_hit, []))
explanation = ""
for expl in expls:
explanation += 'c expl: IF {0} THEN class!={1}'.format(' OR '.join(['!{0}'.format(self.fvmap[p]) for p in sorted(expl, key=lambda p: p[0])]), term)+ "\n"
explanation +='c nof expls:'+ str(len(expls))+ "\n"
explanation +='c min expl:'+ str( min([len(e) for e in expls]))+ "\n"
explanation +='c max expl:'+ str( max([len(e) for e in expls]))+ "\n"
explanation +='c avg expl: {0:.2f}'.format(sum([len(e) for e in expls]) / len(expls))+ "\n"
return explanation