Source code for ocean.maxsat._explainer

from __future__ import annotations

import signal
import warnings
from typing import TYPE_CHECKING, Any

from sklearn.ensemble import AdaBoostClassifier

from ..tree import parse_ensembles
from ..typing import (
    Array1D,
    BaseExplainableEnsemble,
    BaseExplainer,
    NonNegativeInt,
    PositiveInt,
)
from ._env import ENV
from ._model import Model

if TYPE_CHECKING:
    from ..abc import Mapper
    from ..feature import Feature
    from ._explanation import Explanation


def handler(signum: Any, frame: Any) -> TimeoutError:  # noqa: ANN401, ARG001
    msg = "Timeout for maxsat!"
    raise TimeoutError(msg)


[docs] class Explainer(Model, BaseExplainer): """MaxSAT-based explainer for tree ensemble classifiers.""" Status: str = "UNKNOWN" def __init__( self, ensemble: BaseExplainableEnsemble, *, mapper: Mapper[Feature], weights: Array1D | None = None, epsilon: int = Model.DEFAULT_EPSILON, model_type: Model.Type = Model.Type.MAXSAT, ) -> None: ensembles = (ensemble,) trees = parse_ensembles(*ensembles, mapper=mapper) if isinstance(ensemble, AdaBoostClassifier): weights = ensemble.estimator_weights_ Model.__init__( self, trees, mapper=mapper, weights=weights, epsilon=epsilon, model_type=model_type, ) self.build() self.solver = ENV.solver
[docs] def get_objective_value(self) -> float: """ Return the weighted MaxSAT objective value of the last solve. Returns ------- float Objective value rescaled back to the user-facing distance units. """ return self.solver.cost / self._obj_scale
[docs] def get_distance(self) -> float: """ Return the post-processed distance of the last CF. Returns ------- float Post-processed :math:`L_p` distance for the last successful solve. Raises ------ RuntimeError If no explanation has been computed yet. """ query = self.explanation.query if query.size == 0: msg = "No explanation has been computed yet." raise RuntimeError(msg) norm = getattr(self, "_distance_norm", None) if norm is None: msg = "No explanation has been computed yet." raise RuntimeError(msg) counterfactual = self.explanation.x distance = 0.0 for name, feature in self.mapper.items(): if feature.is_one_hot_encoded: feature_distance = 0.0 for code in feature.codes: idx = self.mapper.idx.get(name, code) delta = float(counterfactual[idx]) - float(query[idx]) feature_distance += abs(delta) ** norm distance += feature_distance / 2.0 else: idx = self.mapper.idx.get(name) delta = float(counterfactual[idx]) - float(query[idx]) distance += abs(delta) ** norm if norm != 1: distance **= 1.0 / norm return float(distance)
[docs] def get_solving_status(self) -> str: """ Return the status of the latest MaxSAT solve. Returns ------- str Status string such as ``"OPTIMAL"`` or ``"INFEASIBLE"``. """ return self.Status
[docs] def get_anytime_solutions(self) -> list[dict[str, float]] | None: """ Return the intermediate solution trace for the last MaxSAT solve. Returns ------- None The MaxSAT backend currently exposes only the final solution. """ _ = self.Status return None
[docs] def explain( self, x: Array1D, *, y: NonNegativeInt, norm: PositiveInt, return_callback: bool = False, verbose: bool = False, max_time: int = 60, num_workers: int | None = None, random_seed: int = 42, clean_up: bool = True, ) -> Explanation | None: """ Solve one counterfactual query with the weighted MaxSAT backend. Parameters ---------- x Query instance in the processed feature space. y Target class enforced by the counterfactual. norm Distance norm. The MaxSAT backend currently supports only ``1``. return_callback Accepted for API compatibility but ignored by this backend. verbose Whether to enable RC2 logging. max_time Time limit in seconds. num_workers Optional thread count forwarded to the MaxSAT solver. random_seed Accepted for API compatibility but currently ignored. clean_up Whether to remove query-specific clauses after the solve. Returns ------- Explanation | None The decoded counterfactual, or ``None`` when no feasible counterfactual is found within the given limits. Raises ------ RuntimeError If the MaxSAT solver raises an error that is not UNSAT or timeout. """ if return_callback: default_seed = 42 msg = "There are no callbacks for maxsat." if random_seed != default_seed: msg = "There are no callbacks/random_seed for maxsat." warnings.warn(msg, category=UserWarning, stacklevel=2) self.solver.TimeLimit = max_time self.solver.n_threads = num_workers if num_workers is not None else 1 self.solver.verbose = verbose # Add objective soft clauses self.add_objective(x, norm=norm) # Add hard constraints for target class self.set_majority_class(y=y) signal.signal(signal.SIGALRM, handler=handler) signal.alarm(max_time) try: # Solve the MaxSAT problem self.solver.solve(self) self.Status = "OPTIMAL" except RuntimeError as e: if "UNSAT" in str(e): self.Status = "INFEASIBLE" msg = "There are no feasible counterfactuals for this query." msg += " If there should be one, please check the model " msg += "constraints or report this issue to the developers." warnings.warn(msg, category=UserWarning, stacklevel=2) if clean_up: self.cleanup() return None raise except TimeoutError as exc: warnings.warn(str(exc), category=UserWarning, stacklevel=2) signal.alarm(0) if clean_up: self.cleanup() return None finally: signal.alarm(0) # Store the query in the explanation self.explanation.query = x self._distance_norm = norm # Clean up for next solve if clean_up: self.cleanup() return self.explanation