from __future__ import annotations
from enum import Enum
from math import ceil, gcd
from typing import TYPE_CHECKING
import numpy as np
from pydantic import validate_call
try:
from pysat.card import CardEnc
from pysat.card import EncType as CardEncType
except ImportError:
CardEnc = None
CardEncType = None
try:
from pysat.pb import EncType as PBEncType
from pysat.pb import PBEnc
except (AssertionError, ImportError):
PBEncType = None
PBEnc = None
from ..typing import NonNegativeInt
from ._base import BaseModel
from ._builder.model import ModelBuilder, ModelBuilderFactory
from ._managers import FeatureManager, GarbageManager, TreeManager
if TYPE_CHECKING:
from collections.abc import Iterable
from ..abc import Mapper
from ..feature import Feature
from ..tree import Tree
from ..typing import Array1D, Key, NonNegativeArray1D, NonNegativeInt
from ._variables import FeatureVar
[docs]
class Model(BaseModel, FeatureManager, GarbageManager, TreeManager):
"""
Weighted MaxSAT formulation for tree ensemble explanations.
The Boolean feature variables encode the counterfactual point ``x`` and
the Boolean tree variables encode the active leaf decisions ``p_(t,l)``.
"""
# Model builder for the ensemble.
_builder: ModelBuilder
DEFAULT_EPSILON: int = 1
_obj_scale: int = int(1e8)
_hard_voting: bool = False
[docs]
class Type(Enum):
MAXSAT = "MAXSAT"
def __init__(
self,
trees: Iterable[Tree],
mapper: Mapper[Feature],
*,
weights: NonNegativeArray1D | None = None,
hard_voting: bool = False,
max_samples: NonNegativeInt = 0,
epsilon: int = DEFAULT_EPSILON,
model_type: Model.Type = Type.MAXSAT,
) -> None:
"""
Initialize an empty weighted MaxSAT model for a parsed ensemble.
Parameters
----------
trees
Parsed trees whose leaf activations define the ensemble class
scores.
mapper
Feature mapper aligning processed query coordinates with the
Boolean feature encoding.
weights
Optional tree weights.
hard_voting
Whether to build the hard-voting MaxSAT encoding.
max_samples
Reserved parameter used by compatible higher-level explainers.
epsilon
Integer tie-breaking margin used in the score constraints.
model_type
Backend builder variant.
"""
BaseModel.__init__(self)
TreeManager.__init__(
self,
trees=trees,
weights=weights,
)
FeatureManager.__init__(self, mapper=mapper)
GarbageManager.__init__(self)
self._hard_voting = hard_voting
self._set_weights(weights=weights)
self._max_samples = max_samples
self._epsilon = epsilon
self._set_builder(model_type=model_type)
[docs]
def build(self) -> None:
r"""
Create the Boolean encoding of features, leaves, and path consistency.
After ``build()``, the model contains the Boolean variables encoding
the counterfactual point :math:`x`, the active leaves
:math:`p_{t,\ell}`, and the hard clauses linking both.
"""
self.build_features(self)
self.build_trees(self)
self._builder.build(self, trees=self.trees, mapper=self.mapper)
[docs]
def add_objective(
self,
x: Array1D,
*,
norm: NonNegativeInt = 1,
) -> None:
r"""
Encode the :math:`L_1` distance :math:`d_1(x, \hat{x})` as soft clauses.
Parameters
----------
x
Query point :math:`\hat{x}` in the processed feature space. The
code parameter is named ``x``, but mathematically it represents the
query.
norm
Distance norm. The MaxSAT backend currently supports only
:math:`L_1`.
Raises
------
ValueError
If ``x`` does not have the expected size or ``norm`` is
unsupported.
"""
if x.size != self.mapper.n_columns:
msg = f"Expected {self.mapper.n_columns} values, got {x.size}"
raise ValueError(msg)
if norm != 1:
msg = f"Unsupported norm: {norm}"
raise ValueError(msg)
x_arr = np.asarray(x, dtype=float).ravel()
self._add_objective_maxorc(x_arr)
def _add_objective_maxorc(self, x_arr: Array1D) -> None:
variables = self.mapper.values()
names = list(self.mapper.keys())
k = 0
indexer = self.mapper.idx
for v, name in zip(variables, names, strict=True):
if v.is_one_hot_encoded:
for code in v.codes:
idx = indexer.get(name, code)
self._add_soft_l1_ohe_maxorc(x_arr[idx], v, code=code)
k += 1
elif v.is_continuous:
self._add_soft_l1_threshold_numeric(x_arr[k], v)
k += 1
elif v.is_discrete:
if v.has_threshold_encoding:
self._add_soft_l1_threshold_numeric(x_arr[k], v)
else:
self._add_soft_l1_discrete_maxorc(x_arr[k], v)
k += 1
elif v.is_binary:
self._add_soft_l1_binary_maxorc(x_arr[k], v)
k += 1
else:
k += 1
[docs]
def _add_soft_l1_binary(self, x_val: float, v: FeatureVar) -> None:
r"""
Add the soft clause corresponding to a binary coordinate of :math:`x`.
The clause penalizes a deviation from the query coordinate
:math:`\hat{x}_j` with weight proportional to
:math:`|x_j - \hat{x}_j|`.
"""
weight = int(self._obj_scale)
x_var = v.xget()
binary_threshold = 0.5
if x_val > binary_threshold:
# If x=1, penalize flipping to 0
self.add_soft([x_var], weight=weight)
else:
# If x=0, penalize flipping to 1
self.add_soft([-x_var], weight=weight)
def _add_soft_l1_binary_maxorc(self, x_val: float, v: FeatureVar) -> None:
x_var = v.xget()
binary_threshold = 0.5
if x_val > binary_threshold:
self.add_soft([x_var], weight=1.0)
else:
self.add_soft([-x_var], weight=1.0)
[docs]
def _add_soft_l1_ohe(
self,
x_val: float,
v: FeatureVar,
code: Key,
) -> None:
"""
Add the soft clause for one coordinate of a one-hot block.
The weight is halved so that changing category in an unordered nominal
feature contributes one unit of :math:`L_1` distance instead of two.
"""
weight = int(self._obj_scale / 2) # OHE uses half weight
x_var = v.xget(code=code)
binary_threshold = 0.5
if x_val > binary_threshold:
self.add_soft([x_var], weight=weight)
else:
self.add_soft([-x_var], weight=weight)
def _add_soft_l1_ohe_maxorc(
self,
x_val: float,
v: FeatureVar,
code: Key,
) -> None:
x_var = v.xget(code=code)
binary_threshold = 0.5
if x_val > binary_threshold:
self.add_soft([x_var], weight=1.0)
else:
self.add_soft([-x_var], weight=1.0)
[docs]
def _add_soft_l1_continuous(self, x_val: float, v: FeatureVar) -> None:
r"""
Add soft clauses for an interval-encoded continuous coordinate.
Each Boolean interval selector receives a weight equal to the scaled
distance between the query coordinate :math:`\hat{x}_j` and that
interval.
"""
levels = v.levels
intervals_cost = self._get_intervals_cost(levels, x_val)
for i in range(len(levels) - 1):
cost = intervals_cost[i]
if cost > 0:
mu_var = v.xget(mu=i)
self.add_soft([-mu_var], weight=cost)
[docs]
def _add_soft_l1_discrete(self, x_val: float, v: FeatureVar) -> None:
r"""
Add soft clauses for an ordinal discrete feature.
If ``mu[i]`` denotes :math:`x_j = d_{j,i}`, this method assigns the
scaled cost :math:`|d_{j,i} - \hat{x}_j|` to every admissible level
different from the query value.
"""
levels = v.levels
for i in range(len(levels)):
level_val = levels[i]
if level_val == x_val:
# No cost if this is the same value
continue
cost = int(abs(x_val - level_val) * self._obj_scale)
if cost > 0:
mu_var = v.xget(mu=i)
self.add_soft([-mu_var], weight=cost)
def _add_soft_l1_threshold_numeric(
self,
x_val: float,
v: FeatureVar,
) -> None:
thresholds = v.split_threshold_values
n_thresholds = len(thresholds)
if n_thresholds == 0:
return
for idx, threshold in enumerate(thresholds):
if x_val <= threshold:
if idx == 0:
cost = threshold - x_val
else:
cost = min(
threshold - x_val,
threshold - thresholds[idx - 1],
)
lit = v.xget(mu=idx)
else:
if idx == n_thresholds - 1:
cost = x_val - threshold
else:
cost = min(
x_val - threshold,
thresholds[idx + 1] - threshold,
)
lit = -v.xget(mu=idx)
if cost > 0:
self.add_soft([lit], weight=float(cost))
def _add_soft_l1_discrete_maxorc(self, x_val: float, v: FeatureVar) -> None:
levels = v.levels
for i in range(len(levels)):
level_val = levels[i]
if level_val == x_val:
continue
cost = float(abs(x_val - level_val))
if cost > 0:
mu_var = v.xget(mu=i)
self.add_soft([-mu_var], weight=cost)
[docs]
def _get_intervals_cost(self, levels: Array1D, x: float) -> list[int]:
r"""
Compute interval costs relative to :math:`\hat{x}_j`.
This helper is used for continuous features whose Boolean encoding
selects one interval between consecutive threshold levels.
Returns
-------
list[int]
Scaled costs for the interval selectors attached to that
continuous feature.
"""
intervals_cost = np.zeros(len(levels) - 1, dtype=int)
for i in range(len(intervals_cost)):
if levels[i] < x <= levels[i + 1]:
continue
if levels[i] > x:
intervals_cost[i] = int(abs(x - levels[i]) * self._obj_scale)
elif levels[i + 1] < x:
intervals_cost[i] = int(
abs(x - levels[i + 1]) * self._obj_scale
) # Distance to nearest endpoint of the interval
return intervals_cost.tolist()
[docs]
@validate_call
def set_majority_class(
self,
y: NonNegativeInt,
*,
op: NonNegativeInt = 0,
) -> None:
r"""
Enforce the target class through hard score constraints.
For every competing class, the backend encodes the pairwise dominance
condition
.. math::
f_y(x) \ge f_c(x) + \varepsilon_c
Raises
------
ValueError
If ``y`` is not a valid class index.
"""
if y >= self.n_classes:
msg = f"Expected class < {self.n_classes}, got {y}"
raise ValueError(msg)
self._set_majority_class(y, op=op)
[docs]
def _set_majority_class(
self,
y: NonNegativeInt,
*,
op: NonNegativeInt = 0,
) -> None:
r"""
Encode the target-class dominance constraints in MaxSAT form.
The intended inequality is
.. math::
f_y(x) - f_c(x) \ge \varepsilon_c,
\qquad \forall c \neq y.
Since MaxSAT works on Boolean clauses, the leaf contributions of each
tree are converted into an integer pseudo-Boolean constraint.
"""
if self._hard_voting:
self._set_hard_voting_majority_class(y, op=op)
return
scale = 10000 # Scale factor for probabilities
for class_ in range(self.n_classes):
if class_ == y:
continue
# Compute the score difference for each leaf in each tree
# We need: sum over trees of (prob_y - prob_c) >= epsilon
# For each tree, compute min and max possible contributions
tree_contributions: list[list[tuple[int, int]]] = []
for tree in self.trees:
contribs: list[tuple[int, int]] = []
for leaf in tree.leaves:
prob_y = leaf.value[op, y]
prob_c = leaf.value[op, class_]
diff = int((prob_y - prob_c) * scale)
leaf_var = tree[leaf.node_id]
contribs.append((leaf_var, diff))
tree_contributions.append(contribs)
# Threshold for comparison
epsilon = self._epsilon if class_ < y else 0
# Use iterative bounds propagation to encode the constraint
# For each tree, we track the range of possible partial sums
self._encode_weighted_sum_constraint(tree_contributions, epsilon)
[docs]
def _set_hard_voting_majority_class(
self,
y: NonNegativeInt,
*,
op: NonNegativeInt = 0,
) -> None:
r"""
Encode hard-voting dominance using cardinality constraints.
For hard voting, each tree contributes one unit vote to the class of
its active leaf. The target class is enforced through
.. math::
\sum_t \mathbb{1}[\hat{y}_t(x)=y]
\ge
\sum_t \mathbb{1}[\hat{y}_t(x)=c] + \varepsilon_c,
\qquad \forall c \neq y.
"""
function = self.function
for class_ in range(self.n_classes):
if class_ == y:
continue
epsilon = self._epsilon if class_ < y else 0
self._encode_cardinality_difference_constraint(
function[op, y],
function[op, class_],
epsilon,
)
[docs]
def _encode_cardinality_difference_constraint(
self,
positive_lits: list[int],
negative_lits: list[int],
threshold: int,
) -> None:
r"""
Encode ``sum(pos) - sum(neg) >= threshold`` as a cardinality bound.
The difference is rewritten as
.. math::
\sum pos + \sum \lnot neg \ge |neg| + threshold.
This method is used for hard-voting score comparisons, where the
contributions are Boolean and the threshold is an integer margin.
Parameters
----------
positive_lits
List of positive literals (e.g., votes for the target class).
negative_lits
List of negative literals (e.g., votes for the competing class).
threshold
Integer margin for the difference.
Raises
------
ImportError
If pysat.card is not installed.
"""
lits = [*positive_lits, *[-lit for lit in negative_lits]]
bound = len(negative_lits) + threshold
if bound <= 0:
return
if not lits or bound > len(lits):
self.add_garbage(self.add_hard([], return_id=True))
return
if CardEnc is None or CardEncType is None:
msg = "pysat.card is required for hard-voting cardinality encoding."
raise ImportError(msg)
card = CardEnc.atleast(
lits=lits,
bound=bound,
vpool=self.vpool,
encoding=CardEncType.totalizer,
)
for clause in card.clauses:
self.add_garbage(
self.add_hard(clause, return_id=True) # pyright: ignore[reportUnknownArgumentType]
)
[docs]
def _encode_weighted_sum_constraint(
self,
tree_contributions: list[list[tuple[int, int]]],
threshold: int,
) -> None:
r"""
Encode a hard pseudo-Boolean constraint on tree contributions.
This method encodes
.. math::
\sum_i a_i \ell_i \ge \tau,
where ``threshold`` provides the integer bound :math:`\tau`. This is
the MaxSAT realization of the score comparison
.. math::
f_y(x) - f_c(x) \ge \varepsilon_c
Parameters
----------
tree_contributions : list[list[tuple[int, int]]]
List of contributions for each tree.
threshold : int
Threshold for the sum of contributions.
Raises
------
ImportError
If pysat.pb is not installed.
"""
normalized_ctbs, effective_bound = self._normalize_tree_contributions(
tree_contributions,
threshold,
)
lits: list[int] = []
weights: list[int] = []
for contribs in normalized_ctbs:
for leaf_var, diff in contribs:
if diff == 0:
continue # contributes nothing, can be ignored
lits.append(leaf_var)
weights.append(diff)
if not lits: # degenerate case
if effective_bound > 0:
self.add_garbage(self.add_hard([], return_id=True))
return
# Encode sum(weights_i * lits_i) >= effective_bound
if PBEnc is None or PBEncType is None:
msg = "pysat.pb is required for this operation."
msg += " The pysat[pblib] extra dependency is required."
msg += " It does not work properly on Windows."
raise ImportError(msg)
pb = PBEnc.atleast(
lits=lits,
weights=weights,
bound=effective_bound,
vpool=self.vpool,
encoding=PBEncType.adder,
)
for clause in pb.clauses:
self.add_garbage(
self.add_hard(clause, return_id=True) # pyright: ignore[reportUnknownArgumentType]
)
[docs]
@staticmethod
def _normalize_tree_contributions(
tree_contributions: list[list[tuple[int, int]]],
threshold: int,
) -> tuple[list[list[tuple[int, int]]], int]:
r"""
Normalize contributions using the one-active-leaf property per tree.
Subtracting the minimum contribution of each tree preserves the score
inequality and removes avoidable negative coefficients. A final GCD
reduction keeps the PB coefficients smaller before calling PySAT.
Parameters
----------
tree_contributions : list[list[tuple[int, int]]]
List of contributions for each tree.
threshold : int
Threshold for the sum of contributions.
Returns
-------
tuple[list[list[tuple[int, int]]], int]
Normalized contributions and the adjusted threshold.
"""
normalized: list[list[tuple[int, int]]] = []
effective_bound = threshold
for contribs in tree_contributions:
if not contribs:
continue
min_diff = min(diff for _, diff in contribs)
effective_bound -= min_diff
normalized.append(
[
(leaf_var, diff - min_diff)
for leaf_var, diff in contribs
if diff != min_diff
]
)
flat_weights = [
diff for contribs in normalized for _, diff in contribs if diff > 0
]
if not flat_weights:
return normalized, effective_bound
common_divisor = flat_weights[0]
for weight in flat_weights[1:]:
common_divisor = gcd(common_divisor, weight)
if common_divisor == 1:
break
if common_divisor <= 1:
return normalized, effective_bound
reduced = [
[
(leaf_var, diff // common_divisor)
for leaf_var, diff in contribs
if diff > 0
]
for contribs in normalized
]
reduced_bound = ceil(effective_bound / common_divisor)
return reduced, reduced_bound
[docs]
def cleanup(self) -> None:
r"""
Remove query-specific soft and hard clauses from the MaxSAT model.
The structural Boolean encoding of :math:`x` and :math:`p_{t,\ell}`
stays in place, while the objective clauses and target-class clauses
for the previous query :math:`\hat{x}` are removed.
"""
self._clean_soft()
for idx in sorted(self.garbage_list(), reverse=True):
self.hard.pop(idx)
self.remove_garbage()
self.invalidate_solver_state()
def _set_builder(self, model_type: Type) -> None:
match model_type:
case Model.Type.MAXSAT:
self._builder = ModelBuilderFactory.MAXSAT()