Source code for ocean.abc._mapper

from collections.abc import Callable, Iterator, Mapping
from typing import Concatenate, Literal, Protocol, overload

import pandas as pd

from ..typing import Index, Index1L, Key, NonNegativeInt, Number, PositiveInt


class Value(Protocol):
    """Protocol implemented by values stored inside a mapper."""

    @property
    def is_one_hot_encoded(self) -> bool: ...

    @property
    def codes(self) -> tuple[Key, ...]: ...


type Getter[K, V] = Callable[Concatenate[K, ...], V]
type Args[V] = tuple[Mapping[Key, V], Index]


[docs] class Indexer[K, V]: """Memoized dispatcher used to resolve column positions efficiently.""" _getters: tuple[Getter[K, V], ...] _memo: dict[tuple[K, ...], V]
[docs] def __init__(self, *getters: Getter[K, V]) -> None: """Store the lookup functions used for each supported key arity.""" self._getters = getters self._memo = {}
def _get(self, *keys: K) -> V: """ Dispatch to the getter matching the number of provided keys. Returns ------- V Value returned by the matching getter. """ return self._getters[len(keys) - 1](*keys)
[docs] def get(self, *keys: K) -> V: """ Return the cached lookup result for ``keys``. Returns ------- V Memoized lookup result. """ if keys not in self._memo: self._memo[*keys] = self._get(*keys) return self._memo[*keys]
[docs] class Mapper[V: Value](Mapping[Key, V]): r""" Map original feature names to transformed columns. Both the counterfactual point :math:`x` and the query :math:`\hat{x}` are represented in this transformed coordinate system. The mapper is the low-level bridge between original feature semantics and processed columns. """ NAME_LEVEL: Literal[0] = 0 CODE_LEVEL: Literal[1] = 1 _columns: Index _mapping: dict[Key, V] _indexer: Indexer[Key, NonNegativeInt] | None = None _names: tuple[Key, ...] | None = None _codes: tuple[Key, ...] | None = None @overload def __init__(self) -> None: ... @overload def __init__( self, mapping: Mapping[Key, V], *, columns: Index, validate: bool = True, ) -> None: ... @overload def __init__(self, mapping: "Mapper[V]") -> None: ...
[docs] def __init__( self, mapping: "Mapping[Key, V] | None" = None, *, columns: Index | None = None, validate: bool = True, ) -> None: """ Initialize a mapper from feature metadata and transformed columns. Parameters ---------- mapping Mapping from original feature names to metadata objects. columns Processed pandas index describing the transformed coordinates. validate Whether to verify that ``mapping`` and ``columns`` are consistent. """ mapping, columns = self._get_args(mapping, columns) self._validate_args(mapping, columns, validate=validate) self._columns = columns self._mapping = dict(mapping)
@property def n_columns(self) -> NonNegativeInt: return len(self.columns) @property def n_levels(self) -> NonNegativeInt: return self.columns.nlevels @property def is_multi_level(self) -> bool: return self.n_levels > 1 @property def columns(self) -> Index: return self._columns @property def names(self) -> tuple[Key, ...]: if self._names is None: names: Index1L = self.columns.get_level_values(self.NAME_LEVEL) self._names = tuple(names) return self._names @property def codes(self) -> tuple[Key, ...]: if self._codes is None: if not self.is_multi_level: msg = "No one-hot encoded features found" raise ValueError(msg) codes: Index1L = self.columns.get_level_values(self.CODE_LEVEL) self._codes = tuple(codes.map(str)) return self._codes @property def idx(self) -> Indexer[Key, NonNegativeInt]: if self._indexer is None: self._indexer = self._add_indexer() return self._indexer
[docs] def reduce[S](self, reducer: Callable[[V], S]) -> Mapping[Key, S]: """ Apply ``reducer`` to each mapped value and keep the original keys. Returns ------- Mapping[Key, S] Reduced mapping with the same keys. """ return {name: reducer(value) for name, value in self.items()}
[docs] def apply[U: Value](self, func: Callable[[Key, V], U]) -> "Mapper[U]": """ Transform the mapped values while preserving the column structure. Returns ------- Mapper[U] New mapper with transformed values and the same columns. """ mapping = {name: func(name, value) for name, value in self.items()} return Mapper(mapping, columns=self.columns, validate=False)
[docs] def __len__(self) -> NonNegativeInt: """ Return the number of original features tracked by the mapper. Returns ------- NonNegativeInt Number of original features. """ return len(self._mapping)
[docs] def __iter__(self) -> Iterator[Key]: """ Iterate over original feature names. Returns ------- Iterator[Key] Iterator over mapper keys. """ return iter(self._mapping)
[docs] def __getitem__(self, key: Key) -> V: """ Return the mapped value associated with feature name ``key``. Returns ------- V Value associated with ``key``. """ return self._mapping[key]
[docs] @staticmethod def _get_args( mapping: Mapping[Key, V] | None, columns: Index | None, ) -> Args[V]: """ Normalize constructor arguments into a mapping and a column index. Returns ------- Args[V] Pair ``(mapping, columns)`` ready for validation. """ if mapping is None: mapping = {} if isinstance(mapping, Mapper): columns = mapping.columns elif columns is None: columns = pd.Index([], dtype=object) return mapping, columns
[docs] def _validate_args( self, mapping: Mapping[Key, V], columns: Index, *, validate: bool = True, ) -> None: """ Check that the mapping metadata matches the transformed columns. Raises ------ ValueError If feature names or one-hot codes do not match the column index. """ if not validate: return if isinstance(mapping, Mapper): return names: Index1L = columns.get_level_values(self.NAME_LEVEL) if set(mapping.keys()) != set(names): msg = "Mapping keys must match column names" raise ValueError(msg) if columns.nlevels <= 1: return codes: Index1L = columns.get_level_values(self.CODE_LEVEL) for name, value in mapping.items(): if not value.is_one_hot_encoded: continue matched = names == name if set(value.codes) != set(codes[matched].map(str)): msg = "Mapping codes must match column codes" raise ValueError(msg)
[docs] def _add_indexer(self) -> Indexer[Key, NonNegativeInt]: """ Build the memoized index dispatcher used by ``idx``. Returns ------- Indexer[Key, NonNegativeInt] Index dispatcher for transformed columns. """ n = self.columns.nlevels return Indexer(*map(self._add_getter, range(1, n + 1)))
[docs] def _add_getter(self, n: PositiveInt) -> Getter[Key, NonNegativeInt]: """ Return the lookup function matching a given number of keys. Returns ------- Getter[Key, NonNegativeInt] Getter matching the requested key arity. Raises ------ ValueError If the requested key arity is unsupported. """ match n: case 1: return self._get_with_name case 2: return self._get_with_code case _: msg = f"Unsupported number of keys: {n}" raise ValueError(msg)
def __repr__(self) -> str: """ Return a developer-facing representation of the mapper. Returns ------- str String representation of the mapper. """ return f"Mapper({self._mapping!r}, columns={self.columns!r})"
[docs] def _get_with_name(self, name: Key) -> NonNegativeInt: """ Resolve the transformed column index for a non-encoded feature name. Returns ------- NonNegativeInt Column position associated with ``name``. Raises ------ KeyError If ``name`` is not present in the mapper. """ if name not in self.names: msg = f"Name {name} not found" raise KeyError(msg) return self.names.index(name)
[docs] def _get_with_code(self, name: Key, code: Key) -> NonNegativeInt: """ Resolve the transformed column index for a one-hot encoded feature code. Returns ------- NonNegativeInt Column position associated with ``(name, code)``. Raises ------ KeyError If ``name`` or ``code`` is not present in the mapper. """ if name not in self.names: msg = f"Name {name} not found in names" raise KeyError(msg) codes = self[name].codes if code not in codes: msg = f"Code {code} not found in codes associatedwith {name}" raise KeyError(msg) indices = [j for j, n in enumerate(self.names) if n == name] codes = tuple(self.codes[j] for j in indices) i = codes.index(code) return indices[i]
[docs] @staticmethod def _repr(mapping: Mapping[Key, Key | Number]) -> str: """ Return a column-aligned string representation of a small mapping. Returns ------- str Multi-line aligned representation of ``mapping``. """ length = max(len(str(k)) for k in mapping) lines = [ f"{str(k).ljust(length + 1)} : {v}" for k, v in mapping.items() ] return "\n".join(lines)