import operator
from collections.abc import Iterable
from functools import partial
from itertools import chain
import xgboost as xgb
from sklearn.ensemble import AdaBoostClassifier
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
from ..abc import Mapper
from ..feature import Feature
from ..typing import NonNegativeInt, ParsableEnsemble, SKLearnTree
from ._node import Node
from ._parse_xgb import parse_xgb_ensemble
from ._protocol import (
SKLearnTreeProtocol,
TreeProtocol,
)
from ._tree import Tree
type SKLearnDecisionTree = DecisionTreeClassifier | DecisionTreeRegressor
def _build_leaf(tree: TreeProtocol, node_id: NonNegativeInt) -> Node:
value = tree.value[node_id, :]
n_samples = int(tree.n_samples[node_id])
return Node(node_id, n_samples=n_samples, value=value)
def _build_node(
tree: TreeProtocol,
node_id: NonNegativeInt,
*,
mapper: Mapper[Feature],
) -> Node:
idx = int(tree.feature[node_id])
name = mapper.names[idx]
children = map(int, (tree.left[node_id], tree.right[node_id]))
left_id, right_id = children
threshold, code = None, None
n_samples = int(tree.n_samples[node_id])
if mapper[name].is_numeric:
threshold = float(tree.threshold[node_id])
mapper[name].add(threshold)
elif mapper[name].is_one_hot_encoded:
code = mapper.codes[idx]
node = Node(
node_id,
feature=name,
threshold=threshold,
code=code,
n_samples=n_samples,
)
node.left = _parse_node(tree, left_id, mapper=mapper)
node.right = _parse_node(tree, right_id, mapper=mapper)
return node
def _parse_node(
tree: TreeProtocol,
node_id: NonNegativeInt,
*,
mapper: Mapper[Feature],
) -> Node:
left_id, right_id = map(int, (tree.left[node_id], tree.right[node_id]))
if left_id == right_id:
return _build_leaf(tree, node_id)
return _build_node(tree, node_id, mapper=mapper)
def _parse_tree(
sklearn_tree: SKLearnTree,
*,
mapper: Mapper[Feature],
is_adaboost: bool = False,
) -> Tree:
sk_tree = SKLearnTreeProtocol(sklearn_tree)
root = _parse_node(sk_tree, 0, mapper=mapper)
tree = Tree(root=root)
if is_adaboost:
tree.adaboost = True
return tree
[docs]
def parse_tree(
tree: SKLearnDecisionTree,
*,
mapper: Mapper[Feature],
is_adaboost: bool = False,
) -> Tree:
"""
Convert a fitted scikit-learn tree into an OCEAN tree.
Returns
-------
Tree
Parsed tree structure used by the optimization backends.
"""
getter = operator.attrgetter("tree_")
return _parse_tree(getter(tree), mapper=mapper, is_adaboost=is_adaboost)
[docs]
def parse_trees(
trees: Iterable[SKLearnDecisionTree],
*,
mapper: Mapper[Feature],
is_adaboost: bool = False,
) -> tuple[Tree, ...]:
"""
Parse an iterable of fitted scikit-learn trees.
Returns
-------
tuple[Tree, ...]
Parsed tree structures in the same order as the input iterable.
"""
parser = partial(parse_tree, mapper=mapper, is_adaboost=is_adaboost)
return tuple(map(parser, trees))
def parse_ensemble(
ensemble: ParsableEnsemble,
*,
mapper: Mapper[Feature],
) -> tuple[Tree, ...]:
"""
Parse a supported tree ensemble model into OCEAN trees.
Returns
-------
tuple[Tree, ...]
Parsed tree structures extracted from the fitted ensemble.
"""
if isinstance(ensemble, xgb.Booster):
return parse_xgb_ensemble(ensemble, mapper=mapper)
if isinstance(ensemble, xgb.XGBClassifier):
return parse_xgb_ensemble(ensemble.get_booster(), mapper=mapper)
if isinstance(ensemble, AdaBoostClassifier):
return parse_trees(ensemble, mapper=mapper, is_adaboost=True)
return parse_trees(ensemble, mapper=mapper)
[docs]
def parse_ensembles(
*ensembles: ParsableEnsemble,
mapper: Mapper[Feature],
) -> tuple[Tree, ...]:
"""
Flatten and parse one or more supported tree ensembles.
Returns
-------
tuple[Tree, ...]
Parsed trees from every provided ensemble, flattened into one tuple.
"""
parser = partial(parse_ensemble, mapper=mapper)
return tuple(chain.from_iterable(map(parser, ensembles)))