from collections.abc import Mapping
import numpy as np
from ortools.sat.python import cp_model as cp
from ...feature import Feature
from ...feature._keeper import FeatureKeeper
from ...typing import Key
from .._base import BaseModel, Var
[docs]
class FeatureVar(Var, FeatureKeeper):
"""CP variable bundle associated with a single parsed feature."""
X_VAR_NAME_FMT: str = "x[{name}]"
_x: cp.IntVar
_u: Mapping[Key, cp.IntVar]
_objvar: cp.IntVar
def __init__(self, feature: Feature, name: str) -> None:
Var.__init__(self, name=name)
FeatureKeeper.__init__(self, feature=feature)
[docs]
def build(self, model: BaseModel) -> None:
if not self.is_one_hot_encoded:
self._x = self._add_x(model)
if self.is_one_hot_encoded:
self._u = self._add_u(model)
[docs]
def xget(self, code: Key | None = None) -> cp.IntVar:
if self.is_one_hot_encoded:
return self._xget_one_hot_encoded(code)
if code is not None:
msg = "Get by code is only supported for one-hot encoded features"
raise ValueError(msg)
return self._x
[docs]
def objvarget(self) -> cp.IntVar:
if not self.is_numeric:
msg = "The 'objvarget' method is only supported"
msg += " for continuous and discrete features"
raise ValueError(msg)
return self._objvar
def _add_x(self, model: BaseModel) -> cp.IntVar:
name = self.X_VAR_NAME_FMT.format(name=self._name)
# Case when the feature is one-hot encoded.
if self.is_one_hot_encoded:
msg = "One-hot encoded features are not for x"
raise ValueError(msg)
# Case when the feature is binary.
if self.is_binary:
return self._add_binary(model, name)
# Case when the feature is continuous or discrete
if self.is_continuous:
return self._add_continuous(model, name)
return self._add_discrete(model, name)
def _add_u(self, model: BaseModel) -> Mapping[Key, cp.IntVar]:
name = self._name.format(name=self._name)
u = self._add_one_hot_encoded(model=model, name=name)
model.AddExactlyOne(u.values())
return u
def _add_one_hot_encoded(
self,
model: BaseModel,
name: str,
) -> Mapping[Key, cp.IntVar]:
return {
code: model.NewBoolVar(f"{name}[{code}]") for code in self.codes
}
@staticmethod
def _add_binary(model: BaseModel, name: str) -> cp.IntVar:
return model.NewBoolVar(name)
def _add_continuous(self, model: BaseModel, name: str) -> cp.IntVar:
self._objvar = model.NewIntVar(
0, 42, f"u_{name}"
) # arbitrary, will be adapted for each query
m = len(self.levels)
return model.NewIntVar(0, m - 2, name)
def _add_discrete(self, model: BaseModel, name: str) -> cp.IntVar:
if len(self.thresholds) != 0:
values = [
val for v in self.thresholds for val in [int(v), int(v) + 1]
]
else:
values = np.asarray(self.levels).astype(int).tolist()
val = cp.Domain.FromValues(values)
self._objvar = model.NewIntVar(0, int(self.levels[-1]), f"u_{name}")
return model.NewIntVarFromDomain(val, name)
def _xget_one_hot_encoded(self, code: Key | None) -> cp.IntVar:
if code is None:
msg = "Code is required for one-hot encoded features get"
raise ValueError(msg)
if code not in self.codes:
msg = f"Code '{code}' not found in the feature codes"
raise ValueError(msg)
return self._u[code]