Source code for ocean.typing

"""Shared type aliases and protocols used throughout OCEAN."""

from collections.abc import Mapping
from typing import TYPE_CHECKING, Annotated, Protocol

import numpy as np
import pandas as pd
import xgboost as xgb
from pydantic import Field
from sklearn.ensemble import (
    AdaBoostClassifier,
    IsolationForest,
    RandomForestClassifier,
)

type BaseExplainableEnsemble = (
    RandomForestClassifier | xgb.XGBClassifier | AdaBoostClassifier
)
type ParsableEnsemble = BaseExplainableEnsemble | IsolationForest | xgb.Booster

type Number = float
type NonNegativeNumber = Annotated[Number, Field(ge=0.0)]
type PositiveInt = Annotated[int, Field(ge=1)]
type NonNegativeInt = Annotated[int, Field(ge=0)]
type NonNegative = Annotated[np.float64, Field(ge=0.0)]
type Unit = Annotated[float, Field(gt=0.0, lt=1.0)]
type UnitO = Annotated[float, Field(ge=0.0, lt=1.0)]
type NodeId = Annotated[np.int64, Field(ge=-1)]

# Key alias:
# - This is used to represent the name of a feature
#   or the code of a one-hot encoded feature.
type Key = int | str

# Index alias:
if TYPE_CHECKING:
    type Index1L = pd.Index[Key]
    type Index = pd.Index[int] | pd.Index[str] | pd.MultiIndex
else:
    type Index1L = pd.Index
    type Index = pd.Index | pd.MultiIndex

# Arrays aliases

# Int arrays:
# 1D, 2D, and nD arrays of integers.
IntDtype = np.dtype[np.int64]
IntArray1D = np.ndarray[tuple[int], IntDtype]
IntArray2D = np.ndarray[tuple[int, int], IntDtype]
IntArray = np.ndarray[tuple[int, ...], IntDtype]

# Positive Int arrays:
# 1D, 2D, and nD arrays of positive integers.
NonNegativeIntDtype = np.dtype[np.uint32]
NonNegativeIntArray1D = np.ndarray[tuple[int], NonNegativeIntDtype]
NonNegativeIntArray2D = np.ndarray[tuple[int, int], NonNegativeIntDtype]
NonNegativeIntArray = np.ndarray[tuple[int, ...], NonNegativeIntDtype]

# Float arrays:
# 1D, 2D, and nD arrays of floats (64 bits).
Dtype = np.dtype[np.float64]
Array1D = np.ndarray[tuple[int], Dtype]
Array2D = np.ndarray[tuple[int, int], Dtype]
Array = np.ndarray[tuple[int, ...], Dtype]

# 1D, 2D, and nD arrays of non-negative floats (64 bits).
NonNegativeDtype = np.dtype[NonNegative]
NonNegativeArray1D = np.ndarray[tuple[int], NonNegativeDtype]
NonNegativeArray2D = np.ndarray[tuple[int, int], NonNegativeDtype]
NonNegativeArray = np.ndarray[tuple[int, ...], NonNegativeDtype]

# NodeId arrays:
# 1D:
NodeIdDtype = np.dtype[NodeId]
NodeIdArray1D = np.ndarray[tuple[int], NodeIdDtype]


# Scikit-learn Tree alias:
# This class is only used for type hinting purposes.
[docs] class SKLearnTree(Protocol): """Protocol capturing the subset of the sklearn tree API OCEAN uses.""" node_count: PositiveInt max_depth: NonNegativeInt feature: NonNegativeIntArray1D threshold: Array1D children_left: NodeIdArray1D children_right: NodeIdArray1D n_node_samples: NonNegativeIntArray1D value: Array
type XGBTree = pd.DataFrame
[docs] class BaseExplanation(Protocol): """Protocol implemented by explanation containers returned by explainers."""
[docs] def to_numpy(self) -> Array1D: ...
[docs] def to_series(self) -> pd.Series: ...
@property def x(self) -> Array1D: ... @property def value(self) -> Mapping[Key, Key | Number]: ... @property def query(self) -> Array1D: ... @staticmethod def _next_float32_up(value: float) -> float: return float( np.nextafter( np.float32(value), np.float32(np.inf), dtype=np.float32, ) ) @staticmethod def _next_float32_down(value: float) -> float: return float( np.nextafter( np.float32(value), np.float32(-np.inf), dtype=np.float32, ) )
[docs] class BaseExplainer(Protocol): """Protocol implemented by all public OCEAN explainers."""
[docs] def get_objective_value(self) -> float: ...
[docs] def get_distance(self) -> float: ...
[docs] def get_solving_status(self) -> str: ...
[docs] def get_anytime_solutions(self) -> list[dict[str, float]] | None: ...
[docs] def explain(
self, x: Array1D, *, y: NonNegativeInt, norm: NonNegativeInt, return_callback: bool = False, verbose: bool = False, max_time: int = 60, num_workers: int | None = None, random_seed: int = 42, clean_up: bool = True, ) -> BaseExplanation | None: ...
[docs] def cleanup(self) -> None: ...
__all__ = [ "Array", "Array1D", "Array2D", "BaseExplainableEnsemble", "BaseExplainer", "BaseExplanation", "Dtype", "Index", "Index1L", "IntArray", "IntArray1D", "IntArray2D", "IntDtype", "Key", "NodeId", "NodeIdArray1D", "NodeIdDtype", "NonNegative", "NonNegativeArray", "NonNegativeArray1D", "NonNegativeArray2D", "NonNegativeDtype", "NonNegativeInt", "NonNegativeIntArray", "NonNegativeIntArray1D", "NonNegativeIntArray2D", "NonNegativeIntDtype", "NonNegativeNumber", "Number", "ParsableEnsemble", "PositiveInt", "SKLearnTree", "Unit", "UnitO", "XGBTree", ]