Source code for ocean.mip._model

from collections.abc import Iterable
from enum import Enum

import gurobipy as gp
import numpy as np
from pydantic import validate_call

from ..abc import Mapper
from ..feature import Feature
from ..tree import Tree
from ..typing import (
    Array1D,
    Key,
    NonNegativeArray1D,
    NonNegativeInt,
    Unit,
)
from ._base import BaseModel
from ._builders.model import ModelBuilder, ModelBuilderFactory
from ._managers import FeatureManager, GarbageManager, TreeManager
from ._typing import Objective
from ._variables import FeatureVar, TreeVar


[docs] class Model(BaseModel, FeatureManager, TreeManager, GarbageManager): """ Mixed-integer programming formulation for tree ensemble explanations. The feature variables encode the counterfactual point ``x`` and the tree variables encode the active leaf or path decisions ``p_(t,l)``. """ DEFAULT_EPSILON: Unit = 1.0 / (2.0**16) DEFAULT_NUM_EPSILON: Unit = 1e-6 # 1.0 / (2.0**20) MIN_NUMERIC_TOL: float = 1e-9
[docs] class Type(Enum): MIP = "MIP"
# Constraints for the majority class. _scores: gp.tupledict[tuple[NonNegativeInt, NonNegativeInt], gp.Constr] # Model builder for the ensemble. _builder: ModelBuilder # Numerical parameters for the model. # - epsilon: the minimum difference between two scores. # - num_epsilon: the minimum difference between two numerical values. _epsilon: Unit _num_epsilon: Unit def __init__( self, trees: Iterable[Tree], mapper: Mapper[Feature], *, weights: NonNegativeArray1D | None = None, n_isolators: NonNegativeInt = 0, max_samples: NonNegativeInt = 0, isolation_threshold: float | None = None, name: str = "OCEAN", env: gp.Env | None = None, epsilon: Unit = DEFAULT_EPSILON, num_epsilon: Unit = DEFAULT_NUM_EPSILON, model_type: "Model.Type" = Type.MIP, flow_type: "TreeVar.FlowType" = TreeVar.FlowType.CONTINUOUS, ) -> None: """ Initialize an empty MIP model for a parsed ensemble. Parameters ---------- trees Parsed trees whose leaves define the ensemble class scores. mapper Feature mapper describing how processed query columns map to decision variables. weights Optional non-negative tree weights. n_isolators Number of isolation trees contributing to the auxiliary isolation constraint on the path length. max_samples Reference sample count used by the isolation-forest extension. isolation_threshold Optional isolation-score cutoff in ``(0, 1]``. When omitted, the classic average-path-length threshold is used. name Name of the underlying Gurobi model. env Optional Gurobi environment. epsilon Classification margin used in the pairwise score constraints. num_epsilon Strictness constant used in numerical split implications. model_type Backend builder variant. flow_type Encoding used for the tree flow variables. """ # Initialize the super models. BaseModel.__init__(self, name=name, env=env) TreeManager.__init__( self, trees=trees, weights=weights, n_isolators=n_isolators, max_samples=max_samples, isolation_threshold=isolation_threshold, flow_type=flow_type, ) FeatureManager.__init__(self, mapper=mapper) GarbageManager.__init__(self) self._set_weights(weights=weights) self._max_samples = max_samples self._epsilon = epsilon self._num_epsilon = num_epsilon self._scores = gp.tupledict() self._set_builder(model_type=model_type)
[docs] def build(self) -> None: r""" Create the decision variables and structural constraints of the model. This step introduces the feature variables encoding :math:`x`, the tree-path variables :math:`p_{t,\ell}`, and the constraints linking both so that exactly one leaf is active in each tree and the selected leaves are consistent with the feature values. """ self.build_features(self) self.build_trees(self) self._builder.build(self, trees=self.trees, mapper=self.mapper) self._set_isolation() self._stabilize_tolerances()
@property def epsilon(self) -> Unit: return self._epsilon @property def num_epsilon(self) -> Unit: return self._num_epsilon
[docs] def add_objective( self, x: Array1D, *, norm: int = 1, sense: int = gp.GRB.MINIMIZE, ) -> None: r""" Attach the distance objective :math:`d(x, \hat{x})` to the MIP model. Parameters ---------- x Query point :math:`\hat{x}` in the processed feature space. The code parameter is named ``x``, but mathematically it represents the query. norm Distance norm used for :math:`d(x, \hat{x})`. The MIP backend supports :math:`L_0`, :math:`L_1`, and :math:`L_2`. sense Optimization sense passed to Gurobi. Counterfactual search uses minimization. """ objective = self._add_objective(x=x, norm=norm) self.explanation.query = np.asarray(x, dtype=np.float64).ravel().copy() self.setObjective(objective, sense=sense)
[docs] @validate_call def set_majority_class( self, y: NonNegativeInt, *, op: NonNegativeInt = 0, ) -> None: r""" Enforce the target class through pairwise score constraints. For every competing class, this adds .. math:: f_y(x) \ge f_c(x) + \varepsilon_c Raises ------ ValueError If ``y`` is not a valid class index. """ if y >= self.n_classes: msg = f"Expected class < {self.n_classes}, got {y}" raise ValueError(msg) self._set_majority_class(y, op=op)
[docs] def clear_majority_class(self) -> None: r""" Remove the current target-class constraints. This deletes stored inequalities of the form .. math:: f_y(x) \ge f_c(x) + \varepsilon_c """ self.remove(self._scores) self._scores.clear()
[docs] def cleanup(self) -> None: r""" Remove query-specific objective auxiliaries and class constraints. After ``cleanup()``, the structural encoding of :math:`x` and :math:`p_{t,\ell}` remains, but temporary constraints created for a specific query :math:`\hat{x}` are removed. """ self.clear_majority_class() self.remove_garbage(self)
def _set_builder(self, model_type: Type) -> None: match model_type: case Model.Type.MIP: epsilon = self._num_epsilon self._builder = ModelBuilderFactory.MIP(epsilon=epsilon)
[docs] def _set_majority_class( self, y: NonNegativeInt, *, op: NonNegativeInt, ) -> None: r""" Encode the pairwise score dominance constraints. The stored constraints are .. math:: f_y(x) - f_c(x) \ge \varepsilon_c, \qquad \forall c \neq y. """ function = self.function for class_ in range(self.n_classes): if class_ == y: continue rhs = self._epsilon if class_ < y else 0.0 lhs = (function[op, y] - function[op, class_]).item() self._scores[op, class_] = self.addConstr(lhs >= rhs)
[docs] def _set_isolation(self) -> None: """ Add the optional isolation-forest length constraint. When isolation trees are present, this constrains the aggregate path length variable to remain above the minimum admissible value. """ if self.n_isolators == 0: return self.addConstr(self.length >= self.min_length)
def _stabilize_tolerances(self) -> None: """ Tighten solver tolerances to preserve exact split semantics. The formulation relies on strict score margins and split implications. With Gurobi's default feasibility tolerances, a continuous feature can end up slightly on the wrong side of a split threshold while the path variables still select the opposite branch, which then disagrees with exact sklearn tree traversal. Using the minimum supported tolerance and asking the MIP solver to prioritize integral solutions reduces those numerically loose incumbents without an explicit re-optimization pass. """ safe_tol = self.MIN_NUMERIC_TOL feasibility_tol = float(self.getParamInfo("FeasibilityTol")[2]) if safe_tol < feasibility_tol: self.setParam("FeasibilityTol", safe_tol) int_feas_tol = float(self.getParamInfo("IntFeasTol")[2]) if safe_tol < int_feas_tol: self.setParam("IntFeasTol", safe_tol) integrality_focus = int(self.getParamInfo("IntegralityFocus")[2]) if integrality_focus < 1: self.setParam("IntegralityFocus", 1)
[docs] def _add_objective(self, x: Array1D, norm: int) -> Objective: r""" Build the symbolic objective expression for :math:`d(x, \hat{x})`. Each feature contributes an :math:`L_0`, :math:`L_1`, or :math:`L_2` term. One-hot encoded coordinates use a factor :math:`1/2` for :math:`L_1` and :math:`L_2` so that switching a category is not double counted. Returns ------- Objective Linear or quadratic expression representing :math:`d(x, \hat{x})`. Raises ------ ValueError If ``x`` does not have the expected size or ``norm`` is unsupported. """ if x.size != self.mapper.n_columns: msg = f"Expected {self.mapper.n_columns} values, got {x.size}" raise ValueError(msg) if norm not in {0, 1, 2}: msg = f"Unsupported norm: {norm}" raise ValueError(msg) if norm == 0: x_arr = np.asarray(x, dtype=np.float64).ravel() return self._add_l0_objective(x_arr) names = self.mapper.names is_ohe = [self.mapper[name].is_one_hot_encoded for name in names] variables = map(self.vget, range(self.n_columns)) if norm == 1: return sum( ( self.L1(x, v, is_ohe=ohe) for x, v, ohe in zip(x, variables, is_ohe, strict=True) ), start=gp.LinExpr(), ) return sum( ( self.L2(x, v, is_ohe=ohe) for x, v, ohe in zip(x, variables, is_ohe, strict=True) ), start=gp.QuadExpr(), )
def _add_l0_objective(self, x: Array1D) -> gp.LinExpr: objective = gp.LinExpr() indexer = self.mapper.idx for name, feature in self.mapper.items(): if feature.is_one_hot_encoded: objective += self._l0_one_hot(x, name=name, feature=feature) continue idx = indexer.get(name) x_j = float(x[idx]) if feature.is_binary: objective += self._l0_binary(x_j, feature=feature) elif feature.is_discrete: objective += self._l0_discrete(x_j, feature=feature) else: objective += self._l0_continuous(x_j, feature=feature) return objective def _l0_one_hot( self, x: Array1D, *, name: Key, feature: FeatureVar, ) -> gp.LinExpr: for code in feature.codes: if np.isclose(x[self.mapper.idx.get(name, code)], 1.0): return 1.0 - feature.xget(code) msg = f"Could not determine the active query code for feature {name!r}." raise ValueError(msg) @staticmethod def _l0_binary( x: float, *, feature: FeatureVar, ) -> gp.LinExpr: return gp.LinExpr(feature.xget()) if np.isclose( x, 0.0 ) else 1.0 - feature.xget() @staticmethod def _continuous_interval_index(levels: Array1D, x: float) -> int: upper = np.asarray(levels, dtype=float)[1:] idx = int(np.searchsorted(upper, x, side="left")) return max(0, min(idx, len(upper) - 1)) def _l0_continuous( self, x: float, *, feature: FeatureVar, ) -> gp.LinExpr: n_intervals = len(feature.levels) - 1 if n_intervals <= 1: return gp.LinExpr() query_interval = self._continuous_interval_index(feature.levels, x) changed = self.addVar(vtype=gp.GRB.BINARY) garbage: list[gp.Var | gp.Constr] = [changed] if query_interval > 0: garbage.append( self.addConstr(feature.mget(query_interval - 1) >= 1 - changed) ) if query_interval < n_intervals - 1: garbage.append( self.addConstr(feature.mget(query_interval + 1) <= changed) ) self.add_garbage(*garbage) return gp.LinExpr(changed) @staticmethod def _l0_discrete( x: float, *, feature: FeatureVar, ) -> gp.LinExpr: levels = np.asarray(feature.levels, dtype=float) if len(levels) <= 1: return gp.LinExpr() thresholds = np.asarray(feature.thresholds, dtype=float) buckets = np.searchsorted(thresholds, levels, side="left") query_bucket = int(np.searchsorted(thresholds, x, side="left")) start = int(np.searchsorted(buckets, query_bucket, side="left")) end = int(np.searchsorted(buckets, query_bucket, side="right")) - 1 last = len(levels) - 1 if start == 0: if end == last: return gp.LinExpr() return gp.LinExpr(feature.mget(end)) if end == last: return 1.0 - feature.mget(start - 1) return 1.0 - feature.mget(start - 1) + feature.mget(end)
[docs] def L1(self, x: np.float64, v: gp.Var, *, is_ohe: bool) -> gp.LinExpr: r""" Return the MIP :math:`L_1` contribution of one coordinate of :math:`x`. This creates an auxiliary variable :math:`u` such that :math:`u \ge |x_j - \hat{x}_j|`, where ``v`` encodes the counterfactual coordinate :math:`x_j` and the code parameter ``x`` stores the query coordinate :math:`\hat{x}_j`. Parameters ---------- x Query coordinate :math:`\hat{x}_j`. v Model variable encoding the corresponding coordinate of :math:`x`. is_ohe Whether this coordinate belongs to a one-hot block :math:`u_{j,k}`. If so, the returned term is halved. Returns ------- gp.LinExpr Linear expression equal to the contribution of that coordinate to :math:`d_1(x, \hat{x})`. """ u = self.addVar() neg = self.addConstr(u >= v - x) pos = self.addConstr(u >= x - v) self.add_garbage(u, pos, neg) return gp.LinExpr(u) / 2.0 if is_ohe else gp.LinExpr(u)
[docs] @staticmethod def L2(x: np.float64, v: gp.Var, *, is_ohe: bool) -> gp.QuadExpr: r""" Return the MIP :math:`L_2` contribution of one coordinate of :math:`x`. The returned expression is :math:`(x_j - \hat{x}_j)^2`, halved for one-hot encoded coordinates to preserve the same category-switch semantics as the :math:`L_1` objective. Returns ------- gp.QuadExpr Quadratic expression equal to the contribution of that coordinate to :math:`d_2(x, \hat{x})`. """ return ((v - x) ** 2) / 2.0 if is_ohe else (v - x) ** 2