from collections.abc import Mapping
import numpy as np
import pandas as pd
from ..abc import Mapper
from ..typing import Array1D, BaseExplanation, Key, Number
from ._env import ENV
from ._variables import FeatureVar
[docs]
class Explanation(Mapper[FeatureVar], BaseExplanation):
"""Concrete explanation container returned by the MaxSAT backend."""
_epsilon: float = float(np.finfo(np.float32).eps)
_x: Array1D = np.zeros((0,), dtype=int)
[docs]
def vget(self, i: int) -> int:
name = self.names[i]
if self[name].is_one_hot_encoded:
code = self.codes[i]
return self[name].xget(code=code)
if self[name].is_numeric:
j: int = int(
np.searchsorted(self[name].levels, self._x[i], side="left") # pyright: ignore[reportUnknownArgumentType]
)
return self[name].xget(mu=j)
return self[name].xget()
def _get_active_mu_index(
self,
name: Key,
for_discrete: bool = False, # noqa: FBT001, FBT002
) -> int:
"""
Find which mu variable is set to true for a numeric feature.
Returns:
Index of the active mu variable, or 0 if none found.
"""
if for_discrete:
# For discrete: one mu per level
n_vars = len(self[name].levels)
else:
# For continuous: one mu per interval
n_vars = len(self[name].levels) - 1
for mu_idx in range(n_vars):
var = self[name].xget(mu=mu_idx)
if ENV.solver.model(var) > 0:
return mu_idx
return 0 # Default to first if none found
[docs]
def to_series(self) -> "pd.Series[float]":
values: list[float] = []
for f in range(self.n_columns):
name = self.names[f]
if self[name].is_one_hot_encoded:
code = self.codes[f]
var = self[name].xget(code=code)
values.append(ENV.solver.model(var))
elif self[name].is_continuous:
mu_idx = self._get_active_mu_index(name, for_discrete=False)
values.append(
self.format_continuous_value(
f, mu_idx, list(self[name].levels)
)
)
elif self[name].is_discrete:
# For discrete features, mu[i] means value == levels[i]
mu_idx = self._get_active_mu_index(name, for_discrete=True)
levels = list(self[name].levels)
discrete_val = int(levels[mu_idx])
values.append(
self.format_discrete_value(
f, discrete_val, self[name].levels
)
)
elif self[name].is_binary:
var = self[name].xget()
values.append(ENV.solver.model(var))
else:
var = self[name].xget()
values.append(ENV.solver.model(var))
return pd.Series(values, index=self.columns)
[docs]
def to_numpy(self) -> Array1D:
return (
self
.to_series()
.to_frame()
.T[self.columns]
.to_numpy()
.flatten()
.astype(np.float64)
)
@property
def x(self) -> Array1D:
return self.to_numpy()
@property
def value(self) -> Mapping[Key, Key | Number]:
def get(v: FeatureVar) -> Key | Number:
if v.is_one_hot_encoded:
for code in v.codes:
if ENV.solver.model(v.xget(code)) > 0:
return code
if v.is_numeric:
f = list(self.values()).index(v)
if v.is_discrete:
idx = self._get_active_mu_index(
self.names[f], for_discrete=True
)
val = int(v.levels[idx])
return self.format_discrete_value(f, val, v.levels)
idx = self._get_active_mu_index(
self.names[f], for_discrete=False
)
return self.format_continuous_value(
f,
idx,
list(v.levels),
)
x = v.xget()
return int(ENV.solver.model(x))
return self.reduce(get)
@property
def query(self) -> Array1D:
return self._x
@query.setter
def query(self, value: Array1D) -> None:
self._x = value
def __repr__(self) -> str:
mapping = self.value
prefix = f"{self.__class__.__name__}:\n"
root = self._repr(mapping)
suffix = ""
return prefix + root + suffix
__all__ = ["Explanation"]