-
Caroline DE POURTALES authoredCaroline DE POURTALES authored
xforest.py 32.45 KiB
#from sklearn.ensemble._voting import VotingClassifier
#from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import OneHotEncoder, LabelEncoder
from sklearn.model_selection import train_test_split
#from sklearn.metrics import accuracy_score
import numpy as np
import sys
import os
import resource
import collections
from itertools import combinations
from six.moves import range
import six
import math
from data import Data
from .rndmforest import RF2001, VotingRF
from .tree import Forest, predict_tree
#from .encode import SATEncoder
from pysat.formula import CNF, WCNF, IDPool
from pysat.solvers import Solver
from pysat.card import CardEnc, EncType
from pysat.examples.lbx import LBX
from pysat.examples.mcsls import MCSls
from pysat.examples.rc2 import RC2
#
#==============================================================================
class Dataset(Data):
"""
Class for representing dataset (transactions).
"""
def __init__(self, filename=None, fpointer=None, mapfile=None,
separator=' ', use_categorical = False):
super().__init__(filename, fpointer, mapfile, separator, use_categorical)
# split data into X and y
self.feature_names = self.names[:-1]
self.nb_features = len(self.feature_names)
self.use_categorical = use_categorical
samples = np.asarray(self.samps)
if not all(c.isnumeric() for c in samples[:, -1]):
le = LabelEncoder()
le.fit(samples[:, -1])
samples[:, -1]= le.transform(samples[:, -1])
self.class_names = le.classes_
print(le.classes_)
print(samples[1:4, :])
samples = np.asarray(samples, dtype=np.float32)
self.X = samples[:, 0: self.nb_features]
self.y = samples[:, self.nb_features]
self.num_class = len(set(self.y))
self.target_name = list(range(self.num_class))
print("c nof features: {0}".format(self.nb_features))
print("c nof classes: {0}".format(self.num_class))
print("c nof samples: {0}".format(len(self.samps)))
# check if we have info about categorical features
if (self.use_categorical):
self.target_name = self.class_names
self.binarizer = {}
for i in self.categorical_features:
self.binarizer.update({i: OneHotEncoder(categories='auto', sparse=False)})#,
self.binarizer[i].fit(self.X[:,[i]])
else:
self.categorical_features = []
self.categorical_names = []
self.binarizer = []
#feat map
self.mapping_features()
def train_test_split(self, test_size=0.2, seed=0):
return train_test_split(self.X, self.y, test_size=test_size, random_state=seed)
def transform(self, x):
if(len(x) == 0):
return x
if (len(x.shape) == 1):
x = np.expand_dims(x, axis=0)
if (self.use_categorical):
assert(self.binarizer != [])
tx = []
for i in range(self.nb_features):
#self.binarizer[i].drop = None
if (i in self.categorical_features):
self.binarizer[i].drop = None
tx_aux = self.binarizer[i].transform(x[:,[i]])
tx_aux = np.vstack(tx_aux)
tx.append(tx_aux)
else:
tx.append(x[:,[i]])
tx = np.hstack(tx)
return tx
else:
return x
def transform_inverse(self, x):
if(len(x) == 0):
return x
if (len(x.shape) == 1):
x = np.expand_dims(x, axis=0)
if (self.use_categorical):
assert(self.binarizer != [])
inverse_x = []
for i, xi in enumerate(x):
inverse_xi = np.zeros(self.nb_features)
for f in range(self.nb_features):
if f in self.categorical_features:
nb_values = len(self.categorical_names[f])
v = xi[:nb_values]
v = np.expand_dims(v, axis=0)
iv = self.binarizer[f].inverse_transform(v)
inverse_xi[f] =iv
xi = xi[nb_values:]
else:
inverse_xi[f] = xi[0]
xi = xi[1:]
inverse_x.append(inverse_xi)
return inverse_x
else:
return x
def transform_inverse_by_index(self, idx):
if (idx in self.extended_feature_names):
return self.extended_feature_names[idx]
else:
print("Warning there is no feature {} in the internal mapping".format(idx))
return None
def transform_by_value(self, feat_value_pair):
if (feat_value_pair in self.extended_feature_names.values()):
keys = (list(self.extended_feature_names.keys())[list( self.extended_feature_names.values()).index(feat_value_pair)])
return keys
else:
print("Warning there is no value {} in the internal mapping".format(feat_value_pair))
return None
def mapping_features(self):
self.extended_feature_names = {}
self.extended_feature_names_as_array_strings = []
counter = 0
if (self.use_categorical):
for i in range(self.nb_features):
if (i in self.categorical_features):
for j, _ in enumerate(self.binarizer[i].categories_[0]):
self.extended_feature_names.update({counter: (self.feature_names[i], j)})
self.extended_feature_names_as_array_strings.append("f{}_{}".format(i,j)) # str(self.feature_names[i]), j))
counter = counter + 1
else:
self.extended_feature_names.update({counter: (self.feature_names[i], None)})
self.extended_feature_names_as_array_strings.append("f{}".format(i)) #(self.feature_names[i])
counter = counter + 1
else:
for i in range(self.nb_features):
self.extended_feature_names.update({counter: (self.feature_names[i], None)})
self.extended_feature_names_as_array_strings.append("f{}".format(i))#(self.feature_names[i])
counter = counter + 1
def readable_sample(self, x):
readable_x = []
for i, v in enumerate(x):
if (i in self.categorical_features):
readable_x.append(self.categorical_names[i][int(v)])
else:
readable_x.append(v)
return np.asarray(readable_x)
def test_encoding_transformes(self, X_train):
# test encoding
X = X_train[[0],:]
print("Sample of length", len(X[0])," : ", X)
enc_X = self.transform(X)
print("Encoded sample of length", len(enc_X[0])," : ", enc_X)
inv_X = self.transform_inverse(enc_X)
print("Back to sample", inv_X)
print("Readable sample", self.readable_sample(inv_X[0]))
assert((inv_X == X).all())
'''
for i in range(len(self.extended_feature_names)):
print(i, self.transform_inverse_by_index(i))
for key, value in self.extended_feature_names.items():
print(value, self.transform_by_value(value))
'''
#
#==============================================================================
class XRF(object):
"""
class to encode and explain Random Forest classifiers.
"""
def __init__(self, model, feature_names, class_names, verb=0):
self.cls = model
#self.data = dataset
self.verbose = verb
self.feature_names = feature_names
self.class_names = class_names
self.fnames = [f'f{i}' for i in range(len(feature_names))]
self.f = Forest(model, self.fnames)
if self.verbose > 2:
self.f.print_trees()
if self.verbose:
print("c RF sz:", self.f.sz)
print('c max-depth:', self.f.md)
print('c nof DTs:', len(self.f.trees))
def __del__(self):
if 'enc' in dir(self):
del self.enc
if 'x' in dir(self):
if self.x.slv is not None:
self.x.slv.delete()
del self.x
del self.f
self.f = None
del self.cls
self.cls = None
def encode(self, inst):
"""
Encode a tree ensemble trained previously.
"""
if 'f' not in dir(self):
self.f = Forest(self.cls, self.fnames)
#self.f.print_tree()
time = resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime + \
resource.getrusage(resource.RUSAGE_SELF).ru_utime
self.enc = SATEncoder(self.f, self.feature_names, len(self.class_names), self.fnames)
#inst = self.data.transform(np.array(inst))[0]
formula, _, _, _ = self.enc.encode(np.array(inst))
time = resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime + \
resource.getrusage(resource.RUSAGE_SELF).ru_utime - time
if self.verbose:
print('c nof vars:', formula.nv) # number of variables
print('c nof clauses:', len(formula.clauses)) # number of clauses
print('c encoding time: {0:.3f}'.format(time))
def explain(self, inst, xtype='abd'):
"""
Explain a prediction made for a given sample with a previously
trained RF.
"""
time = resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime + \
resource.getrusage(resource.RUSAGE_SELF).ru_utime
if 'enc' not in dir(self):
self.encode(inst)
#inpvals = self.data.readable_sample(inst)
inpvals = np.asarray(inst)
preamble = []
for f, v in zip(self.feature_names, inpvals):
if f not in str(v):
preamble.append('{0} = {1}'.format(f, v))
else:
preamble.append(v)
inps = self.fnames # input (feature value) variables
#print("inps: {0}".format(inps))
self.x = SATExplainer(self.enc, inps, preamble, self.class_names, verb=self.verbose)
#inst = self.data.transform(np.array(inst))[0]
expl = self.x.explain(np.array(inst), xtype)
time = resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime + \
resource.getrusage(resource.RUSAGE_SELF).ru_utime - time
if self.verbose:
print("c Total time: {0:.3f}".format(time))
return expl
def enumerate(self, inst, xtype='con', smallest=True):
"""
list all XPs
"""
if 'enc' not in dir(self):
self.encode(inst)
if 'x' not in dir(self):
inpvals = np.asarray(inst)
preamble = []
for f, v in zip(self.feature_names, inpvals):
if f not in str(v):
preamble.append('{0} = {1}'.format(f, v))
else:
preamble.append(v)
inps = self.fnames
self.x = SATExplainer(self.enc, inps, preamble, self.class_names)
for expl in self.x.enumerate(np.array(inst), xtype, smallest):
yield expl
#
#==============================================================================
class SATEncoder(object):
"""
Encoder of Random Forest classifier into SAT.
"""
def __init__(self, forest, feats, nof_classes, extended_feature_names, from_file=None):
self.forest = forest
#self.feats = {f: i for i, f in enumerate(feats)}
self.num_class = nof_classes
self.vpool = IDPool()
self.extended_feature_names = extended_feature_names
#encoding formula
self.cnf = None
# for interval-based encoding
self.intvs, self.imaps, self.ivars, self.thvars = None, None, None, None
def newVar(self, name):
"""
If a variable named 'name' already exists then
return its id; otherwise create a new var
"""
if name in self.vpool.obj2id: #var has been already created
return self.vpool.obj2id[name]
var = self.vpool.id('{0}'.format(name))
return var
def nameVar(self, vid):
"""
input a var id and return a var name
"""
return self.vpool.obj(abs(vid))
def printLits(self, lits):
print(["{0}{1}".format("-" if p<0 else "",self.vpool.obj(abs(p))) for p in lits])
def traverse(self, tree, k, clause):
"""
Traverse a tree and encode each node.
"""
if tree.children:
f = tree.name
v = tree.threshold
pos = neg = []
if f in self.intvs:
d = self.imaps[f][v]
pos, neg = self.thvars[f][d], -self.thvars[f][d]
else:
var = self.newVar(tree.name)
pos, neg = var, -var
#print("{0} => {1}".format(tree.name, var))
assert (pos and neg)
self.traverse(tree.children[0], k, clause + [-neg])
self.traverse(tree.children[1], k, clause + [-pos])
else: # leaf node
cvar = self.newVar('class{0}_tr{1}'.format(tree.values,k))
self.cnf.append(clause + [cvar])
#self.printLits(clause + [cvar])
def compute_intervals(self):
"""
Traverse all trees in the ensemble and extract intervals for each
feature.
At this point, the method only works for numerical datasets!
"""
def traverse_intervals(tree):
"""
Auxiliary function. Recursive tree traversal.
"""
if tree.children:
f = tree.name
v = tree.threshold
if f in self.intvs:
self.intvs[f].add(v)
traverse_intervals(tree.children[0])
traverse_intervals(tree.children[1])
# initializing the intervals
self.intvs = {'{0}'.format(f): set([]) for f in self.extended_feature_names if '_' not in f}
for tree in self.forest.trees:
traverse_intervals(tree)
# OK, we got all intervals; let's sort the values
self.intvs = {f: sorted(self.intvs[f]) + ([math.inf] if len(self.intvs[f]) else []) for f in six.iterkeys(self.intvs)}
self.imaps, self.ivars = {}, {}
self.thvars = {}
for feat, intvs in six.iteritems(self.intvs):
self.imaps[feat] = {}
self.ivars[feat] = []
self.thvars[feat] = []
for i, ub in enumerate(intvs):
self.imaps[feat][ub] = i
ivar = self.newVar('{0}_intv{1}'.format(feat, i))
self.ivars[feat].append(ivar)
#print('{0}_intv{1}'.format(feat, i))
if ub != math.inf:
#assert(i < len(intvs)-1)
thvar = self.newVar('{0}_th{1}'.format(feat, i))
self.thvars[feat].append(thvar)
#print('{0}_th{1}'.format(feat, i))
def encode(self, sample):
"""
Do the job.
"""
###print('Encode RF into SAT ...')
self.cnf = CNF()
# getting a tree ensemble
#self.forest = Forest(self.model, self.extended_feature_names)
num_tree = len(self.forest.trees)
self.forest.predict_inst(sample)
#introducing class variables
#cvars = [self.newVar('class{0}'.format(i)) for i in range(self.num_class)]
# define Tautology var
vtaut = self.newVar('Tautology')
self.cnf.append([vtaut])
# introducing class-tree variables
ctvars = [[] for t in range(num_tree)]
for k in range(num_tree):
for j in range(self.num_class):
var = self.newVar('class{0}_tr{1}'.format(j,k))
ctvars[k].append(var)
# traverse all trees and extract all possible intervals
# for each feature
###print("compute intervarls ...")
self.compute_intervals()
#print(self.intvs)
#print([len(self.intvs[f]) for f in self.intvs])
#print(self.imaps)
#print(self.ivars)
#print(self.thvars)
#print(ctvars)
##print("encode trees ...")
# traversing and encoding each tree
for k, tree in enumerate(self.forest.trees):
#print("Encode tree#{0}".format(k))
# encoding the tree
self.traverse(tree, k, [])
# exactly one class var is true
#self.printLits(ctvars[k])
card = CardEnc.atmost(lits=ctvars[k], vpool=self.vpool,encoding=EncType.cardnetwrk)
self.cnf.extend(card.clauses)
# calculate the majority class
self.cmaj = self.forest.predict_inst(sample)
##print("encode majority class ...")
#Cardinality constraint AtMostK to capture a j_th class
if(self.num_class == 2):
rhs = math.floor(num_tree / 2) + 1
if(self.cmaj==1 and not num_tree%2):
rhs = math.floor(num_tree / 2)
lhs = [ctvars[k][1 - self.cmaj] for k in range(num_tree)]
atls = CardEnc.atleast(lits = lhs, bound = rhs, vpool=self.vpool, encoding=EncType.cardnetwrk)
self.cnf.extend(atls)
else:
zvars = []
zvars.append([self.newVar('z_0_{0}'.format(k)) for k in range (num_tree) ])
zvars.append([self.newVar('z_1_{0}'.format(k)) for k in range (num_tree) ])
##
rhs = num_tree
lhs0 = zvars[0] + [ - ctvars[k][self.cmaj] for k in range(num_tree)]
##self.printLits(lhs0)
atls = CardEnc.atleast(lits = lhs0, bound = rhs, vpool=self.vpool, encoding=EncType.cardnetwrk)
self.cnf.extend(atls)
##
#rhs = num_tree - 1
rhs = num_tree + 1
###########
lhs1 = zvars[1] + [ - ctvars[k][self.cmaj] for k in range(num_tree)]
##self.printLits(lhs1)
atls = CardEnc.atleast(lits = lhs1, bound = rhs, vpool=self.vpool, encoding=EncType.cardnetwrk)
self.cnf.extend(atls)
#
pvars = [self.newVar('p_{0}'.format(k)) for k in range(self.num_class + 1)]
##self.printLits(pvars)
for k,p in enumerate(pvars):
for i in range(num_tree):
if k == 0:
z = zvars[0][i]
#self.cnf.append([-p, -z, vtaut])
self.cnf.append([-p, z, -vtaut])
#self.printLits([-p, z, -vtaut])
#print()
elif k == self.cmaj+1:
z = zvars[1][i]
self.cnf.append([-p, z, -vtaut])
#self.printLits([-p, z, -vtaut])
#print()
else:
z = zvars[0][i] if (k<self.cmaj+1) else zvars[1][i]
self.cnf.append([-p, -z, ctvars[i][k-1] ])
self.cnf.append([-p, z, -ctvars[i][k-1] ])
#self.printLits([-p, -z, ctvars[i][k-1] ])
#self.printLits([-p, z, -ctvars[i][k-1] ])
#print()
#
self.cnf.append([-pvars[0], -pvars[self.cmaj+1]])
##
lhs1 = pvars[:(self.cmaj+1)]
##self.printLits(lhs1)
eqls = CardEnc.equals(lits = lhs1, bound = 1, vpool=self.vpool, encoding=EncType.cardnetwrk)
self.cnf.extend(eqls)
lhs2 = pvars[(self.cmaj + 1):]
##self.printLits(lhs2)
eqls = CardEnc.equals(lits = lhs2, bound = 1, vpool=self.vpool, encoding=EncType.cardnetwrk)
self.cnf.extend(eqls)
##print("exactly-one feat const ...")
# enforce exactly one of the feature values to be chosen
# (for categorical features)
categories = collections.defaultdict(lambda: [])
for f in self.extended_feature_names:
if '_' in f:
categories[f.split('_')[0]].append(self.newVar(f))
for c, feats in six.iteritems(categories):
# exactly-one feat is True
self.cnf.append(feats)
card = CardEnc.atmost(lits=feats, vpool=self.vpool, encoding=EncType.cardnetwrk)
self.cnf.extend(card.clauses)
# lits of intervals
for f, intvs in six.iteritems(self.ivars):
if not len(intvs):
continue
self.cnf.append(intvs)
card = CardEnc.atmost(lits=intvs, vpool=self.vpool, encoding=EncType.cardnetwrk)
self.cnf.extend(card.clauses)
#self.printLits(intvs)
for f, threshold in six.iteritems(self.thvars):
for j, thvar in enumerate(threshold):
d = j+1
pos, neg = self.ivars[f][d:], self.ivars[f][:d]
if j == 0:
assert(len(neg) == 1)
self.cnf.append([thvar, neg[-1]])
self.cnf.append([-thvar, -neg[-1]])
else:
self.cnf.append([thvar, neg[-1], -threshold[j-1]])
self.cnf.append([-thvar, threshold[j-1]])
self.cnf.append([-thvar, -neg[-1]])
if j == len(threshold) - 1:
assert(len(pos) == 1)
self.cnf.append([-thvar, pos[0]])
self.cnf.append([thvar, -pos[0]])
else:
self.cnf.append([-thvar, pos[0], threshold[j+1]])
self.cnf.append([thvar, -pos[0]])
self.cnf.append([thvar, -threshold[j+1]])
return self.cnf, self.intvs, self.imaps, self.ivars
#
#==============================================================================
class SATExplainer(object):
"""
An SAT-inspired minimal explanation extractor for Random Forest models.
"""
def __init__(self, sat_enc, inps, preamble, target_name, verb=1):
"""
Constructor.
"""
self.enc = sat_enc
self.inps = inps # input (feature value) variables
self.target_name = target_name
self.preamble = preamble
self.verbose = verb
self.slv = None
def prepare_selectors(self, sample):
# adapt the solver to deal with the current sample
#self.csel = []
self.assums = [] # var selectors to be used as assumptions
self.sel2fid = {} # selectors to original feature ids
self.sel2vid = {} # selectors to categorical feature ids
self.sel2v = {} # selectors to (categorical/interval) values
#for i in range(self.enc.num_class):
# self.csel.append(self.enc.newVar('class{0}'.format(i)))
#self.csel = self.enc.newVar('class{0}'.format(self.enc.cmaj))
# preparing the selectors
for i, (inp, val) in enumerate(zip(self.inps, sample), 1):
if '_' in inp:
# binarized (OHE) features
assert (inp not in self.enc.intvs)
feat = inp.split('_')[0]
selv = self.enc.newVar('selv_{0}'.format(feat))
self.assums.append(selv)
if selv not in self.sel2fid:
self.sel2fid[selv] = int(feat[1:])
self.sel2vid[selv] = [i - 1]
else:
self.sel2vid[selv].append(i - 1)
p = self.enc.newVar(inp)
if not val:
p = -p
else:
self.sel2v[selv] = p
self.enc.cnf.append([-selv, p])
#self.enc.printLits([-selv, p])
elif len(self.enc.intvs[inp]):
#v = None
#for intv in self.enc.intvs[inp]:
# if intv > val:
# v = intv
# break
v = next((intv for intv in self.enc.intvs[inp] if intv > val), None)
assert(v is not None)
selv = self.enc.newVar('selv_{0}'.format(inp))
self.assums.append(selv)
assert (selv not in self.sel2fid)
self.sel2fid[selv] = int(inp[1:])
self.sel2vid[selv] = [i - 1]
for j,p in enumerate(self.enc.ivars[inp]):
cl = [-selv]
if j == self.enc.imaps[inp][v]:
cl += [p]
self.sel2v[selv] = p
else:
cl += [-p]
self.enc.cnf.append(cl)
#self.enc.printLits(cl)
def explain(self, sample, xtype='abd', smallest=False):
"""
Hypotheses minimization.
"""
if self.verbose:
print(' explaining: "IF {0} THEN {1}"'.format(' AND '.join(self.preamble), self.target_name[self.enc.cmaj]))
self.time = resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime + \
resource.getrusage(resource.RUSAGE_SELF).ru_utime
self.prepare_selectors(sample)
if xtype == 'abd':
# abductive (PI-) explanation
expl = self.compute_axp()
else:
# contrastive explanation
expl = self.compute_cxp()
self.time = resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime + \
resource.getrusage(resource.RUSAGE_SELF).ru_utime - self.time
# delete sat solver
self.slv.delete()
self.slv = None
if self.verbose:
print(' time: {0:.3f}'.format(self.time))
return expl
def compute_axp(self, smallest=False):
"""
Compute an Abductive eXplanation
"""
self.assums = sorted(set(self.assums))
if self.verbose:
print(' # hypos:', len(self.assums))
#create a SAT solver
self.slv = Solver(name="glucose3")
# pass a CNF formula
self.slv.append_formula(self.enc.cnf)
def minimal():
vtaut = self.enc.newVar('Tautology')
# simple deletion-based linear search
for i, p in enumerate(self.assums):
to_test = [vtaut] + self.assums[:i] + self.assums[(i + 1):] + [-p, -self.sel2v[p]]
sat = self.slv.solve(assumptions=to_test)
if not sat:
self.assums[i] = -p
return
if not smallest:
minimal()
else:
raise NotImplementedError('Smallest explanation is not yet implemented.')
#self.compute_smallest()
expl = sorted([self.sel2fid[h] for h in self.assums if h>0 ])
assert len(expl), 'Abductive explanation cannot be an empty-set! otherwise RF fcn is const, i.e. predicts only one class'
if self.verbose:
print("expl-selctors: ", expl)
preamble = [self.preamble[i] for i in expl]
print(' explanation: "IF {0} THEN {1}"'.format(' AND '.join(preamble), self.target_name[self.enc.cmaj]))
print(' # hypos left:', len(expl))
return expl
def compute_cxp(self, smallest=True):
"""
Compute a Contrastive eXplanation
"""
self.assums = sorted(set(self.assums))
if self.verbose:
print(' # hypos:', len(self.assums))
wcnf = WCNF()
for cl in self.enc.cnf:
wcnf.append(cl)
for p in self.assums:
wcnf.append([p], weight=1)
if not smallest:
# mcs solver
self.slv = LBX(wcnf, use_cld=True, solver_name='g3')
mcs = self.slv.compute()
expl = sorted([self.sel2fid[self.assums[i-1]] for i in mcs])
else:
# mxsat solver
self.slv = RC2(wcnf)
model = self.slv.compute()
model = [p for p in model if abs(p) in self.assums]
expl = sorted([self.sel2fid[-p] for p in model if p<0 ])
assert len(expl), 'Contrastive explanation cannot be an empty-set!'
if self.verbose:
print("expl-selctors: ", expl)
preamble = [self.preamble[i] for i in expl]
pred = self.target_name[self.enc.cmaj]
print(f' explanation: "IF {" AND ".join([f"!({p})" for p in preamble])} THEN !(class = {pred})"')
return expl
def enumerate(self, sample, xtype='con', smallest=True):
"""
list all CXp's or AXp's
"""
if xtype == 'abd':
raise NotImplementedError('Enumerate abductive explanations is not yet implemented.')
time = resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime + \
resource.getrusage(resource.RUSAGE_SELF).ru_utime
if 'assums' not in dir(self):
self.prepare_selectors(sample)
self.assums = sorted(set(self.assums))
#
# compute CXp's/AE's
if self.slv is None:
wcnf = WCNF()
for cl in self.enc.cnf:
wcnf.append(cl)
for p in self.assums:
wcnf.append([p], weight=1)
if smallest:
# incremental maxsat solver
self.slv = RC2(wcnf, adapt=True, exhaust=True, minz=True)
else:
# mcs solver
self.slv = LBX(wcnf, use_cld=True, solver_name='g3')
#self.slv = MCSls(wcnf, use_cld=True, solver_name='g3')
if smallest:
print('smallest')
for model in self.slv.enumerate(block=-1):
#model = [p for p in model if abs(p) in self.assums]
expl = sorted([self.sel2fid[-p] for p in model if (p<0 and (-p in self.assums))])
cxp_feats = [f'f{j}' for j in expl]
advx = []
for f in cxp_feats:
ps = [p for p in model if (p>0 and (p in self.enc.ivars[f]))]
assert(len(ps) == 1)
advx.append(tuple([f,self.enc.nameVar(ps[0])]))
#yield expl
print(cxp_feats, advx)
yield advx
else:
print('LBX')
for mcs in self.slv.enumerate():
expl = sorted([self.sel2fid[self.assums[i-1]] for i in mcs])
assumptions = [-p if(i in mcs) else p for i,p in enumerate(self.assums, 1)]
#for k, model in enumerate(self.slv.oracle.enum_models(assumptions), 1):
assert (self.slv.oracle.solve(assumptions))
model = self.slv.oracle.get_model()
cxp_feats = [f'f{j}' for j in expl]
advx = []
for f in cxp_feats:
ps = [p for p in model if (p>0 and (p in self.enc.ivars[f]))]
assert(len(ps) == 1)
advx.append(tuple([f,self.enc.nameVar(ps[0])]))
yield advx
self.slv.block(mcs)
#yield expl
time = resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime + \
resource.getrusage(resource.RUSAGE_SELF).ru_utime - time
if self.verbose:
print('c expl time: {0:.3f}'.format(time))
#
self.slv.delete()
self.slv = None