Source code for ocean.cp._explainer

import time
import traceback
import warnings

from ortools.sat.python import cp_model as cp
from sklearn.ensemble import AdaBoostClassifier

from ..abc import Mapper
from ..feature import Feature
from ..tree import parse_ensembles
from ..typing import (
    Array1D,
    BaseExplainableEnsemble,
    BaseExplainer,
    NonNegativeArray1D,
    NonNegativeInt,
    PositiveInt,
)
from ._env import ENV
from ._explanation import Explanation
from ._model import Model


[docs] class Explainer(Model, BaseExplainer): """Constraint programming explainer for tree ensemble classifiers.""" def __init__( self, ensemble: BaseExplainableEnsemble, *, mapper: Mapper[Feature], weights: NonNegativeArray1D | None = None, epsilon: int = Model.DEFAULT_EPSILON, model_type: "Model.Type" = Model.Type.CP, ) -> None: ensembles = (ensemble,) trees = parse_ensembles(*ensembles, mapper=mapper) if isinstance(ensemble, AdaBoostClassifier): weights = ensemble.estimator_weights_ Model.__init__( self, trees, mapper=mapper, weights=weights, epsilon=epsilon, model_type=model_type, ) self.build() self.solver = ENV.solver
[docs] def get_objective_value(self) -> float: """ Return the scaled objective value of the last CP-SAT solve. Returns ------- float Objective value rescaled back to the user-facing distance units. """ return self.solver.ObjectiveValue() / self._obj_scale
[docs] def get_distance(self) -> float: """ Return the post-processed distance of the last CF. Returns ------- float Post-processed :math:`L_p` distance for the last successful solve. Raises ------ RuntimeError If no explanation has been computed yet. """ query = self.explanation.query if query.size == 0: msg = "No explanation has been computed yet." raise RuntimeError(msg) norm = getattr(self, "_distance_norm", None) if norm is None: msg = "No explanation has been computed yet." raise RuntimeError(msg) counterfactual = self.explanation.x distance = 0.0 for name, feature in self.mapper.items(): if feature.is_one_hot_encoded: feature_distance = 0.0 for code in feature.codes: idx = self.mapper.idx.get(name, code) delta = float(counterfactual[idx]) - float(query[idx]) feature_distance += abs(delta) ** norm distance += feature_distance / 2.0 else: idx = self.mapper.idx.get(name) delta = float(counterfactual[idx]) - float(query[idx]) distance += abs(delta) ** norm if norm != 1: distance **= 1.0 / norm return float(distance)
[docs] def get_solving_status(self) -> str: """ Return the status string from the latest CP-SAT solve. Returns ------- str Solver status such as ``"OPTIMAL"``, ``"FEASIBLE"``, or ``"INFEASIBLE"``. """ return self.Status
[docs] def get_anytime_solutions(self) -> list[dict[str, float]] | None: """ Return intermediate solutions collected during the last solve. Returns ------- list[dict[str, float]] | None Time-stamped incumbent objective values when ``return_callback`` was enabled in :meth:`explain`, otherwise ``None``. """ if self.callback is not None: return self.callback.sollist return None
[docs] def explain( self, x: Array1D, *, y: NonNegativeInt, norm: PositiveInt, return_callback: bool = False, verbose: bool = False, max_time: int = 60, num_workers: int | None = None, random_seed: int = 42, clean_up: bool = True, ) -> Explanation | None: """ Solve one counterfactual query with the CP-SAT backend. Parameters ---------- x Query instance in the processed feature space. y Target class enforced by the counterfactual. norm Integer distance norm used by the CP objective. return_callback Whether to record incumbent solutions during the search. verbose Whether to enable CP-SAT search logging. max_time Time limit in seconds. num_workers Optional number of CP-SAT workers. random_seed Random seed passed to CP-SAT. clean_up Whether to remove query-specific constraints after the solve. Returns ------- Explanation | None The decoded counterfactual, or ``None`` when no feasible counterfactual is found within the given limits. Raises ------ RuntimeError If CP-SAT reports an invalid model or an unexpected status. """ self.solver.parameters.log_search_progress = verbose self.solver.parameters.max_time_in_seconds = max_time self.solver.parameters.random_seed = random_seed if num_workers is not None: self.solver.parameters.num_workers = num_workers self.add_objective(x, norm=norm) self.set_majority_class(y=y) self.callback: MySolCallback | None = ( MySolCallback(starttime=time.time(), _obj_scale=self._obj_scale) if return_callback else None ) _ = self.solver.Solve(self, solution_callback=self.callback) status = self.solver.status_name() self.Status = status cf_status_ok = True match status: case "OPTIMAL": pass case "FEASIBLE": msg = "A valid CF was found, but it might be " msg += "suboptimal as the constraint programming " msg += "solver could not prove optimality within " msg += "the given time frame. \n It can however certify" msg += " that no counterfactual can be closer than" msg += f" {self.solver.BestObjectiveBound()}." warnings.warn(msg, category=UserWarning, stacklevel=2) case "INFEASIBLE": msg = "There are no feasible counterfactuals for this query." msg += " If there should be one, please check the model " msg += "constraints or report this issue to the developers." warnings.warn(msg, category=UserWarning, stacklevel=2) cf_status_ok = False case "MODEL_INVALID": msg = "The constraint programming model is invalid. " msg += "Please check the model constraints or report" msg += " this issue to the developers." raise RuntimeError(msg) case "UNKNOWN": msg = "The constraint programming solver could " msg += "not find any valid CF within the given time frame." msg += " Try increasing the time limit." warnings.warn(msg, category=UserWarning, stacklevel=2) cf_status_ok = False case _: msg = "Unexpected solver status: " + status raise RuntimeError(msg) if not cf_status_ok: self.cleanup() return None self.explanation.query = x self._distance_norm = norm if clean_up: self.cleanup() return self.explanation
class MySolCallback(cp.CpSolverSolutionCallback): """Save intermediate solutions.""" def __init__(self, starttime: float, _obj_scale: float) -> None: cp.CpSolverSolutionCallback.__init__(self) self.sollist: list[dict[str, float]] = [] self.__solution_count = 0 self.starttime = starttime self._obj_scale = _obj_scale def on_solution_callback(self) -> None: try: self.__solution_count += 1 t = time.time() objval = self.ObjectiveValue() / self._obj_scale self.addSol(objval, t - self.starttime) except Exception: traceback.print_exc() raise def solution_count(self) -> NonNegativeInt: return self.__solution_count def addSol(self, objval: float, time: float) -> None: self.sollist.append({"objective_value": objval, "time": time})