Source code for ocean.mip._variables._tree

from collections.abc import Iterator, Mapping
from enum import Enum

import gurobipy as gp
import numpy as np
from pydantic import validate_call

from ...tree._keeper import TreeKeeper, TreeLike
from ...tree._node import Node
from ...typing import NonNegativeInt
from .._base import BaseModel, Var
from .._builders.flow import FlowBuilder, FlowBuilderFactory


[docs] class TreeVar(Var, TreeKeeper, Mapping[NonNegativeInt, gp.Var]): """MIP variables encoding the flow through one parsed tree.""" FLOW_VAR_NAME_FMT: str = "{name}_flow" _flow: gp.MVar _value: gp.MLinExpr _length: gp.LinExpr _builder: FlowBuilder
[docs] class FlowType(Enum): """Available formulations for the tree flow variables.""" CONTINUOUS = "CONTINUOUS" BINARY = "BINARY"
def __init__( self, tree: TreeLike, name: str, *, flow_type: FlowType = FlowType.CONTINUOUS, _adaboost: bool = False, ) -> None: Var.__init__(self, name=name) TreeKeeper.__init__(self, tree=tree) self._set_builder(flow_type=flow_type) self._adaboost = _adaboost @property def value(self) -> gp.MLinExpr: return self._value @property def length(self) -> gp.LinExpr: return self._length
[docs] def build(self, model: BaseModel) -> None: name = self.FLOW_VAR_NAME_FMT.format(name=self._name) self._flow = self._builder.get(model=model, tree=self, name=name) # Propagate Flow model.addConstr(self[self.root.node_id] == 1) self._propagate(model, node=self.root) # Set Value self._value = self._get_value() # Set Average Path Length self._length = self._get_length()
def __len__(self) -> int: return self.n_nodes def __iter__(self) -> Iterator[NonNegativeInt]: return iter(range(self.n_nodes)) @validate_call def __getitem__(self, node_id: NonNegativeInt) -> gp.Var: return self._flow[node_id].item() def _set_builder(self, *, flow_type: FlowType) -> None: match flow_type: case self.FlowType.BINARY: self._builder = FlowBuilderFactory.Binary() case self.FlowType.CONTINUOUS: self._builder = FlowBuilderFactory.Continuous() def _propagate(self, model: BaseModel, node: Node) -> None: if node.is_leaf: return left, right = node.left, node.right nid, lid, rid = node.node_id, left.node_id, right.node_id model.addConstr(self[nid] == self[lid] + self[rid]) self._propagate(model, node=left) self._propagate(model, node=right) def _get_value(self) -> gp.MLinExpr: value = gp.MLinExpr.zeros(self.shape) for leaf in self.leaves: if self._adaboost: # one-hot encode the confidence vector val: np.ndarray[tuple[int, ...], np.dtype[np.float64]] = ( np.zeros(leaf.value.shape[1]) ) val[np.argmax(leaf.value)] = 1 else: val = leaf.value value += self._flow[leaf.node_id] * val return value def _get_length(self) -> gp.LinExpr: def length(node: Node) -> gp.LinExpr: return node.length * self[node.node_id] return sum(map(length, self.leaves), gp.LinExpr())