Skip to content
Snippets Groups Projects
dtree.py 13.95 KiB
#!/usr/bin/env python
#-*- coding:utf-8 -*-
##
## dtree.py
##
##  Created on: Jul 6, 2020
##      Author: Alexey Ignatiev
##      E-mail: alexey.ignatiev@monash.edu
##

#
#==============================================================================
from __future__ import print_function
import collections
from functools import reduce
from pysat.card import *
from pysat.examples.hitman import Hitman
from pysat.formula import CNF, IDPool
from pysat.solvers import Solver

try:  # for Python2
    from cStringIO import StringIO
except ImportError:  # for Python3
    from io import StringIO
from sklearn.tree import _tree
import numpy as np

class Node():
    """
        Node class.
    """

    def __init__(self, feat='', vals=[]):
        """
            Constructor.
        """

        self.feat = feat
        self.vals = vals

#
#==============================================================================
class DecisionTree():
    """
        Simple decision tree class.
    """

    def __init__(self, from_dt=None, mapfile=None, verbose=0):
        """
            Constructor.
        """

        self.verbose = verbose

        self.nof_nodes = 0
        self.nof_terms = 0
        self.root_node = None
        self.terms = []
        self.nodes = {}
        self.paths = {}
        self.feats = []
        self.feids = {}
        self.fdoms = {}
        self.fvmap = {}

        # OHE mapping
        OHEMap = collections.namedtuple('OHEMap', ['dir', 'opp'])
        self.ohmap = OHEMap(dir={}, opp={})

        if from_dt:
            self.from_dt(from_dt)     

        if mapfile:
            self.parse_mapping(mapfile)
        else:  # no mapping is given
            for f in self.feats:
                for v in self.fdoms[f]:
                    self.fvmap[tuple([f, v])] = '{0}={1}'.format(f, v)

    def from_dt(self, data):
        """
            Get the tree from a file pointer.
        """

        contents = StringIO(data)

        lines = contents.readlines()

        # filtering out comment lines (those that start with '#')
        lines = list(filter(lambda l: not l.startswith('#'), lines))

        # number of nodes
        self.nof_nodes = int(lines[0].strip())

        # root node
        self.root_node = int(lines[1].strip())

        # number of terminal nodes (classes)
        self.nof_terms = len(lines[3][2:].strip().split())

        # the ordered list of terminal nodes
        self.terms = {}
        for i in range(self.nof_terms):
            nd, _, t = lines[i + 4].strip().split()
            self.terms[int(nd)] = t #int(t)

        # finally, reading the nodes
        self.nodes = collections.defaultdict(lambda: Node(feat='', vals={}))
        self.feats = set([])
        self.fdoms = collections.defaultdict(lambda: set([]))
        for line in lines[(4 + self.nof_terms):]:
            # reading the tuple
            nid, fid, fval, child = line.strip().split()

            # inserting it in the nodes list
            self.nodes[int(nid)].feat = fid
            self.nodes[int(nid)].vals[int(fval)] = int(child)

            # updating the list of features
            self.feats.add(fid)

            # updaing feature domains
            self.fdoms[fid].add(int(fval))

        # adding complex node connections into consideration
        for n1 in self.nodes:
            conns = collections.defaultdict(lambda: set([]))
            for v, n2 in self.nodes[n1].vals.items():
                conns[n2].add(v)
            self.nodes[n1].vals = {frozenset(v): n2 for n2, v in conns.items()}

        # simplifying the features and their domains
        self.feats = sorted(self.feats)
        #self.feids = {f: i for i, f in enumerate(self.feats)}
        self.fdoms = {f: sorted(self.fdoms[f]) for f in self.fdoms}

        # here we assume all features are present in the tree
        # if not, this value will be rewritten by self.parse_mapping()
        self.nof_feats = len(self.feats)

        self.paths = collections.defaultdict(lambda: [])
        self.extract_paths(root=self.root_node, prefix=[])

    def parse_mapping(self, mapfile):
        """
            Parse feature-value mapping from a file.
        """

        self.fvmap = {}

        lines = mapfile.split('\n')

        if lines[0].startswith('OHE'):
            for i in range(int(lines[1])):
                feats = lines[i + 2].strip().split(',')
                orig, ohe = feats[0], tuple(feats[1:])
                self.ohmap.dir[orig] = tuple(ohe)
                for f in ohe:
                    self.ohmap.opp[f] = orig

            lines = lines[(int(lines[1]) + 2):]

        elif lines[0].startswith('Categorical'):
            # skipping the first comment line if necessary
            lines = lines[1:]

        elif lines[0].startswith('Ordinal'):
            # skipping the first comment line if necessary
            lines = lines[1:]

        # number of features
        self.nof_feats = int(lines[0].strip())
        self.feids = {}

        for line in lines[1:]:
            feat, val, real = line.split()
            self.fvmap[tuple([feat, int(val)])] = '{0}{1}'.format(feat, real)
            #if feat not in self.feids:
            #    self.feids[feat] = len(self.feids)

        #assert len(self.feids) == self.nof_feats

    def convert_to_multiedges(self):
        """
            Convert ITI trees with '!=' edges to multi-edges.
        """

        # new feature domains
        fdoms = collections.defaultdict(lambda: [])

        # tentative mapping relating negative and positive values
        nemap = collections.defaultdict(lambda: collections.defaultdict(lambda: [None, None]))

        for fv, tval in self.fvmap.items():
            if '!=' in tval:
                nemap[fv[0]][tval.split('=')[1]][0] = fv[1]
            else:
                fdoms[fv[0]].append(fv[1])
                nemap[fv[0]][tval.split('=')[1]][1] = fv[1]

        # a mapping from negative values to sets
        fnmap = collections.defaultdict(lambda: {})
        for f in nemap:
            for t, vals in nemap[f].items():
                if vals[0] != None:
                    fnmap[(f, frozenset({vals[0]}))] = frozenset(set(fdoms[f]).difference({vals[1]}))

        # updating node connections
        for n in self.nodes:
            vals = {}
            for v in self.nodes[n].vals.keys():
                fn = (self.nodes[n].feat, v)
                if fn in fnmap:
                    vals[fnmap[fn]] = self.nodes[n].vals[v]
                else:
                    vals[v] = self.nodes[n].vals[v]
            self.nodes[n].vals = vals

        # updating the domains
        self.fdoms = fdoms

        # extracting the paths again
        self.paths = collections.defaultdict(lambda: [])
        self.extract_paths(root=self.root_node, prefix=[])

    def extract_paths(self, root, prefix):
        """
            Traverse the tree and extract explicit paths.
        """

        if root in self.terms:
            # store the path
            term = self.terms[root]
            self.paths[term].append(prefix)
        else:
            # select next node
            feat, vals = self.nodes[root].feat, self.nodes[root].vals
            for val in vals:
                self.extract_paths(vals[val], prefix + [tuple([feat, val])])

    def execute(self, inst, pathlits=False):
        """
            Run the tree and obtain the prediction given an input instance.
        """

        root = self.root_node
        depth = 0
        path = []

        # this array is needed if we focus on the path's literals only
        visited = [False for f in inst]

        while not root in self.terms:
            path.append(root)
            feat, vals = self.nodes[root].feat, self.nodes[root].vals
            visited[self.feids[feat]] = True
            tval = inst[self.feids[feat]][1]
            ###############
            # assert(len(vals) == 2)
            next_node = root
            neq = None
            for vs, dest in vals.items():
                if tval in vs:
                    next_node = dest
                    break
                else:
                    for v in vs:
                        if '!=' in self.fvmap[(feat, v)]:
                            neq = dest
                            break
            else:
                next_node = neq
            # if tval not in vals:
            #     # go to the False branch (!=)
            #     for i in vals:
            #         if "!=" in self.fvmap[(feat,i)]:
            #             next_node = vals[i]
            #             break
            # else:
            #     next_node = vals[tval]

            assert (next_node != root)
            ###############
            root = next_node
            depth += 1

        if pathlits:
            # filtering out non-visited literals
            for i, v in enumerate(visited):
                if not v:
                    inst[i] = None

        return path, self.terms[root], depth

    def prepare_sets(self, inst, term):
        """
            Hitting set based encoding of the problem.
            (currently not incremental -- should be fixed later)
        """

        sets = []
        for t, paths in self.paths.items():
            # ignoring the right class
            if t == term:
                continue

            # computing the sets to hit
            for path in paths:
                to_hit = []
                for item in path:
                    fv = inst[self.feids[item[0]]]
                    # if the instance disagrees with the path on this item
                    if fv and not fv[1] in item[1]:
                        if fv[0] in self.ohmap.opp:
                            to_hit.append(tuple([self.ohmap.opp[fv[0]], None]))
                        else:
                            to_hit.append(fv)

                to_hit = sorted(set(to_hit))
                sets.append(tuple(to_hit))

                if self.verbose:
                    if self.verbose > 1:
                        print('c trav. path: {0}'.format(path))

                    print('c set to hit: {0}'.format(to_hit))

        # returning the set of sets with no duplicates
        return list(dict.fromkeys(sets))

    def explain(self, inst, enum=1, pathlits=False, xtype = ["AXp"], solver='g3', htype='sorted'):
        """
            Compute a given number of explanations.
        """
        
        self.feids = {f[0]: i for i, f in enumerate(inst)}

        path, term, depth = self.execute(inst, pathlits)
    
        #contaiins all the elements for explanation
        explanation_dic = {}
        #instance plotting
        explanation_dic["Instance : "] = str([self.fvmap[inst[i]] for i in range (len(inst))])

        #decision path
        decision_path_str = 'IF {0} THEN class={1}'.format(' AND '.join([self.fvmap[inst[self.feids[self.nodes[n].feat]]] for n in path]), term)
        explanation_dic["Decision path of instance : "] = decision_path_str
        explanation_dic["Decision path length : "] = 'Path length is :'+ str(depth)


        if self.ohmap.dir:
            f2v = {fv[0]: fv[1] for fv in inst}

            # updating fvmap for printing ohe features
            for fo, fis in self.ohmap.dir.items():
                self.fvmap[tuple([fo, None])] = '(' + ' AND '.join([self.fvmap[tuple([fi, f2v[fi]])] for fi in fis]) + ')'

        # computing the sets to hit
        to_hit = self.prepare_sets(inst, term)
        for type in xtype :
            if type == "AXp":
                explanation_dic.update(self.enumerate_abductive(to_hit, enum, solver, htype, term))
            else :
                explanation_dic.update(self.enumerate_contrastive(to_hit, term))

        return explanation_dic

    def enumerate_abductive(self, to_hit, enum, solver, htype, term):
        """
            Enumerate abductive explanations.
        """
        list_expls = []
        list_expls_str = []
        explanation = {}
        with Hitman(bootstrap_with=to_hit, solver='m22', htype=htype) as hitman:
            expls = []
            for i, expl in enumerate(hitman.enumerate(), 1):
                list_expls.append([self.fvmap[p] for p in sorted(expl, key=lambda p: p[0])])
                list_expls_str.append('Explanation: IF {0} THEN class={1}'.format(' AND '.join([self.fvmap[p] for p in sorted(expl, key=lambda p: p[0])]), term))

                expls.append(expl)
                if i == enum:
                    break
            explanation["List of path explanation(s)"] = list_expls
            explanation["List of abductive explanation(s)"] = list_expls_str
            explanation["Number of abductive explanation(s) : "] = str(i)
            explanation["Minimal abductive explanation : "] = str( min([len(e) for e in expls]))
            explanation["Maximal abductive explanation : "] = str( max([len(e) for e in expls]))
            explanation["Average abductive explanation : "] = '{0:.2f}'.format(sum([len(e) for e in expls]) / len(expls))

        return explanation

    def enumerate_contrastive(self, to_hit, term):
        """
            Enumerate contrastive explanations.
        """

        def process_set(done, target):
            for s in done:
                if s <= target:
                    break
            else:
                done.append(target)
            return done

        to_hit = [set(s) for s in to_hit]
        to_hit.sort(key=lambda s: len(s))
        expls = list(reduce(process_set, to_hit, []))
        list_expls_str = []
        explanation = {}
        for expl in expls:
            list_expls_str.append('Contrastive: IF {0} THEN class!={1}'.format(' OR '.join(['!{0}'.format(self.fvmap[p]) for p in sorted(expl, key=lambda p: p[0])]), term))

        explanation["List of contrastive explanation(s)"] = list_expls_str
        explanation["Number of contrastive explanation(s) : "]=str(len(expls))
        explanation["Minimal contrastive explanation : "]= str( min([len(e) for e in expls]))
        explanation["Maximal contrastive explanation : "]= str( max([len(e) for e in expls]))
        explanation["Average contrastive explanation : "]='{0:.2f}'.format(sum([len(e) for e in expls]) / len(expls))

        return explanation