Source code for ocean.tree._tree

from collections.abc import Iterator

from pydantic import validate_call

from ..typing import Array1D, NonNegativeInt, PositiveInt
from ._node import Node


[docs] class Tree: """In-memory representation of a parsed tree ensemble member.""" root: Node _shape: tuple[NonNegativeInt, ...] _xgboost: bool = False _adaboost: bool = False def __init__(self, root: Node) -> None: self.root = root self._shape = root.leaves[0].value.shape @property def n_nodes(self) -> PositiveInt: return self.root.size @property def max_depth(self) -> NonNegativeInt: return self.root.height @property def leaves(self) -> tuple[Node, *tuple[Node, ...]]: return self.root.leaves @property def shape(self) -> tuple[NonNegativeInt, ...]: return self._shape @property def logit(self) -> Array1D: return self._base_score_prob @logit.setter def logit(self, value: Array1D) -> None: self._base_score_prob = value @property def xgboost(self) -> bool: return self._xgboost @xgboost.setter def xgboost(self, value: bool) -> None: self._xgboost = value @property def adaboost(self) -> bool: return self._adaboost @adaboost.setter def adaboost(self, value: bool) -> None: self._adaboost = value
[docs] @validate_call def nodes_at(self, depth: NonNegativeInt) -> Iterator[Node]: return self._nodes_at(self.root, depth=depth)
def _nodes_at(self, node: Node, *, depth: NonNegativeInt) -> Iterator[Node]: if depth == 0: yield node for child in node.children: yield from self._nodes_at(child, depth=depth - 1)