Source code for ocean.tree._parse

import operator
from collections.abc import Iterable
from functools import partial
from itertools import chain

import xgboost as xgb
from sklearn.ensemble import AdaBoostClassifier
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor

from ..abc import Mapper
from ..feature import Feature
from ..typing import NonNegativeInt, ParsableEnsemble, SKLearnTree
from ._node import Node
from ._parse_xgb import parse_xgb_ensemble
from ._protocol import (
    SKLearnTreeProtocol,
    TreeProtocol,
)
from ._tree import Tree

type SKLearnDecisionTree = DecisionTreeClassifier | DecisionTreeRegressor


def _build_leaf(tree: TreeProtocol, node_id: NonNegativeInt) -> Node:
    value = tree.value[node_id, :]
    n_samples = int(tree.n_samples[node_id])
    return Node(node_id, n_samples=n_samples, value=value)


def _build_node(
    tree: TreeProtocol,
    node_id: NonNegativeInt,
    *,
    mapper: Mapper[Feature],
) -> Node:
    idx = int(tree.feature[node_id])
    name = mapper.names[idx]
    children = map(int, (tree.left[node_id], tree.right[node_id]))
    left_id, right_id = children
    threshold, code = None, None
    n_samples = int(tree.n_samples[node_id])

    if mapper[name].is_numeric:
        threshold = float(tree.threshold[node_id])
        mapper[name].add(threshold)
    elif mapper[name].is_one_hot_encoded:
        code = mapper.codes[idx]

    node = Node(
        node_id,
        feature=name,
        threshold=threshold,
        code=code,
        n_samples=n_samples,
    )
    node.left = _parse_node(tree, left_id, mapper=mapper)
    node.right = _parse_node(tree, right_id, mapper=mapper)
    return node


def _parse_node(
    tree: TreeProtocol,
    node_id: NonNegativeInt,
    *,
    mapper: Mapper[Feature],
) -> Node:
    left_id, right_id = map(int, (tree.left[node_id], tree.right[node_id]))
    if left_id == right_id:
        return _build_leaf(tree, node_id)
    return _build_node(tree, node_id, mapper=mapper)


def _parse_tree(
    sklearn_tree: SKLearnTree,
    *,
    mapper: Mapper[Feature],
    is_adaboost: bool = False,
) -> Tree:
    sk_tree = SKLearnTreeProtocol(sklearn_tree)
    root = _parse_node(sk_tree, 0, mapper=mapper)
    tree = Tree(root=root)
    if is_adaboost:
        tree.adaboost = True
    return tree


[docs] def parse_tree( tree: SKLearnDecisionTree, *, mapper: Mapper[Feature], is_adaboost: bool = False, ) -> Tree: """ Convert a fitted scikit-learn tree into an OCEAN tree. Returns ------- Tree Parsed tree structure used by the optimization backends. """ getter = operator.attrgetter("tree_") return _parse_tree(getter(tree), mapper=mapper, is_adaboost=is_adaboost)
[docs] def parse_trees( trees: Iterable[SKLearnDecisionTree], *, mapper: Mapper[Feature], is_adaboost: bool = False, ) -> tuple[Tree, ...]: """ Parse an iterable of fitted scikit-learn trees. Returns ------- tuple[Tree, ...] Parsed tree structures in the same order as the input iterable. """ parser = partial(parse_tree, mapper=mapper, is_adaboost=is_adaboost) return tuple(map(parser, trees))
def parse_ensemble( ensemble: ParsableEnsemble, *, mapper: Mapper[Feature], ) -> tuple[Tree, ...]: """ Parse a supported tree ensemble model into OCEAN trees. Returns ------- tuple[Tree, ...] Parsed tree structures extracted from the fitted ensemble. """ if isinstance(ensemble, xgb.Booster): return parse_xgb_ensemble(ensemble, mapper=mapper) if isinstance(ensemble, xgb.XGBClassifier): return parse_xgb_ensemble(ensemble.get_booster(), mapper=mapper) if isinstance(ensemble, AdaBoostClassifier): return parse_trees(ensemble, mapper=mapper, is_adaboost=True) return parse_trees(ensemble, mapper=mapper)
[docs] def parse_ensembles( *ensembles: ParsableEnsemble, mapper: Mapper[Feature], ) -> tuple[Tree, ...]: """ Flatten and parse one or more supported tree ensembles. Returns ------- tuple[Tree, ...] Parsed trees from every provided ensemble, flattened into one tuple. """ parser = partial(parse_ensemble, mapper=mapper) return tuple(chain.from_iterable(map(parser, ensembles)))