Source code for ocean.mip._explainer

import time
import warnings
from typing import cast

import gurobipy as gp
from sklearn.ensemble import AdaBoostClassifier, IsolationForest

from ..abc import Mapper
from ..feature import Feature
from ..tree import parse_ensembles
from ..typing import (
    Array1D,
    BaseExplainableEnsemble,
    BaseExplainer,
    NonNegativeInt,
    PositiveInt,
)
from ._explanation import Explanation
from ._model import Model
from ._variables import TreeVar


[docs] class Explainer(Model, BaseExplainer): """Mixed-integer programming explainer for tree ensemble classifiers.""" def __init__( self, ensemble: BaseExplainableEnsemble, *, mapper: Mapper[Feature], weights: Array1D | None = None, isolation: IsolationForest | None = None, name: str = "OCEAN", env: gp.Env | None = None, epsilon: float = Model.DEFAULT_EPSILON, num_epsilon: float = Model.DEFAULT_NUM_EPSILON, model_type: "Model.Type" = Model.Type.MIP, flow_type: "TreeVar.FlowType" = TreeVar.FlowType.CONTINUOUS, ) -> None: ensembles = (ensemble,) if isolation is None else (ensemble, isolation) n_isolators, max_samples = self._get_isolation_params(isolation) trees = parse_ensembles(*ensembles, mapper=mapper) if isinstance(ensemble, AdaBoostClassifier): weights = ensemble.estimator_weights_ Model.__init__( self, trees, mapper=mapper, weights=weights, n_isolators=n_isolators, max_samples=max_samples, name=name, env=env, epsilon=epsilon, num_epsilon=num_epsilon, model_type=model_type, flow_type=flow_type, ) self.build()
[docs] def get_objective_value(self) -> float: """ Return the solver objective value of the last optimization run. Returns ------- float Objective value reported by Gurobi for the latest solve. """ return self.ObjVal
[docs] def get_distance(self) -> float: """ Return the post-processed distance of the last CF. Returns ------- float Post-processed :math:`L_p` distance for the last successful solve. Raises ------ RuntimeError If no explanation has been computed yet. """ query = self.explanation.query if query.size == 0: msg = "No explanation has been computed yet." raise RuntimeError(msg) norm = getattr(self, "_distance_norm", None) if norm is None: msg = "No explanation has been computed yet." raise RuntimeError(msg) counterfactual = self.explanation.x distance = 0.0 for name, feature in self.mapper.items(): if feature.is_one_hot_encoded: feature_distance = 0.0 for code in feature.codes: idx = self.mapper.idx.get(name, code) delta = float(counterfactual[idx]) - float(query[idx]) feature_distance += abs(delta) ** norm distance += feature_distance / 2.0 else: idx = self.mapper.idx.get(name) delta = float(counterfactual[idx]) - float(query[idx]) distance += abs(delta) ** norm if norm != 1: distance **= 1.0 / norm return float(distance)
[docs] def get_solving_status(self) -> str: """ Return the latest Gurobi solve status as a readable string. Returns ------- str Current model status such as ``"OPTIMAL"`` or ``"TIME_LIMIT"``. """ gurobi_statuses = { 1: "LOADED", 2: "OPTIMAL", 3: "INFEASIBLE", 4: "INF_OR_UNBD", 5: "UNBOUNDED", 6: "CUTOFF", 7: "ITERATION_LIMIT", 8: "NODE_LIMIT", 9: "TIME_LIMIT", 10: "SOLUTION_LIMIT", 11: "INTERRUPTED", 12: "NUMERIC", 13: "SUBOPTIMAL", 14: "INPROGRESS", 15: "USER_OBJ_LIMIT", 16: "WORK_LIMIT", } return gurobi_statuses[self.Status]
[docs] def get_anytime_solutions(self) -> list[dict[str, float]] | None: """ Return incumbent solutions collected during the last solve. Returns ------- list[dict[str, float]] | None Time-stamped incumbent objective values when ``return_callback`` was enabled in :meth:`explain`, otherwise ``None``. """ callback = cast( "SolutionCallback | None", getattr(self, "callback", None), ) if callback is None: return None return callback.sollist
[docs] def explain( self, x: Array1D, *, y: NonNegativeInt, norm: PositiveInt, return_callback: bool = False, verbose: bool = False, max_time: int = 60, num_workers: int | None = None, random_seed: int = 42, clean_up: bool = True, ) -> Explanation | None: """ Solve one counterfactual query with the MIP backend. Parameters ---------- x Query instance in the processed feature space. y Target class enforced by the counterfactual. norm Distance norm. The MIP backend supports ``1`` and ``2``. return_callback Whether to collect incumbent solutions through a Gurobi callback. verbose Whether to print Gurobi logs. max_time Time limit in seconds. num_workers Optional Gurobi thread count. random_seed Random seed passed to Gurobi. clean_up Whether to remove query-specific constraints after the solve. Returns ------- Explanation | None The decoded counterfactual, or ``None`` when no feasible counterfactual is found within the given limits. Raises ------ RuntimeError If the solver stops for an unexpected status that is not handled by the explainer. """ self.setParam("LogToConsole", int(verbose)) self.setParam("TimeLimit", max_time) self.setParam("Seed", random_seed) if num_workers is not None: self.setParam("Threads", num_workers) self.add_objective(x, norm=norm) self.set_majority_class(y=y) if return_callback: self.callback = SolutionCallback(starttime=time.time()) self.optimize(self.callback) else: self.optimize() status = self.get_solving_status() if status == "INFEASIBLE": msg = "There are no feasible counterfactuals for this query." msg += " If there should be one, please check the model " msg += "constraints or report this issue to the developers." warnings.warn(msg, category=UserWarning, stacklevel=2) return None if status != "OPTIMAL": if self.SolCount > 0: msg = "A valid CF was found, but it might be " msg += "suboptimal as the MILP " msg += "solver could not prove optimality within " msg += "the given time frame. \n It can however certify" msg += " that no counterfactual can be closer than" msg += f" {self.ObjBound}." warnings.warn(msg, category=UserWarning, stacklevel=2) elif status == "TIME_LIMIT": msg = "The MILP solver could not find any" msg += " valid CF within the given time frame." msg += " Try increasing the time limit." warnings.warn(msg, category=UserWarning, stacklevel=2) return None elif status == "MEM_LIMIT": msg = "The MILP solver could not find any" msg += " valid CF within the given max memory." msg += " Try increasing the memory limit." warnings.warn(msg, category=UserWarning, stacklevel=2) return None else: msg = "The MILP solver could not find any" msg += " valid CF for an un-handled reason." msg += "Unexpected solver status: " + status raise RuntimeError(msg) self.explanation.query = x self._distance_norm = norm if clean_up: self.cleanup() return self.explanation
@staticmethod def _get_isolation_params( isolation: IsolationForest | None, ) -> tuple[NonNegativeInt, NonNegativeInt]: if isolation is not None: return len(isolation), int(isolation.max_samples_) # pyright: ignore[reportUnknownArgumentType] return 0, 0
class SolutionCallback: """Collect incumbent solutions reported by Gurobi during optimization.""" def __init__(self, starttime: float) -> None: self.starttime = starttime self.sollist: list[dict[str, float]] = [] def __call__(self, model: gp.Model, where: int) -> None: if where == gp.GRB.Callback.MIPSOL: # Query the objective value of the new solution best_objective = model.cbGet(gp.GRB.Callback.MIPSOL_OBJ) self.sollist.append({ "objective_value": best_objective, "time": time.time() - self.starttime, })