Source code for ocean.maxsat._variables._tree

from collections.abc import Iterator, Mapping

from pydantic import validate_call

from ...tree._keeper import TreeKeeper, TreeLike
from ...typing import NonNegativeInt
from .._base import BaseModel, Var


[docs] class TreeVar(Var, TreeKeeper, Mapping[NonNegativeInt, object]): """MaxSAT literals encoding the active leaf of one parsed tree.""" PATH_VAR_NAME_FMT: str = "{name}_path" CLASS_VAR_NAME_FMT: str = "{name}_class" _path: Mapping[NonNegativeInt, int] _class: Mapping[NonNegativeInt, int] _hard_voting: bool = False def __init__( self, tree: TreeLike, name: str, ) -> None: Var.__init__(self, name=name) TreeKeeper.__init__(self, tree=tree)
[docs] def build(self, model: BaseModel) -> None: self._hard_voting = bool(getattr(model, "_hard_voting", False)) if self._hard_voting: name = self.CLASS_VAR_NAME_FMT.format(name=self._name) self._class = self._add_class(model=model, name=name) model.add_exactly_one(list(self._class.values())) else: name = self.PATH_VAR_NAME_FMT.format(name=self._name) self._path = self._add_path(model=model, name=name) model.add_exactly_one(list(self._path.values()))
def __len__(self) -> int: return self.n_nodes def __iter__(self) -> Iterator[NonNegativeInt]: return iter(range(self.n_nodes)) @validate_call def __getitem__(self, node_id: NonNegativeInt) -> int: if self._hard_voting: msg = "Leaf variables are not available in hard-voting mode." raise ValueError(msg) return self._path[node_id]
[docs] @validate_call def cget(self, class_id: NonNegativeInt) -> int: if not self._hard_voting: msg = "Class variables are only available in hard-voting mode." raise ValueError(msg) return self._class[class_id]
def _add_path( self, model: BaseModel, name: str, ) -> Mapping[NonNegativeInt, int]: return { leaf.node_id: model.add_var(name=f"{name}[{leaf.node_id}]") for leaf in self.leaves } def _add_class( self, model: BaseModel, name: str, ) -> Mapping[NonNegativeInt, int]: n_classes = self.shape[-1] return { class_id: model.add_var(name=f"{name}[{class_id}]") for class_id in range(n_classes) }