"""Shared type aliases and protocols used throughout OCEAN."""
from collections.abc import Mapping
from typing import TYPE_CHECKING, Annotated, Protocol
import numpy as np
import pandas as pd
import xgboost as xgb
from pydantic import Field
from sklearn.ensemble import (
AdaBoostClassifier,
IsolationForest,
RandomForestClassifier,
)
type BaseExplainableEnsemble = (
RandomForestClassifier | xgb.XGBClassifier | AdaBoostClassifier
)
type ParsableEnsemble = BaseExplainableEnsemble | IsolationForest | xgb.Booster
type Number = float
type NonNegativeNumber = Annotated[Number, Field(ge=0.0)]
type PositiveInt = Annotated[int, Field(ge=1)]
type NonNegativeInt = Annotated[int, Field(ge=0)]
type NonNegative = Annotated[np.float64, Field(ge=0.0)]
type Unit = Annotated[float, Field(gt=0.0, lt=1.0)]
type UnitO = Annotated[float, Field(ge=0.0, lt=1.0)]
type NodeId = Annotated[np.int64, Field(ge=-1)]
# Key alias:
# - This is used to represent the name of a feature
# or the code of a one-hot encoded feature.
type Key = int | str
# Index alias:
if TYPE_CHECKING:
type Index1L = pd.Index[Key]
type Index = pd.Index[int] | pd.Index[str] | pd.MultiIndex
else:
type Index1L = pd.Index
type Index = pd.Index | pd.MultiIndex
# Arrays aliases
# Int arrays:
# 1D, 2D, and nD arrays of integers.
IntDtype = np.dtype[np.int64]
IntArray1D = np.ndarray[tuple[int], IntDtype]
IntArray2D = np.ndarray[tuple[int, int], IntDtype]
IntArray = np.ndarray[tuple[int, ...], IntDtype]
# Positive Int arrays:
# 1D, 2D, and nD arrays of positive integers.
NonNegativeIntDtype = np.dtype[np.uint32]
NonNegativeIntArray1D = np.ndarray[tuple[int], NonNegativeIntDtype]
NonNegativeIntArray2D = np.ndarray[tuple[int, int], NonNegativeIntDtype]
NonNegativeIntArray = np.ndarray[tuple[int, ...], NonNegativeIntDtype]
# Float arrays:
# 1D, 2D, and nD arrays of floats (64 bits).
Dtype = np.dtype[np.float64]
Array1D = np.ndarray[tuple[int], Dtype]
Array2D = np.ndarray[tuple[int, int], Dtype]
Array = np.ndarray[tuple[int, ...], Dtype]
# 1D, 2D, and nD arrays of non-negative floats (64 bits).
NonNegativeDtype = np.dtype[NonNegative]
NonNegativeArray1D = np.ndarray[tuple[int], NonNegativeDtype]
NonNegativeArray2D = np.ndarray[tuple[int, int], NonNegativeDtype]
NonNegativeArray = np.ndarray[tuple[int, ...], NonNegativeDtype]
# NodeId arrays:
# 1D:
NodeIdDtype = np.dtype[NodeId]
NodeIdArray1D = np.ndarray[tuple[int], NodeIdDtype]
# Scikit-learn Tree alias:
# This class is only used for type hinting purposes.
[docs]
class SKLearnTree(Protocol):
"""Protocol capturing the subset of the sklearn tree API OCEAN uses."""
node_count: PositiveInt
max_depth: NonNegativeInt
feature: NonNegativeIntArray1D
threshold: Array1D
children_left: NodeIdArray1D
children_right: NodeIdArray1D
n_node_samples: NonNegativeIntArray1D
value: Array
type XGBTree = pd.DataFrame
[docs]
class BaseExplanation(Protocol):
"""Protocol implemented by explanation containers returned by explainers."""
[docs]
def to_numpy(self) -> Array1D: ...
[docs]
def to_series(self) -> pd.Series: ...
@property
def x(self) -> Array1D: ...
@property
def value(self) -> Mapping[Key, Key | Number]: ...
@property
def query(self) -> Array1D: ...
@staticmethod
def _next_float32_up(value: float) -> float:
return float(
np.nextafter(
np.float32(value),
np.float32(np.inf),
dtype=np.float32,
)
)
@staticmethod
def _next_float32_down(value: float) -> float:
return float(
np.nextafter(
np.float32(value),
np.float32(-np.inf),
dtype=np.float32,
)
)
[docs]
class BaseExplainer(Protocol):
"""Protocol implemented by all public OCEAN explainers."""
[docs]
def get_objective_value(self) -> float: ...
[docs]
def get_distance(self) -> float: ...
[docs]
def get_solving_status(self) -> str: ...
[docs]
def get_anytime_solutions(self) -> list[dict[str, float]] | None: ...
self,
x: Array1D,
*,
y: NonNegativeInt,
norm: NonNegativeInt,
return_callback: bool = False,
verbose: bool = False,
max_time: int = 60,
num_workers: int | None = None,
random_seed: int = 42,
clean_up: bool = True,
) -> BaseExplanation | None: ...
[docs]
def cleanup(self) -> None: ...
__all__ = [
"Array",
"Array1D",
"Array2D",
"BaseExplainableEnsemble",
"BaseExplainer",
"BaseExplanation",
"Dtype",
"Index",
"Index1L",
"IntArray",
"IntArray1D",
"IntArray2D",
"IntDtype",
"Key",
"NodeId",
"NodeIdArray1D",
"NodeIdDtype",
"NonNegative",
"NonNegativeArray",
"NonNegativeArray1D",
"NonNegativeArray2D",
"NonNegativeDtype",
"NonNegativeInt",
"NonNegativeIntArray",
"NonNegativeIntArray1D",
"NonNegativeIntArray2D",
"NonNegativeIntDtype",
"NonNegativeNumber",
"Number",
"ParsableEnsemble",
"PositiveInt",
"SKLearnTree",
"Unit",
"UnitO",
"XGBTree",
]