import time
import warnings
from typing import cast
import gurobipy as gp
import numpy as np
from sklearn.ensemble import AdaBoostClassifier, IsolationForest
from ..abc import Mapper
from ..feature import Feature
from ..tree import parse_ensembles
from ..typing import (
Array1D,
BaseExplainableEnsemble,
BaseExplainer,
NonNegativeInt,
)
from ._explanation import Explanation
from ._model import Model
from ._variables import TreeVar
[docs]
class Explainer(Model, BaseExplainer):
"""Mixed-integer programming explainer for tree ensemble classifiers."""
_output_values: Array1D | None
def __init__(
self,
ensemble: BaseExplainableEnsemble,
*,
mapper: Mapper[Feature],
weights: Array1D | None = None,
isolation: IsolationForest | None = None,
isolation_threshold: float | None = None,
name: str = "OCEAN",
env: gp.Env | None = None,
epsilon: float = Model.DEFAULT_EPSILON,
num_epsilon: float = Model.DEFAULT_NUM_EPSILON,
model_type: "Model.Type" = Model.Type.MIP,
flow_type: "TreeVar.FlowType" = TreeVar.FlowType.CONTINUOUS,
) -> None:
ensembles = (ensemble,) if isolation is None else (ensemble, isolation)
n_isolators, max_samples = self._get_isolation_params(isolation)
trees = parse_ensembles(*ensembles, mapper=mapper)
if isinstance(ensemble, AdaBoostClassifier):
weights = ensemble.estimator_weights_
Model.__init__(
self,
trees,
mapper=mapper,
weights=weights,
n_isolators=n_isolators,
max_samples=max_samples,
isolation_threshold=isolation_threshold,
name=name,
env=env,
epsilon=epsilon,
num_epsilon=num_epsilon,
model_type=model_type,
flow_type=flow_type,
)
self._output_values = None
self.build()
[docs]
def vget(self, i: int) -> gp.Var:
var = super().vget(i)
if self._output_values is None:
return var
return cast("gp.Var", _ValueProxy(var, float(self._output_values[i])))
[docs]
def get_objective_value(self) -> float:
"""
Return the solver objective value of the last optimization run.
Returns
-------
float
Objective value reported by Gurobi for the latest solve.
"""
return self.ObjVal
[docs]
def get_distance(self) -> float:
"""
Return the post-processed distance of the last CF.
Returns
-------
float
Post-processed :math:`L_p` distance for the last successful solve.
Raises
------
RuntimeError
If no explanation has been computed yet.
"""
query = self.explanation.query
if query.size == 0:
msg = "No explanation has been computed yet."
raise RuntimeError(msg)
norm = getattr(self, "_distance_norm", None)
if norm is None:
msg = "No explanation has been computed yet."
raise RuntimeError(msg)
counterfactual = self.explanation.x
distance = 0.0
for name, feature in self.mapper.items():
if feature.is_one_hot_encoded:
feature_distance = 0.0
for code in feature.codes:
idx = self.mapper.idx.get(name, code)
delta = float(counterfactual[idx]) - float(query[idx])
feature_distance += (
0.0
if norm == 0 and np.isclose(delta, 0.0)
else abs(delta) ** norm
)
distance += feature_distance / 2.0
else:
idx = self.mapper.idx.get(name)
delta = float(counterfactual[idx]) - float(query[idx])
distance += (
0.0
if norm == 0 and np.isclose(delta, 0.0)
else abs(delta) ** norm
)
if norm not in {0, 1}:
distance **= 1.0 / norm
return float(distance)
[docs]
def get_solving_status(self) -> str:
"""
Return the latest Gurobi solve status as a readable string.
Returns
-------
str
Current model status such as ``"OPTIMAL"`` or ``"TIME_LIMIT"``.
"""
gurobi_statuses = {
1: "LOADED",
2: "OPTIMAL",
3: "INFEASIBLE",
4: "INF_OR_UNBD",
5: "UNBOUNDED",
6: "CUTOFF",
7: "ITERATION_LIMIT",
8: "NODE_LIMIT",
9: "TIME_LIMIT",
10: "SOLUTION_LIMIT",
11: "INTERRUPTED",
12: "NUMERIC",
13: "SUBOPTIMAL",
14: "INPROGRESS",
15: "USER_OBJ_LIMIT",
16: "WORK_LIMIT",
}
return gurobi_statuses[self.Status]
[docs]
def get_anytime_solutions(self) -> list[dict[str, float]] | None:
"""
Return incumbent solutions collected during the last solve.
Returns
-------
list[dict[str, float]] | None
Time-stamped incumbent objective values when ``return_callback``
was enabled in :meth:`explain`, otherwise ``None``.
"""
callback = cast(
"SolutionCallback | None",
getattr(self, "callback", None),
)
if callback is None:
return None
return callback.sollist
[docs]
def explain(
self,
x: Array1D,
*,
y: NonNegativeInt,
norm: NonNegativeInt,
return_callback: bool = False,
verbose: bool = False,
max_time: int = 60,
num_workers: int | None = None,
random_seed: int = 42,
clean_up: bool = True,
) -> Explanation | None:
"""
Solve one counterfactual query with the MIP backend.
Parameters
----------
x
Query instance in the processed feature space.
y
Target class enforced by the counterfactual.
norm
Distance norm. The MIP backend supports ``0``, ``1``, and ``2``.
return_callback
Whether to collect incumbent solutions through a Gurobi callback.
verbose
Whether to print Gurobi logs.
max_time
Time limit in seconds.
num_workers
Optional Gurobi thread count.
random_seed
Random seed passed to Gurobi.
clean_up
Whether to remove query-specific constraints after the solve.
Returns
-------
Explanation | None
The decoded counterfactual, or ``None`` when no feasible
counterfactual is found within the given limits.
Raises
------
RuntimeError
If the solver stops for an unexpected status that is not handled
by the explainer.
"""
self._output_values = None
self.setParam("LogToConsole", int(verbose))
self.setParam("TimeLimit", max_time)
self.setParam("Seed", random_seed)
if num_workers is not None:
self.setParam("Threads", num_workers)
self.add_objective(x, norm=norm)
self.set_majority_class(y=y)
if return_callback:
self.callback = SolutionCallback(starttime=time.time())
self.optimize(self.callback)
else:
self.optimize()
status = self.get_solving_status()
if status == "INFEASIBLE":
msg = "There are no feasible counterfactuals for this query."
msg += " If there should be one, please check the model "
msg += "constraints or report this issue to the developers."
warnings.warn(msg, category=UserWarning, stacklevel=2)
return None
if status != "OPTIMAL":
if self.SolCount > 0:
msg = "A valid CF was found, but it might be "
msg += "suboptimal as the MILP "
msg += "solver could not prove optimality within "
msg += "the given time frame. \n It can however certify"
msg += " that no counterfactual can be closer than"
msg += f" {self.ObjBound}."
warnings.warn(msg, category=UserWarning, stacklevel=2)
elif status == "TIME_LIMIT":
msg = "The MILP solver could not find any"
msg += " valid CF within the given time frame."
msg += " Try increasing the time limit."
warnings.warn(msg, category=UserWarning, stacklevel=2)
return None
elif status == "MEM_LIMIT":
msg = "The MILP solver could not find any"
msg += " valid CF within the given max memory."
msg += " Try increasing the memory limit."
warnings.warn(msg, category=UserWarning, stacklevel=2)
return None
else:
msg = "The MILP solver could not find any"
msg += " valid CF for an un-handled reason."
msg += "Unexpected solver status: " + status
raise RuntimeError(msg)
self.explanation.query = x
self._distance_norm = norm
self._output_values = np.asarray(self.explanation.x, dtype=np.float64)
if clean_up:
self.cleanup()
return self.explanation
@staticmethod
def _get_isolation_params(
isolation: IsolationForest | None,
) -> tuple[NonNegativeInt, NonNegativeInt]:
if isolation is not None:
return len(isolation), int(isolation.max_samples_) # pyright: ignore[reportUnknownArgumentType]
return 0, 0
class SolutionCallback:
"""Collect incumbent solutions reported by Gurobi during optimization."""
def __init__(self, starttime: float) -> None:
self.starttime = starttime
self.sollist: list[dict[str, float]] = []
def __call__(self, model: gp.Model, where: int) -> None:
if where == gp.GRB.Callback.MIPSOL:
# Query the objective value of the new solution
best_objective = model.cbGet(gp.GRB.Callback.MIPSOL_OBJ)
self.sollist.append({
"objective_value": best_objective,
"time": time.time() - self.starttime,
})
class _ValueProxy:
def __init__(self, var: gp.Var, value: float) -> None:
self._var = var
self._value = value
@property
def X(self) -> float:
return self._value
def __getattr__(self, name: str) -> object:
return getattr(self._var, name)