-
Caroline DE POURTALES authoredCaroline DE POURTALES authored
dtree.py 13.95 KiB
#!/usr/bin/env python
#-*- coding:utf-8 -*-
##
## dtree.py
##
## Created on: Jul 6, 2020
## Author: Alexey Ignatiev
## E-mail: alexey.ignatiev@monash.edu
##
#
#==============================================================================
from __future__ import print_function
import collections
from functools import reduce
from pysat.card import *
from pysat.examples.hitman import Hitman
from pysat.formula import CNF, IDPool
from pysat.solvers import Solver
try: # for Python2
from cStringIO import StringIO
except ImportError: # for Python3
from io import StringIO
from sklearn.tree import _tree
import numpy as np
class Node():
"""
Node class.
"""
def __init__(self, feat='', vals=[]):
"""
Constructor.
"""
self.feat = feat
self.vals = vals
#
#==============================================================================
class DecisionTree():
"""
Simple decision tree class.
"""
def __init__(self, from_dt=None, mapfile=None, verbose=0):
"""
Constructor.
"""
self.verbose = verbose
self.nof_nodes = 0
self.nof_terms = 0
self.root_node = None
self.terms = []
self.nodes = {}
self.paths = {}
self.feats = []
self.feids = {}
self.fdoms = {}
self.fvmap = {}
# OHE mapping
OHEMap = collections.namedtuple('OHEMap', ['dir', 'opp'])
self.ohmap = OHEMap(dir={}, opp={})
if from_dt:
self.from_dt(from_dt)
if mapfile:
self.parse_mapping(mapfile)
else: # no mapping is given
for f in self.feats:
for v in self.fdoms[f]:
self.fvmap[tuple([f, v])] = '{0}={1}'.format(f, v)
def from_dt(self, data):
"""
Get the tree from a file pointer.
"""
contents = StringIO(data)
lines = contents.readlines()
# filtering out comment lines (those that start with '#')
lines = list(filter(lambda l: not l.startswith('#'), lines))
# number of nodes
self.nof_nodes = int(lines[0].strip())
# root node
self.root_node = int(lines[1].strip())
# number of terminal nodes (classes)
self.nof_terms = len(lines[3][2:].strip().split())
# the ordered list of terminal nodes
self.terms = {}
for i in range(self.nof_terms):
nd, _, t = lines[i + 4].strip().split()
self.terms[int(nd)] = t #int(t)
# finally, reading the nodes
self.nodes = collections.defaultdict(lambda: Node(feat='', vals={}))
self.feats = set([])
self.fdoms = collections.defaultdict(lambda: set([]))
for line in lines[(4 + self.nof_terms):]:
# reading the tuple
nid, fid, fval, child = line.strip().split()
# inserting it in the nodes list
self.nodes[int(nid)].feat = fid
self.nodes[int(nid)].vals[int(fval)] = int(child)
# updating the list of features
self.feats.add(fid)
# updaing feature domains
self.fdoms[fid].add(int(fval))
# adding complex node connections into consideration
for n1 in self.nodes:
conns = collections.defaultdict(lambda: set([]))
for v, n2 in self.nodes[n1].vals.items():
conns[n2].add(v)
self.nodes[n1].vals = {frozenset(v): n2 for n2, v in conns.items()}
# simplifying the features and their domains
self.feats = sorted(self.feats)
#self.feids = {f: i for i, f in enumerate(self.feats)}
self.fdoms = {f: sorted(self.fdoms[f]) for f in self.fdoms}
# here we assume all features are present in the tree
# if not, this value will be rewritten by self.parse_mapping()
self.nof_feats = len(self.feats)
self.paths = collections.defaultdict(lambda: [])
self.extract_paths(root=self.root_node, prefix=[])
def parse_mapping(self, mapfile):
"""
Parse feature-value mapping from a file.
"""
self.fvmap = {}
lines = mapfile.split('\n')
if lines[0].startswith('OHE'):
for i in range(int(lines[1])):
feats = lines[i + 2].strip().split(',')
orig, ohe = feats[0], tuple(feats[1:])
self.ohmap.dir[orig] = tuple(ohe)
for f in ohe:
self.ohmap.opp[f] = orig
lines = lines[(int(lines[1]) + 2):]
elif lines[0].startswith('Categorical'):
# skipping the first comment line if necessary
lines = lines[1:]
elif lines[0].startswith('Ordinal'):
# skipping the first comment line if necessary
lines = lines[1:]
# number of features
self.nof_feats = int(lines[0].strip())
self.feids = {}
for line in lines[1:]:
feat, val, real = line.split()
self.fvmap[tuple([feat, int(val)])] = '{0}{1}'.format(feat, real)
#if feat not in self.feids:
# self.feids[feat] = len(self.feids)
#assert len(self.feids) == self.nof_feats
def convert_to_multiedges(self):
"""
Convert ITI trees with '!=' edges to multi-edges.
"""
# new feature domains
fdoms = collections.defaultdict(lambda: [])
# tentative mapping relating negative and positive values
nemap = collections.defaultdict(lambda: collections.defaultdict(lambda: [None, None]))
for fv, tval in self.fvmap.items():
if '!=' in tval:
nemap[fv[0]][tval.split('=')[1]][0] = fv[1]
else:
fdoms[fv[0]].append(fv[1])
nemap[fv[0]][tval.split('=')[1]][1] = fv[1]
# a mapping from negative values to sets
fnmap = collections.defaultdict(lambda: {})
for f in nemap:
for t, vals in nemap[f].items():
if vals[0] != None:
fnmap[(f, frozenset({vals[0]}))] = frozenset(set(fdoms[f]).difference({vals[1]}))
# updating node connections
for n in self.nodes:
vals = {}
for v in self.nodes[n].vals.keys():
fn = (self.nodes[n].feat, v)
if fn in fnmap:
vals[fnmap[fn]] = self.nodes[n].vals[v]
else:
vals[v] = self.nodes[n].vals[v]
self.nodes[n].vals = vals
# updating the domains
self.fdoms = fdoms
# extracting the paths again
self.paths = collections.defaultdict(lambda: [])
self.extract_paths(root=self.root_node, prefix=[])
def extract_paths(self, root, prefix):
"""
Traverse the tree and extract explicit paths.
"""
if root in self.terms:
# store the path
term = self.terms[root]
self.paths[term].append(prefix)
else:
# select next node
feat, vals = self.nodes[root].feat, self.nodes[root].vals
for val in vals:
self.extract_paths(vals[val], prefix + [tuple([feat, val])])
def execute(self, inst, pathlits=False):
"""
Run the tree and obtain the prediction given an input instance.
"""
root = self.root_node
depth = 0
path = []
# this array is needed if we focus on the path's literals only
visited = [False for f in inst]
while not root in self.terms:
path.append(root)
feat, vals = self.nodes[root].feat, self.nodes[root].vals
visited[self.feids[feat]] = True
tval = inst[self.feids[feat]][1]
###############
# assert(len(vals) == 2)
next_node = root
neq = None
for vs, dest in vals.items():
if tval in vs:
next_node = dest
break
else:
for v in vs:
if '!=' in self.fvmap[(feat, v)]:
neq = dest
break
else:
next_node = neq
# if tval not in vals:
# # go to the False branch (!=)
# for i in vals:
# if "!=" in self.fvmap[(feat,i)]:
# next_node = vals[i]
# break
# else:
# next_node = vals[tval]
assert (next_node != root)
###############
root = next_node
depth += 1
if pathlits:
# filtering out non-visited literals
for i, v in enumerate(visited):
if not v:
inst[i] = None
return path, self.terms[root], depth
def prepare_sets(self, inst, term):
"""
Hitting set based encoding of the problem.
(currently not incremental -- should be fixed later)
"""
sets = []
for t, paths in self.paths.items():
# ignoring the right class
if t == term:
continue
# computing the sets to hit
for path in paths:
to_hit = []
for item in path:
fv = inst[self.feids[item[0]]]
# if the instance disagrees with the path on this item
if fv and not fv[1] in item[1]:
if fv[0] in self.ohmap.opp:
to_hit.append(tuple([self.ohmap.opp[fv[0]], None]))
else:
to_hit.append(fv)
to_hit = sorted(set(to_hit))
sets.append(tuple(to_hit))
if self.verbose:
if self.verbose > 1:
print('c trav. path: {0}'.format(path))
print('c set to hit: {0}'.format(to_hit))
# returning the set of sets with no duplicates
return list(dict.fromkeys(sets))
def explain(self, inst, enum=1, pathlits=False, xtype = ["AXp"], solver='g3', htype='sorted'):
"""
Compute a given number of explanations.
"""
self.feids = {f[0]: i for i, f in enumerate(inst)}
path, term, depth = self.execute(inst, pathlits)
#contaiins all the elements for explanation
explanation_dic = {}
#instance plotting
explanation_dic["Instance : "] = str([self.fvmap[inst[i]] for i in range (len(inst))])
#decision path
decision_path_str = 'IF {0} THEN class={1}'.format(' AND '.join([self.fvmap[inst[self.feids[self.nodes[n].feat]]] for n in path]), term)
explanation_dic["Decision path of instance : "] = decision_path_str
explanation_dic["Decision path length : "] = 'Path length is :'+ str(depth)
if self.ohmap.dir:
f2v = {fv[0]: fv[1] for fv in inst}
# updating fvmap for printing ohe features
for fo, fis in self.ohmap.dir.items():
self.fvmap[tuple([fo, None])] = '(' + ' AND '.join([self.fvmap[tuple([fi, f2v[fi]])] for fi in fis]) + ')'
# computing the sets to hit
to_hit = self.prepare_sets(inst, term)
for type in xtype :
if type == "AXp":
explanation_dic.update(self.enumerate_abductive(to_hit, enum, solver, htype, term))
else :
explanation_dic.update(self.enumerate_contrastive(to_hit, term))
return explanation_dic
def enumerate_abductive(self, to_hit, enum, solver, htype, term):
"""
Enumerate abductive explanations.
"""
list_expls = []
list_expls_str = []
explanation = {}
with Hitman(bootstrap_with=to_hit, solver='m22', htype=htype) as hitman:
expls = []
for i, expl in enumerate(hitman.enumerate(), 1):
list_expls.append([self.fvmap[p] for p in sorted(expl, key=lambda p: p[0])])
list_expls_str.append('Explanation: IF {0} THEN class={1}'.format(' AND '.join([self.fvmap[p] for p in sorted(expl, key=lambda p: p[0])]), term))
expls.append(expl)
if i == enum:
break
explanation["List of path explanation(s)"] = list_expls
explanation["List of abductive explanation(s)"] = list_expls_str
explanation["Number of abductive explanation(s) : "] = str(i)
explanation["Minimal abductive explanation : "] = str( min([len(e) for e in expls]))
explanation["Maximal abductive explanation : "] = str( max([len(e) for e in expls]))
explanation["Average abductive explanation : "] = '{0:.2f}'.format(sum([len(e) for e in expls]) / len(expls))
return explanation
def enumerate_contrastive(self, to_hit, term):
"""
Enumerate contrastive explanations.
"""
def process_set(done, target):
for s in done:
if s <= target:
break
else:
done.append(target)
return done
to_hit = [set(s) for s in to_hit]
to_hit.sort(key=lambda s: len(s))
expls = list(reduce(process_set, to_hit, []))
list_expls_str = []
explanation = {}
for expl in expls:
list_expls_str.append('Contrastive: IF {0} THEN class!={1}'.format(' OR '.join(['!{0}'.format(self.fvmap[p]) for p in sorted(expl, key=lambda p: p[0])]), term))
explanation["List of contrastive explanation(s)"] = list_expls_str
explanation["Number of contrastive explanation(s) : "]=str(len(expls))
explanation["Minimal contrastive explanation : "]= str( min([len(e) for e in expls]))
explanation["Maximal contrastive explanation : "]= str( max([len(e) for e in expls]))
explanation["Average contrastive explanation : "]='{0:.2f}'.format(sum([len(e) for e in expls]) / len(expls))
return explanation