Source code for ocean.maxsat._explanation

from collections.abc import Mapping

import numpy as np
import pandas as pd

from ..abc import Mapper
from ..typing import Array1D, BaseExplanation, Key, Number
from ._env import ENV
from ._variables import FeatureVar


[docs] class Explanation(Mapper[FeatureVar], BaseExplanation): """Concrete explanation container returned by the MaxSAT backend.""" _x: Array1D = np.zeros((0,), dtype=int)
[docs] def vget(self, i: int) -> int: name = self.names[i] if self[name].is_one_hot_encoded: code = self.codes[i] return self[name].xget(code=code) if self[name].is_numeric: if self[name].has_threshold_encoding: idx = self._get_threshold_index(name) if idx is None: msg = "No threshold variable is available for this feature." raise ValueError(msg) return self[name].xget(mu=idx) j: int = int( np.searchsorted(self[name].levels, self._x[i], side="left") # pyright: ignore[reportUnknownArgumentType] ) return self[name].xget(mu=j) return self[name].xget()
def _get_active_mu_index( self, name: Key, for_discrete: bool = False, # noqa: FBT001, FBT002 ) -> int: """ Find which mu variable is set to true for a numeric feature. Returns: Index of the active mu variable, or 0 if none found. """ if for_discrete: # For discrete: one mu per level n_vars = len(self[name].levels) else: # For continuous: one mu per interval n_vars = len(self[name].levels) - 1 for mu_idx in range(n_vars): var = self[name].xget(mu=mu_idx) if ENV.solver.model(var) > 0: return mu_idx return 0 # Default to first if none found
[docs] def to_series(self) -> "pd.Series[float]": values: list[float] = [] for f in range(self.n_columns): name = self.names[f] if self[name].is_one_hot_encoded: code = self.codes[f] var = self[name].xget(code=code) values.append(ENV.solver.model(var)) elif self[name].is_continuous: if self[name].has_threshold_encoding: values.append( self.format_threshold_continuous_value( f, name, ) ) else: mu_idx = self._get_active_mu_index(name, for_discrete=False) values.append( self.format_continuous_value( f, mu_idx, list(self[name].levels), ) ) elif self[name].is_discrete: if self[name].has_threshold_encoding: values.append(self.format_threshold_discrete_value(f, name)) else: # For discrete features, mu[i] means value == levels[i] mu_idx = self._get_active_mu_index(name, for_discrete=True) levels = list(self[name].levels) discrete_val = int(levels[mu_idx]) values.append( self.format_discrete_value( f, discrete_val, self[name].levels, ) ) elif self[name].is_binary: var = self[name].xget() values.append(ENV.solver.model(var)) else: var = self[name].xget() values.append(ENV.solver.model(var)) return pd.Series(values, index=self.columns)
[docs] def to_numpy(self) -> Array1D: return ( self.to_series() .to_frame() .T[self.columns] .to_numpy() .flatten() .astype(np.float64) )
@property def x(self) -> Array1D: return self.to_numpy() @property def value(self) -> Mapping[Key, Key | Number]: def get(v: FeatureVar) -> Key | Number: if v.is_one_hot_encoded: for code in v.codes: if ENV.solver.model(v.xget(code)) > 0: return code if v.is_numeric: f = list(self.values()).index(v) if v.has_threshold_encoding: if v.is_discrete: return self.format_threshold_discrete_value( f, self.names[f], ) return self.format_threshold_continuous_value( f, self.names[f], ) if v.is_discrete: idx = self._get_active_mu_index( self.names[f], for_discrete=True, ) val = int(v.levels[idx]) return self.format_discrete_value( f, val, v.levels, ) idx = self._get_active_mu_index( self.names[f], for_discrete=False, ) return self.format_continuous_value( f, idx, list(v.levels), ) x = v.xget() return int(ENV.solver.model(x)) return self.reduce(get)
[docs] def format_continuous_value( self, f: int, idx: int, levels: list[float], ) -> float: if self.query.shape[0] == 0: return float(levels[idx] + levels[idx + 1]) / 2 j = 0 query_arr = np.asarray(self.query, dtype=float).ravel() while query_arr[f] > levels[j + 1]: j += 1 if j == idx: value = float(query_arr[f]) elif j < idx: value = self._next_float32_up(levels[idx]) else: value = self._next_float32_down(levels[idx + 1]) return value
[docs] def format_discrete_value( self, f: int, val: int, thresholds: Array1D, ) -> float: if self.query.shape[0] == 0: return val query_arr = np.asarray(self.query, dtype=float).ravel() j_x = np.searchsorted(thresholds, query_arr[f], side="left") j_val = np.searchsorted(thresholds, val, side="left") if j_x != j_val: return float(val) return float(query_arr[f])
def _get_threshold_states(self, name: Key) -> list[bool]: thresholds = self[name].split_threshold_values return [ bool(ENV.solver.model(self[name].xget(mu=idx)) > 0) for idx in range(len(thresholds)) ] def _get_threshold_index(self, name: Key) -> int | None: states = self._get_threshold_states(name) for idx, state in enumerate(states): if state: return idx if states: return len(states) - 1 return None
[docs] def format_threshold_continuous_value(self, f: int, name: Key) -> float: thresholds = list(self[name].split_threshold_values) if len(thresholds) == 0: if self.query.shape[0] != 0: return float(np.asarray(self.query, dtype=float).ravel()[f]) return float(self[name].levels[0] + self[name].levels[-1]) / 2 query_arr = np.asarray(self.query, dtype=float).ravel() query_val = float(query_arr[f]) if self.query.shape[0] != 0 else None first_true = next( ( idx for idx, state in enumerate(self._get_threshold_states(name)) if state ), None, ) result: float if first_true == 0: upper = thresholds[0] if query_val is not None and query_val <= upper: result = query_val else: result = self._next_float32_down(upper) elif first_true is None: lower = thresholds[-1] if query_val is not None and query_val > lower: result = query_val else: result = self._next_float32_up(lower) else: lower = thresholds[first_true - 1] upper = thresholds[first_true] if query_val is not None and lower < query_val <= upper: result = query_val elif query_val is not None and query_val <= lower: result = self._next_float32_up(lower) else: result = self._next_float32_down(upper) return result
[docs] def format_threshold_discrete_value(self, f: int, name: Key) -> float: levels = np.asarray(self[name].levels, dtype=np.float64) thresholds = list(self[name].split_threshold_values) if len(thresholds) == 0: if self.query.shape[0] != 0: return float(np.asarray(self.query, dtype=float).ravel()[f]) return float(levels[0]) query_arr = np.asarray(self.query, dtype=float).ravel() query_val = float(query_arr[f]) if self.query.shape[0] != 0 else None first_true = next( ( idx for idx, state in enumerate(self._get_threshold_states(name)) if state ), None, ) if first_true == 0: lower = float("-inf") elif first_true is None: lower = thresholds[-1] else: lower = thresholds[first_true - 1] upper = float("inf") if first_true is None else thresholds[first_true] candidates = levels[(levels > lower) & (levels <= upper)] if candidates.size == 0: if first_true == 0: candidates = np.array([levels[0]]) elif first_true is None: candidates = np.array([levels[-1]]) else: insertion = np.searchsorted(levels, upper, side="left") insertion = np.min(insertion, len(levels) - 1) candidates = np.array([levels[insertion]]) if query_val is not None and np.any(np.isclose(candidates, query_val)): return query_val return float(candidates[0])
@property def query(self) -> Array1D: return self._x @query.setter def query(self, value: Array1D) -> None: self._x = value def __repr__(self) -> str: mapping = self.value prefix = f"{self.__class__.__name__}:\n" root = self._repr(mapping) suffix = "" return prefix + root + suffix
__all__ = ["Explanation"]