Skip to content
Snippets Groups Projects
xforest.py 32.45 KiB

#from sklearn.ensemble._voting import VotingClassifier
#from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import OneHotEncoder, LabelEncoder
from sklearn.model_selection import train_test_split
#from sklearn.metrics import accuracy_score
import numpy as np
import sys
import os
import resource

import collections
from itertools import combinations
from six.moves import range
import six
import math

from data import Data
from .rndmforest import RF2001, VotingRF
from .tree import Forest, predict_tree

#from .encode import SATEncoder
from pysat.formula import CNF, WCNF, IDPool
from pysat.solvers import Solver
from pysat.card import CardEnc, EncType
from pysat.examples.lbx import LBX
from pysat.examples.mcsls import MCSls
from pysat.examples.rc2 import RC2


    

#
#==============================================================================
class Dataset(Data):
    """
        Class for representing dataset (transactions).
    """
    def __init__(self, filename=None, fpointer=None, mapfile=None,
            separator=' ', use_categorical = False):
        super().__init__(filename, fpointer, mapfile, separator, use_categorical)
        
        # split data into X and y
        self.feature_names = self.names[:-1]
        self.nb_features = len(self.feature_names)
        self.use_categorical = use_categorical
        
        samples = np.asarray(self.samps)
        if not all(c.isnumeric() for c in samples[:, -1]):            
            le = LabelEncoder()
            le.fit(samples[:, -1])
            samples[:, -1]= le.transform(samples[:, -1])
            self.class_names = le.classes_ 
            print(le.classes_)
            print(samples[1:4, :])
        
        samples = np.asarray(samples, dtype=np.float32)
        self.X = samples[:, 0: self.nb_features]
        self.y = samples[:, self.nb_features]
        self.num_class = len(set(self.y))
        self.target_name = list(range(self.num_class))          
        
        print("c nof features: {0}".format(self.nb_features))
        print("c nof classes: {0}".format(self.num_class))
        print("c nof samples: {0}".format(len(self.samps)))
        
        # check if we have info about categorical features
        if (self.use_categorical):
            self.target_name = self.class_names            
            
            self.binarizer = {}
            for i in self.categorical_features:
                self.binarizer.update({i: OneHotEncoder(categories='auto', sparse=False)})#,
                self.binarizer[i].fit(self.X[:,[i]])
        else:
            self.categorical_features = []
            self.categorical_names = []            
            self.binarizer = []           
        #feat map
        self.mapping_features()        
        
        
            
    def train_test_split(self, test_size=0.2, seed=0):
        return train_test_split(self.X, self.y, test_size=test_size, random_state=seed)
           

    def transform(self, x):
        if(len(x) == 0):
            return x
        if (len(x.shape) == 1):
            x = np.expand_dims(x, axis=0)
        if (self.use_categorical):
            assert(self.binarizer != [])
            tx = []
            for i in range(self.nb_features):
                #self.binarizer[i].drop = None
                if (i in self.categorical_features):
                    self.binarizer[i].drop = None
                    tx_aux = self.binarizer[i].transform(x[:,[i]])
                    tx_aux = np.vstack(tx_aux)
                    tx.append(tx_aux)
                else:
                    tx.append(x[:,[i]])
            tx = np.hstack(tx)
            return tx
        else:
            return x

    def transform_inverse(self, x):
        if(len(x) == 0):
            return x
        if (len(x.shape) == 1):
            x = np.expand_dims(x, axis=0)
        if (self.use_categorical):
            assert(self.binarizer != [])
            inverse_x = []
            for i, xi in enumerate(x):
                inverse_xi = np.zeros(self.nb_features)
                for f in range(self.nb_features):
                    if f in self.categorical_features:
                        nb_values = len(self.categorical_names[f])
                        v = xi[:nb_values]
                        v = np.expand_dims(v, axis=0)
                        iv = self.binarizer[f].inverse_transform(v)
                        inverse_xi[f] =iv
                        xi = xi[nb_values:]

                    else:
                        inverse_xi[f] = xi[0]
                        xi = xi[1:]
                inverse_x.append(inverse_xi)
            return inverse_x
        else:
            return x

    def transform_inverse_by_index(self, idx):
        if (idx in self.extended_feature_names):
            return self.extended_feature_names[idx]
        else:
            print("Warning there is no feature {} in the internal mapping".format(idx))
            return None

    def transform_by_value(self, feat_value_pair):
        if (feat_value_pair in self.extended_feature_names.values()):
            keys = (list(self.extended_feature_names.keys())[list( self.extended_feature_names.values()).index(feat_value_pair)])
            return keys
        else:
            print("Warning there is no value {} in the internal mapping".format(feat_value_pair))
            return None

    def mapping_features(self):
        self.extended_feature_names = {}
        self.extended_feature_names_as_array_strings = []
        counter = 0
        if (self.use_categorical):
            for i in range(self.nb_features):
                if (i in self.categorical_features):
                    for j, _ in enumerate(self.binarizer[i].categories_[0]):
                        self.extended_feature_names.update({counter:  (self.feature_names[i], j)})
                        self.extended_feature_names_as_array_strings.append("f{}_{}".format(i,j)) # str(self.feature_names[i]), j))
                        counter = counter + 1
                else:
                    self.extended_feature_names.update({counter: (self.feature_names[i], None)})
                    self.extended_feature_names_as_array_strings.append("f{}".format(i)) #(self.feature_names[i])
                    counter = counter + 1
        else:
            for i in range(self.nb_features):
                self.extended_feature_names.update({counter: (self.feature_names[i], None)})
                self.extended_feature_names_as_array_strings.append("f{}".format(i))#(self.feature_names[i])
                counter = counter + 1

    def readable_sample(self, x):
        readable_x = []
        for i, v in enumerate(x):
            if (i in self.categorical_features):
                readable_x.append(self.categorical_names[i][int(v)])
            else:
                readable_x.append(v)
        return np.asarray(readable_x)

    
    def test_encoding_transformes(self, X_train):
        # test encoding

        X = X_train[[0],:]

        print("Sample of length", len(X[0])," : ", X)
        enc_X = self.transform(X)
        print("Encoded sample of length", len(enc_X[0])," : ", enc_X)
        inv_X = self.transform_inverse(enc_X)
        print("Back to sample", inv_X)
        print("Readable sample", self.readable_sample(inv_X[0]))
        assert((inv_X == X).all())

        '''
        for i in range(len(self.extended_feature_names)):
            print(i, self.transform_inverse_by_index(i))
        for key, value in self.extended_feature_names.items():
            print(value, self.transform_by_value(value))   
        '''       
