Source code for ocean.cp._builder.model

from collections.abc import Iterable
from typing import Protocol

import numpy as np
from ortools.sat.python import cp_model as cp

from ...abc import Mapper
from ...tree._node import Node
from .._base import BaseModel
from .._variables import FeatureVar, TreeVar


class ModelBuilder(Protocol):
    """Protocol implemented by objects that encode ensemble constraints."""

    def build(
        self,
        model: BaseModel,
        *,
        trees: Iterable[TreeVar],
        mapper: Mapper[FeatureVar],
    ) -> None:
        """
        Build the model constraints for the given trees and features.

        Parameters
        ----------
        model : BaseModel
            The model to which the constraints will be added.
        trees : tuple[TreeVar, ...]
            The tree variables for which the constraints will be built.
        mapper : Mapper[FeatureVar]
            The feature variables for which the constraints will be built.

        """
        raise NotImplementedError


[docs] class ConstraintProgramBuilder(ModelBuilder): """Build CP constraints linking feature variables to tree paths."""
[docs] def build( self, model: BaseModel, *, trees: Iterable[TreeVar], mapper: Mapper[FeatureVar], ) -> None: for tree in trees: self._build(model, tree=tree, mapper=mapper)
def _build( self, model: BaseModel, *, tree: TreeVar, mapper: Mapper[FeatureVar], ) -> None: for leaf in tree.leaves: self._build_path(model, tree=tree, leaf=leaf, mapper=mapper) def _build_path( self, model: BaseModel, *, tree: TreeVar, leaf: Node, mapper: Mapper[FeatureVar], ) -> None: y = tree[leaf.node_id] self._propagate(model, node=leaf, mapper=mapper, y=y) def _propagate( self, model: BaseModel, *, node: Node, mapper: Mapper[FeatureVar], y: cp.IntVar, ) -> None: parent = node.parent if parent is None: return v = mapper[parent.feature] self._expand(model, node=parent, y=y, v=v, sigma=node.sigma) self._propagate(model, node=parent, mapper=mapper, y=y) def _expand( self, model: BaseModel, *, node: Node, y: cp.IntVar, v: FeatureVar, sigma: bool, ) -> None: if v.is_binary: self._bset(model, y=y, v=v, sigma=sigma) elif v.is_continuous: self._cset(model, node=node, y=y, v=v, sigma=sigma) elif v.is_discrete: self._dset(model, node=node, y=y, v=v, sigma=sigma) elif v.is_one_hot_encoded: self._eset(model, node=node, y=y, v=v, sigma=sigma) @staticmethod def _bset( model: BaseModel, *, y: cp.IntVar, v: FeatureVar, sigma: bool, ) -> None: x = v.xget() if sigma: model.Add(x <= 0).OnlyEnforceIf(y) else: model.Add(x >= 1).OnlyEnforceIf(y) @staticmethod def _cset( model: BaseModel, *, node: Node, y: cp.IntVar, v: FeatureVar, sigma: bool, ) -> None: threshold = node.threshold j = int(np.searchsorted(v.levels, threshold, side="left")) x = v.xget() if sigma: model.Add(x <= j - 1).OnlyEnforceIf(y) else: model.Add(x >= j).OnlyEnforceIf(y) @staticmethod def _dset( model: BaseModel, *, node: Node, y: cp.IntVar, v: FeatureVar, sigma: bool, ) -> None: threshold = node.threshold x = v.xget() if sigma: model.Add(x <= int(threshold)).OnlyEnforceIf(y) else: model.Add(x >= int(threshold) + 1).OnlyEnforceIf(y) @staticmethod def _eset( model: BaseModel, *, node: Node, y: cp.IntVar, v: FeatureVar, sigma: bool, ) -> None: x = v.xget(node.code) if sigma: model.Add(x <= 0).OnlyEnforceIf(y) else: model.Add(x >= 1).OnlyEnforceIf(y)
class ModelBuilderFactory: CP: type[ConstraintProgramBuilder] = ConstraintProgramBuilder