Skip to content
Snippets Groups Projects
dtree.py 13.58 KiB
#!/usr/bin/env python
#-*- coding:utf-8 -*-
##
## dtree.py
##
##  Created on: Jul 6, 2020
##      Author: Alexey Ignatiev
##      E-mail: alexey.ignatiev@monash.edu
##

#
#==============================================================================
from __future__ import print_function
import collections
from functools import reduce
from pysat.card import *
from pysat.examples.hitman import Hitman
from pysat.formula import CNF, IDPool
from pysat.solvers import Solver

try:  # for Python2
    from cStringIO import StringIO
except ImportError:  # for Python3
    from io import StringIO
from sklearn.tree import _tree
import numpy as np

#
#==============================================================================
class Node():
    """
        Node class.
    """

    def __init__(self, feat='', vals=[], threshold=None):
        """
            Constructor.
        """

        self.feat = feat
        if threshold is not None :
            self.threshold = threshold
        else : 
            self.vals = vals


#
#==============================================================================
class DecisionTree():
    """
        Simple decision tree class.
    """

    def __init__(self, from_dt=None, from_pickle=None,
            mapfile=None, verbose=0):
        """
            Constructor.
        """

        self.verbose = verbose

        self.nof_nodes = 0
        self.nof_terms = 0
        self.root_node = None
        self.terms = []
        self.nodes = {}
        self.paths = {}
        self.feats = []
        self.feids = {}
        self.fdoms = {}
        self.fvmap = {}

        # OHE mapping
        OHEMap = collections.namedtuple('OHEMap', ['dir', 'opp'])
        self.ohmap = OHEMap(dir={}, opp={})

        if from_dt:
            self.from_dt(from_dt)     
        elif from_pickle:
            self.from_pickle_file(from_pickle)

        if mapfile:
            self.parse_mapping(mapfile)
        else:  # no mapping is given
            for f in self.feats:
                for v in self.fdoms[f]:
                    self.fvmap[tuple([f, v])] = '{0}={1}'.format(f, v)

    #problem de feature names et problem de vals dans node
    def from_pickle_file(self, tree):
        tree_ = tree.tree_
        try:
            feature_names = tree.feature_names_in_
        except:
            print("You did not dump the model with the features names")
            feature_names = [str(i) for i in range(tree.n_features_in_)] 

        class_names = tree.classes_

        self.nodes = collections.defaultdict(lambda: Node(feat='', vals={}))
        self.terms={}
        self.nof_nodes = tree_.node_count
        self.nof_terms = 0
        self.root_node = 0

        feature_name = [
            feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
            for i in tree_.feature]

        def recurse(feats, fdoms, node):
            if tree_.feature[node] != _tree.TREE_UNDEFINED:
                name = feature_name[node]
                val = tree_.threshold[node]

                #faire une boucle for des vals ? 
                self.nodes[int(node)].feat = name
                self.nodes[int(node)].vals[int(np.round(val,4))] = int(tree_.children_left[node])

                self.nodes[int(node)].feat = name
                self.nodes[int(node)].vals[int(4854)] = int(tree_.children_right[node])

                feats.add(name)
                fdoms[name].add(int(np.round(val,4)))
                feats, fdoms = recurse(feats, fdoms, tree_.children_left[node])
                fdoms[name].add(4854)
                feats, fdoms = recurse(feats, fdoms, tree_.children_right[node])

            else:
                self.terms[node] = class_names[np.argmax(tree_.value[node])]
                print("leaf {}".format(tree_.value[node]))

            return feats, fdoms
            
        self.feats, self.fdoms = recurse(set([]), collections.defaultdict(lambda: set([])), self.root_node)

        for parent in self.nodes:
            conns = collections.defaultdict(lambda: set([]))
            for val, child in self.nodes[parent].vals.items():
                conns[child].add(val)
            self.nodes[parent].vals = {frozenset(val): child for child, val in conns.items()}

        self.feats = sorted(self.feats)
        self.feids = {f: i for i, f in enumerate(self.feats)}
        self.fdoms = {f: sorted(self.fdoms[f]) for f in self.fdoms}
        self.nof_terms = len(self.terms)
        self.nof_feats = len(self.feats)

        self.paths = collections.defaultdict(lambda: [])
        self.extract_paths(root=self.root_node, prefix=[])

    def from_dt(self, data):
        """
            Get the tree from a file pointer.
        """

        contents = StringIO(data)

        lines = contents.readlines()

        # filtering out comment lines (those that start with '#')
        lines = list(filter(lambda l: not l.startswith('#'), lines))

        # number of nodes
        self.nof_nodes = int(lines[0].strip())

        # root node
        self.root_node = int(lines[1].strip())

        # number of terminal nodes (classes)
        self.nof_terms = len(lines[3][2:].strip().split())

        # the ordered list of terminal nodes
        self.terms = {}
        for i in range(self.nof_terms):
            nd, _, t = lines[i + 4].strip().split()
            self.terms[int(nd)] = t #int(t)

        # finally, reading the nodes
        self.nodes = collections.defaultdict(lambda: Node(feat='', vals={}))
        self.feats = set([])
        self.fdoms = collections.defaultdict(lambda: set([]))
        for line in lines[(4 + self.nof_terms):]:
            # reading the tuple
            nid, fid, fval, child = line.strip().split()

            # inserting it in the nodes list
            self.nodes[int(nid)].feat = fid
            self.nodes[int(nid)].vals[int(fval)] = int(child)

            # updating the list of features
            self.feats.add(fid)

            # updaing feature domains
            self.fdoms[fid].add(int(fval))

        # adding complex node connections into consideration
        for n1 in self.nodes:
            conns = collections.defaultdict(lambda: set([]))
            for v, n2 in self.nodes[n1].vals.items():
                conns[n2].add(v)
            self.nodes[n1].vals = {frozenset(v): n2 for n2, v in conns.items()}

        # simplifying the features and their domains
        self.feats = sorted(self.feats)
        self.feids = {f: i for i, f in enumerate(self.feats)}
        self.fdoms = {f: sorted(self.fdoms[f]) for f in self.fdoms}

        # here we assume all features are present in the tree
        # if not, this value will be rewritten by self.parse_mapping()
        self.nof_feats = len(self.feats)

        self.paths = collections.defaultdict(lambda: [])
        self.extract_paths(root=self.root_node, prefix=[])

    def extract_paths(self, root, prefix):
        """
            Traverse the tree and extract explicit paths.
        """

        if root in self.terms:
            # store the path
            term = self.terms[root]
            self.paths[term].append(prefix)
        else:
            # select next node
            feat, vals = self.nodes[root].feat, self.nodes[root].vals
            for val in vals:
                self.extract_paths(vals[val], prefix + [tuple([feat, val])])

    def execute(self, inst, pathlits=False):
        """
            Run the tree and obtain the prediction given an input instance.
        """

        root = self.root_node
        depth = 0
        path = []

        # this array is needed if we focus on the path's literals only
        visited = [False for f in inst]

        while not root in self.terms:
            path.append(root)
            feat, vals = self.nodes[root].feat, self.nodes[root].vals
            visited[self.feids[feat]] = True
            tval = inst[self.feids[feat]][1]
            ###############
            # assert(len(vals) == 2)
            next_node = root
            neq = None
            for vs, dest in vals.items():
                if tval in vs:
                    next_node = dest
                    break
                else:
                    for v in vs:
                        if '!=' in self.fvmap[(feat, v)]:
                            neq = dest
                            break
            else:
                next_node = neq
            # if tval not in vals:
            #     # go to the False branch (!=)
            #     for i in vals:
            #         if "!=" in self.fvmap[(feat,i)]:
            #             next_node = vals[i]
            #             break
            # else:
            #     next_node = vals[tval]

            assert (next_node != root)
            ###############
            root = next_node
            depth += 1

        if pathlits:
            # filtering out non-visited literals
            for i, v in enumerate(visited):
                if not v:
                    inst[i] = None

        return path, self.terms[root], depth

    def prepare_sets(self, inst, term):
        """
            Hitting set based encoding of the problem.
            (currently not incremental -- should be fixed later)
        """

        sets = []
        for t, paths in self.paths.items():
            # ignoring the right class
            if t == term:
                continue

            # computing the sets to hit
            for path in paths:
                to_hit = []
                for item in path:
                    # if the instance disagrees with the path on this item
                    if inst[self.feids[item[0]]] and not inst[self.feids[item[0]]][1] in item[1]:
                        fv = inst[self.feids[item[0]]]
                        if fv[0] in self.ohmap.opp:
                            to_hit.append(tuple([self.ohmap.opp[fv[0]], None]))
                        else:
                            to_hit.append(fv)

                to_hit = sorted(set(to_hit))
                sets.append(tuple(to_hit))

                if self.verbose:
                    if self.verbose > 1:
                        print('c trav. path: {0}'.format(path))

                    print('c set to hit: {0}'.format(to_hit))

        # returning the set of sets with no duplicates
        return list(dict.fromkeys(sets))

    def explain(self, inst, enum=1, pathlits=False, xtype = ["AXp"], solver='g3', htype='sorted'):
        """
            Compute a given number of explanations.
        """

        inst = list(map(lambda i: tuple([i[0], int(i[1])]), [i.split('=') for i in inst]))
        inst_orig = inst[:]
        path, term, depth = self.execute(inst, pathlits)

        explanation = str(inst) + "\n \n"
        #print('c instance: IF {0} THEN class={1}'.format(' AND '.join([self.fvmap[p] for p in inst_orig]), term))
        #print(term)
        explanation += 'c instance: IF {0} THEN class={1}'.format(' AND '.join([self.fvmap[ inst_orig[self.feids[self.nodes[n].feat]] ] for n in path]), term) + "\n"
        explanation +='c path len:'+ str(depth)+ "\n \n \n"

        if self.ohmap.dir:
            f2v = {fv[0]: fv[1] for fv in inst}

            # updating fvmap for printing ohe features
            for fo, fis in self.ohmap.dir.items():
                self.fvmap[tuple([fo, None])] = '(' + ' AND '.join([self.fvmap[tuple([fi, f2v[fi]])] for fi in fis]) + ')'

        # computing the sets to hit
        to_hit = self.prepare_sets(inst, term)

        for type in xtype :
            if type == "AXp":
                explanation += "Abductive explanation : " + "\n \n"
                explanation += self.enumerate_abductive(to_hit, enum, solver, htype, term)+ "\n \n"
            else :
                explanation += "Contrastive explanation : "+ "\n \n"
                explanation += self.enumerate_contrastive(to_hit, term)+ "\n \n"

        return explanation 

    def enumerate_abductive(self, to_hit, enum, solver, htype, term):
        """
            Enumerate abductive explanations.
        """
        explanation = ""
        with Hitman(bootstrap_with=to_hit, solver=solver, htype=htype) as hitman:
            expls = []
            for i, expl in enumerate(hitman.enumerate(), 1):
                explanation += 'c expl: IF {0} THEN class={1}'.format(' AND '.join([self.fvmap[p] for p in sorted(expl, key=lambda p: p[0])]), term) + "\n"

                expls.append(expl)
                if i == enum:
                    break

            explanation += 'c nof expls:' + str(i)+ "\n"
            explanation += 'c min expl:'+ str( min([len(e) for e in expls]))+ "\n"
            explanation += 'c max expl:'+ str( max([len(e) for e in expls]))+ "\n"
            explanation += 'c avg expl: {0:.2f}'.format(sum([len(e) for e in expls]) / len(expls))+ "\n \n \n"

        return explanation

    def enumerate_contrastive(self, to_hit, term):
        """
            Enumerate contrastive explanations.
        """

        def process_set(done, target):
            for s in done:
                if s <= target:
                    break
            else:
                done.append(target)
            return done

        to_hit = [set(s) for s in to_hit]
        to_hit.sort(key=lambda s: len(s))
        expls = list(reduce(process_set, to_hit, []))
        explanation = ""
        for expl in expls:
            explanation += 'c expl: IF {0} THEN class!={1}'.format(' OR '.join(['!{0}'.format(self.fvmap[p]) for p in sorted(expl, key=lambda p: p[0])]), term)+ "\n"


        explanation +='c nof expls:'+ str(len(expls))+ "\n"
        explanation +='c min expl:'+ str( min([len(e) for e in expls]))+ "\n"
        explanation +='c max expl:'+ str( max([len(e) for e in expls]))+ "\n"
        explanation +='c avg expl: {0:.2f}'.format(sum([len(e) for e in expls]) / len(expls))+ "\n"

        return explanation