#
#==============================================================================
class XRF(object):
    """
        class to encode and explain Random Forest classifiers.
    """
    
    def __init__(self, model, feature_names, class_names, verb=0):
        self.cls = model
        #self.data = dataset
        self.verbose = verb
        self.feature_names = feature_names
        self.class_names = class_names
        self.fnames = [f'f{i}' for i in range(len(feature_names))]
        self.f = Forest(model, self.fnames)
        
        if self.verbose > 2:
            self.f.print_trees()
        if self.verbose:    
            print("c RF sz:", self.f.sz)
            print('c max-depth:', self.f.md)
            print('c nof DTs:', len(self.f.trees))
        
    def __del__(self):
        if 'enc' in dir(self):
            del self.enc
        if 'x' in dir(self):
            if self.x.slv is not None:
                self.x.slv.delete()
            del self.x
        del self.f
        self.f = None
        del self.cls
        self.cls = None
        
    def encode(self, inst):
        """
            Encode a tree ensemble trained previously.
        """
        if 'f' not in dir(self):
            self.f = Forest(self.cls, self.fnames)
            #self.f.print_tree()
            
        time = resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime + \
                resource.getrusage(resource.RUSAGE_SELF).ru_utime            
            
        self.enc = SATEncoder(self.f, self.feature_names, len(self.class_names), self.fnames)
        
        #inst = self.data.transform(np.array(inst))[0]
        formula, _, _, _ = self.enc.encode(np.array(inst))
        
        time = resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime + \
                resource.getrusage(resource.RUSAGE_SELF).ru_utime - time        
        
        if self.verbose:
            print('c nof vars:', formula.nv) # number of variables 
            print('c nof clauses:', len(formula.clauses)) # number of clauses    
            print('c encoding time: {0:.3f}'.format(time))            
        
    def explain(self, inst, xtype='abd'):
        """
            Explain a prediction made for a given sample with a previously
            trained RF.
        """
        
        time = resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime + \
                resource.getrusage(resource.RUSAGE_SELF).ru_utime          
        
        if 'enc' not in dir(self):
            self.encode(inst)
        
        #inpvals = self.data.readable_sample(inst)
        inpvals = np.asarray(inst)
        preamble = []
        for f, v in zip(self.feature_names, inpvals):
            if f not in str(v):
                preamble.append('{0} = {1}'.format(f, v))
            else:
                preamble.append(v)
                    
        inps = self.fnames # input (feature value) variables
        #print("inps: {0}".format(inps))
            
        self.x = SATExplainer(self.enc, inps, preamble, self.class_names, verb=self.verbose)
        #inst = self.data.transform(np.array(inst))[0]
        expl = self.x.explain(np.array(inst), xtype)

        time = resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime + \
                resource.getrusage(resource.RUSAGE_SELF).ru_utime - time 
        
        if self.verbose:
            print("c Total time: {0:.3f}".format(time))
            
        return expl
    
    def enumerate(self, inst, xtype='con', smallest=True):
        """
            list all XPs
        """
        if 'enc' not in dir(self):
            self.encode(inst)
            
        if 'x' not in dir(self):
            inpvals = np.asarray(inst)
            preamble = []
            for f, v in zip(self.feature_names, inpvals):
                if f not in str(v):
                    preamble.append('{0} = {1}'.format(f, v))
                else:
                    preamble.append(v)
                    
            inps = self.fnames
            self.x = SATExplainer(self.enc, inps, preamble, self.class_names)
            
        for expl in self.x.enumerate(np.array(inst), xtype, smallest):
            yield expl
        
#
#==============================================================================
class SATEncoder(object):
    """
        Encoder of Random Forest classifier into SAT.
    """
    
    def __init__(self, forest, feats, nof_classes, extended_feature_names,  from_file=None):
        self.forest = forest
        #self.feats = {f: i for i, f in enumerate(feats)}
        self.num_class = nof_classes
        self.vpool = IDPool()
        self.extended_feature_names = extended_feature_names
        
        #encoding formula
        self.cnf = None

        # for interval-based encoding
        self.intvs, self.imaps, self.ivars, self.thvars = None, None, None, None
       
        
    def newVar(self, name):
        """
            If a variable named 'name' already exists then
            return its id; otherwise create a new var
        """
        if name in self.vpool.obj2id: #var has been already created 
            return self.vpool.obj2id[name]
        var = self.vpool.id('{0}'.format(name))
        return var
    
    def nameVar(self, vid):
        """
            input a var id and return a var name
        """
        return self.vpool.obj(abs(vid))
    
    def printLits(self, lits):
        print(["{0}{1}".format("-" if p<0 else "",self.vpool.obj(abs(p))) for p in lits])
    
    def traverse(self, tree, k, clause):
        """
            Traverse a tree and encode each node.
        """

        if tree.children:
            f = tree.name
            v = tree.threshold
            pos = neg = []
            if f in self.intvs:
                d = self.imaps[f][v]
                pos, neg = self.thvars[f][d], -self.thvars[f][d]
            else:
                var = self.newVar(tree.name)
                pos, neg = var, -var
                #print("{0} => {1}".format(tree.name, var))
                
            assert (pos and neg)
            self.traverse(tree.children[0], k, clause + [-neg])
            self.traverse(tree.children[1], k, clause + [-pos])            
        else:  # leaf node
            cvar = self.newVar('class{0}_tr{1}'.format(tree.values,k))
            self.cnf.append(clause + [cvar])
            #self.printLits(clause + [cvar])

    def compute_intervals(self):
        """
            Traverse all trees in the ensemble and extract intervals for each
            feature.

            At this point, the method only works for numerical datasets!
        """

        def traverse_intervals(tree):
            """
                Auxiliary function. Recursive tree traversal.
            """

            if tree.children:
                f = tree.name
                v = tree.threshold
                if f in self.intvs:
                    self.intvs[f].add(v)

                traverse_intervals(tree.children[0])
                traverse_intervals(tree.children[1])

        # initializing the intervals
        self.intvs = {'{0}'.format(f): set([]) for f in self.extended_feature_names if '_' not in f}

        for tree in self.forest.trees:
            traverse_intervals(tree)
                
        # OK, we got all intervals; let's sort the values
        self.intvs = {f: sorted(self.intvs[f]) + ([math.inf] if len(self.intvs[f]) else []) for f in six.iterkeys(self.intvs)}

        self.imaps, self.ivars = {}, {}
        self.thvars = {}
        for feat, intvs in six.iteritems(self.intvs):
            self.imaps[feat] = {}
            self.ivars[feat] = []
            self.thvars[feat] = []
            for i, ub in enumerate(intvs):
                self.imaps[feat][ub] = i

                ivar = self.newVar('{0}_intv{1}'.format(feat, i))
                self.ivars[feat].append(ivar)
                #print('{0}_intv{1}'.format(feat, i))
                
                if ub != math.inf:
                    #assert(i < len(intvs)-1)
                    thvar = self.newVar('{0}_th{1}'.format(feat, i))
                    self.thvars[feat].append(thvar)
                    #print('{0}_th{1}'.format(feat, i))



    def encode(self, sample):
        """
            Do the job.
        """
        
        ###print('Encode RF into SAT ...')

        self.cnf = CNF()
        # getting a tree ensemble
        #self.forest = Forest(self.model, self.extended_feature_names)
        num_tree = len(self.forest.trees)
        self.forest.predict_inst(sample)

        #introducing class variables
        #cvars = [self.newVar('class{0}'.format(i)) for i in range(self.num_class)]
        
        # define Tautology var
        vtaut = self.newVar('Tautology')
        self.cnf.append([vtaut])
            
        # introducing class-tree variables
        ctvars = [[] for t in range(num_tree)]
        for k in range(num_tree):
            for j in range(self.num_class):
                var = self.newVar('class{0}_tr{1}'.format(j,k))
                ctvars[k].append(var)       

        # traverse all trees and extract all possible intervals
        # for each feature
        ###print("compute intervarls ...")
        self.compute_intervals()
        
        #print(self.intvs)
        #print([len(self.intvs[f]) for f in self.intvs])
        #print(self.imaps) 
        #print(self.ivars)
        #print(self.thvars)
        #print(ctvars)
        
        
        ##print("encode trees ...")
        # traversing and encoding each tree
        for k, tree in enumerate(self.forest.trees):
            #print("Encode tree#{0}".format(k))
            # encoding the tree     
            self.traverse(tree, k, [])
            # exactly one class var is true
            #self.printLits(ctvars[k])
            card = CardEnc.atmost(lits=ctvars[k], vpool=self.vpool,encoding=EncType.cardnetwrk) 
            self.cnf.extend(card.clauses)
        
        
            
        # calculate the majority class   
        self.cmaj = self.forest.predict_inst(sample)       
        
        ##print("encode majority class ...")                
        #Cardinality constraint AtMostK to capture a j_th class
        
        if(self.num_class == 2):
            rhs = math.floor(num_tree / 2) + 1
            if(self.cmaj==1 and not num_tree%2):
                rhs = math.floor(num_tree / 2)      
            lhs = [ctvars[k][1 - self.cmaj] for k in range(num_tree)]
            atls = CardEnc.atleast(lits = lhs, bound = rhs, vpool=self.vpool, encoding=EncType.cardnetwrk)
            self.cnf.extend(atls)
        else: 
            zvars = []
            zvars.append([self.newVar('z_0_{0}'.format(k)) for k in range (num_tree) ])
            zvars.append([self.newVar('z_1_{0}'.format(k)) for k in range (num_tree) ])
            ##
            rhs = num_tree
            lhs0 = zvars[0] + [ - ctvars[k][self.cmaj] for k in range(num_tree)]
            ##self.printLits(lhs0)
            atls = CardEnc.atleast(lits = lhs0, bound = rhs, vpool=self.vpool, encoding=EncType.cardnetwrk)
            self.cnf.extend(atls)
            ##
            #rhs = num_tree - 1
            rhs = num_tree + 1
            ###########
            lhs1 =  zvars[1] + [ - ctvars[k][self.cmaj] for k in range(num_tree)]
            ##self.printLits(lhs1)
            atls = CardEnc.atleast(lits = lhs1, bound = rhs, vpool=self.vpool, encoding=EncType.cardnetwrk)
            self.cnf.extend(atls)            
            #
            pvars = [self.newVar('p_{0}'.format(k)) for k in range(self.num_class + 1)]
            ##self.printLits(pvars)
            for k,p in enumerate(pvars):
                for i in range(num_tree):
                    if k == 0:
                        z = zvars[0][i]
                        #self.cnf.append([-p, -z, vtaut])
                        self.cnf.append([-p, z, -vtaut])       
                        #self.printLits([-p, z, -vtaut])
                        #print()
                    elif k == self.cmaj+1:
                        z = zvars[1][i]
                        self.cnf.append([-p, z, -vtaut])       
                        
                        #self.printLits([-p, z, -vtaut])
                        #print()                       
                        
                    else:
                        z = zvars[0][i] if (k<self.cmaj+1) else zvars[1][i]
                        self.cnf.append([-p, -z, ctvars[i][k-1] ])
                        self.cnf.append([-p, z, -ctvars[i][k-1] ])  
                        
                        #self.printLits([-p, -z, ctvars[i][k-1] ])
                        #self.printLits([-p, z, -ctvars[i][k-1] ])
                        #print()
                        
            #
            self.cnf.append([-pvars[0], -pvars[self.cmaj+1]])
            ##
            lhs1 =  pvars[:(self.cmaj+1)]
            ##self.printLits(lhs1)
            eqls = CardEnc.equals(lits = lhs1, bound = 1, vpool=self.vpool, encoding=EncType.cardnetwrk)
            self.cnf.extend(eqls)
            
            
            lhs2 = pvars[(self.cmaj + 1):]
            ##self.printLits(lhs2)
            eqls = CardEnc.equals(lits = lhs2, bound = 1, vpool=self.vpool, encoding=EncType.cardnetwrk)
            self.cnf.extend(eqls)
                
        
            
        ##print("exactly-one feat const ...")
        # enforce exactly one of the feature values to be chosen
        # (for categorical features)
        categories = collections.defaultdict(lambda: [])
        for f in self.extended_feature_names:
            if '_' in f:
                categories[f.split('_')[0]].append(self.newVar(f))        
        for c, feats in six.iteritems(categories):
            # exactly-one feat is True
            self.cnf.append(feats)
            card = CardEnc.atmost(lits=feats, vpool=self.vpool, encoding=EncType.cardnetwrk)
            self.cnf.extend(card.clauses)
        # lits of intervals   
        for f, intvs in six.iteritems(self.ivars):
            if not len(intvs):
                continue
            self.cnf.append(intvs) 
            card = CardEnc.atmost(lits=intvs, vpool=self.vpool, encoding=EncType.cardnetwrk)
            self.cnf.extend(card.clauses)
            #self.printLits(intvs)
        
            
        
        for f, threshold in six.iteritems(self.thvars):
            for j, thvar in enumerate(threshold):
                d = j+1
                pos, neg = self.ivars[f][d:], self.ivars[f][:d] 
                
                if j == 0:
                    assert(len(neg) == 1)
                    self.cnf.append([thvar, neg[-1]])
                    self.cnf.append([-thvar, -neg[-1]])
                else:
                    self.cnf.append([thvar, neg[-1], -threshold[j-1]])
                    self.cnf.append([-thvar, threshold[j-1]])
                    self.cnf.append([-thvar, -neg[-1]])
                
                if j == len(threshold) - 1:
                    assert(len(pos) == 1)
                    self.cnf.append([-thvar, pos[0]])
                    self.cnf.append([thvar, -pos[0]])
                else:
                    self.cnf.append([-thvar, pos[0], threshold[j+1]])
                    self.cnf.append([thvar, -pos[0]])
                    self.cnf.append([thvar, -threshold[j+1]])
          

        
        return self.cnf, self.intvs, self.imaps, self.ivars


#
#==============================================================================
class SATExplainer(object):
    """
        An SAT-inspired minimal explanation extractor for Random Forest models.
    """

    def __init__(self, sat_enc, inps, preamble, target_name, verb=1):
        """
            Constructor.
        """
        self.enc = sat_enc
        self.inps = inps  # input (feature value) variables
        self.target_name = target_name
        self.preamble = preamble
        self.verbose = verb
        self.slv = None    
      
    def prepare_selectors(self, sample):
        # adapt the solver to deal with the current sample
        #self.csel = []
        self.assums = []  # var selectors to be used as assumptions
        self.sel2fid = {}  # selectors to original feature ids
        self.sel2vid = {}  # selectors to categorical feature ids
        self.sel2v = {} # selectors to (categorical/interval) values
        
        #for i in range(self.enc.num_class):
        #    self.csel.append(self.enc.newVar('class{0}'.format(i)))
        #self.csel = self.enc.newVar('class{0}'.format(self.enc.cmaj))
               
        # preparing the selectors
        for i, (inp, val) in enumerate(zip(self.inps, sample), 1):
            if '_' in inp:
                # binarized (OHE) features
                assert (inp not in self.enc.intvs)
                
                feat = inp.split('_')[0]
                selv = self.enc.newVar('selv_{0}'.format(feat))
            
                self.assums.append(selv)   
                if selv not in self.sel2fid:
                    self.sel2fid[selv] = int(feat[1:])
                    self.sel2vid[selv] = [i - 1]
                else:
                    self.sel2vid[selv].append(i - 1)
                    
                p = self.enc.newVar(inp) 
                if not val:
                    p = -p
                else:
                    self.sel2v[selv] = p
                    
                self.enc.cnf.append([-selv, p])
                #self.enc.printLits([-selv, p])
                    
            elif len(self.enc.intvs[inp]):
                #v = None
                #for intv in self.enc.intvs[inp]:
                #    if intv > val:
                #        v = intv
                #        break         
                v = next((intv for intv in self.enc.intvs[inp] if intv > val), None)     
                assert(v is not None)
                
                selv = self.enc.newVar('selv_{0}'.format(inp))     
                self.assums.append(selv)  
                
                assert (selv not in self.sel2fid)
                self.sel2fid[selv] = int(inp[1:])
                self.sel2vid[selv] = [i - 1]
                            
                for j,p in enumerate(self.enc.ivars[inp]):
                    cl = [-selv]
                    if j == self.enc.imaps[inp][v]:
                        cl += [p]
                        self.sel2v[selv] = p
                    else:
                        cl += [-p]
                    
                    self.enc.cnf.append(cl)
                    #self.enc.printLits(cl)

        
    
    def explain(self, sample, xtype='abd', smallest=False):
        """
            Hypotheses minimization.
        """
        if self.verbose:
            print('  explaining:  "IF {0} THEN {1}"'.format(' AND '.join(self.preamble), self.target_name[self.enc.cmaj]))
                    
        
        self.time = resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime + \
                resource.getrusage(resource.RUSAGE_SELF).ru_utime
        
        self.prepare_selectors(sample)
        
        if xtype == 'abd':
            # abductive (PI-) explanation
            expl = self.compute_axp() 
        else:
            # contrastive explanation
            expl = self.compute_cxp()
 
        self.time = resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime + \
                resource.getrusage(resource.RUSAGE_SELF).ru_utime - self.time
    
        # delete sat solver
        self.slv.delete()
        self.slv = None
        
        if self.verbose:
            print('  time: {0:.3f}'.format(self.time))

        return expl    

    def compute_axp(self, smallest=False):
        """
            Compute an Abductive eXplanation
        """         
        self.assums = sorted(set(self.assums))
        if self.verbose:
            print('  # hypos:', len(self.assums))   
        
        #create a SAT solver
        self.slv = Solver(name="glucose3")
        
        # pass a CNF formula
        self.slv.append_formula(self.enc.cnf)    

        def minimal():
            vtaut = self.enc.newVar('Tautology')
            # simple deletion-based linear search
            for i, p in enumerate(self.assums):
                to_test = [vtaut] + self.assums[:i] + self.assums[(i + 1):] + [-p, -self.sel2v[p]]
                sat = self.slv.solve(assumptions=to_test)
                if not sat:
                    self.assums[i] = -p         
            return
        
        if not smallest:
            minimal()
        else:
            raise NotImplementedError('Smallest explanation is not yet implemented.')
            #self.compute_smallest()

        expl = sorted([self.sel2fid[h] for h in self.assums if h>0 ])
        assert len(expl), 'Abductive explanation cannot be an empty-set! otherwise RF fcn is const, i.e. predicts only one class'
        
        if self.verbose:
            print("expl-selctors: ", expl)
            preamble = [self.preamble[i] for i in expl]
            print('  explanation: "IF {0} THEN {1}"'.format(' AND '.join(preamble), self.target_name[self.enc.cmaj]))
            print('  # hypos left:', len(expl))
            
        return expl
        
    def compute_cxp(self, smallest=True):
        """
            Compute a Contrastive eXplanation
        """         
        self.assums = sorted(set(self.assums))
        if self.verbose:
            print('  # hypos:', len(self.assums))   
    
        wcnf = WCNF()
        for cl in self.enc.cnf:
            wcnf.append(cl)    
        for p in self.assums:
            wcnf.append([p], weight=1)
            
        if not smallest:
            # mcs solver
            self.slv = LBX(wcnf, use_cld=True, solver_name='g3')
            mcs = self.slv.compute()
            expl = sorted([self.sel2fid[self.assums[i-1]] for i in mcs])
        else:
            # mxsat solver
            self.slv = RC2(wcnf)
            model = self.slv.compute()
            model = [p for p in model if abs(p) in self.assums]            
            expl = sorted([self.sel2fid[-p] for p in model if p<0 ])
       
        assert len(expl), 'Contrastive explanation cannot be an empty-set!'         
        if self.verbose:
            print("expl-selctors: ", expl)
            preamble = [self.preamble[i] for i in expl]
            pred = self.target_name[self.enc.cmaj]
            print(f'  explanation: "IF {" AND ".join([f"!({p})" for p in preamble])} THEN !(class = {pred})"')
            
        return expl    
    
    def enumerate(self, sample, xtype='con', smallest=True):
        """
            list all CXp's or AXp's
        """
        if xtype == 'abd':
            raise NotImplementedError('Enumerate abductive explanations is not yet implemented.')
        time = resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime + \
                resource.getrusage(resource.RUSAGE_SELF).ru_utime
        
        if 'assums' not in dir(self):
            self.prepare_selectors(sample)
            self.assums = sorted(set(self.assums))
            #
            
        # compute CXp's/AE's    
        if self.slv is None:    
            wcnf = WCNF()
            for cl in self.enc.cnf:
                wcnf.append(cl)    
            for p in self.assums:
                wcnf.append([p], weight=1)
            if smallest:    
                # incremental maxsat solver    
                self.slv = RC2(wcnf, adapt=True, exhaust=True, minz=True)
            else:
                # mcs solver
                self.slv = LBX(wcnf, use_cld=True, solver_name='g3')
                #self.slv = MCSls(wcnf, use_cld=True, solver_name='g3')                
                
        if smallest:    
            print('smallest')
            for model in self.slv.enumerate(block=-1):
                #model = [p for p in model if abs(p) in self.assums]
                expl = sorted([self.sel2fid[-p] for p in model if (p<0 and (-p in self.assums))])
                cxp_feats = [f'f{j}' for j in expl]
                advx = []
                for f in cxp_feats:
                    ps = [p for p in model if (p>0 and (p in self.enc.ivars[f]))]
                    assert(len(ps) == 1)
                    advx.append(tuple([f,self.enc.nameVar(ps[0])]))   
                #yield expl
                print(cxp_feats, advx)
                yield advx
        else:
            print('LBX')
            for mcs in self.slv.enumerate():
                expl = sorted([self.sel2fid[self.assums[i-1]] for i in mcs])
                assumptions = [-p if(i in mcs) else p for i,p in enumerate(self.assums, 1)]
                #for k, model in enumerate(self.slv.oracle.enum_models(assumptions), 1):
                assert (self.slv.oracle.solve(assumptions))
                model = self.slv.oracle.get_model()
                cxp_feats = [f'f{j}' for j in expl]
                advx = []
                for f in cxp_feats:
                    ps = [p for p in model if (p>0 and (p in self.enc.ivars[f]))]
                    assert(len(ps) == 1)
                    advx.append(tuple([f,self.enc.nameVar(ps[0])]))
                yield advx
                self.slv.block(mcs)
                #yield expl
                
                
        time = resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime + \
                resource.getrusage(resource.RUSAGE_SELF).ru_utime - time 
        if self.verbose:
            print('c expl time: {0:.3f}'.format(time))
        #
        self.slv.delete()
        self.slv = None