from abc import ABC
from typing import Any, Protocol
from pysat.formula import WCNF, IDPool
[docs]
class BaseModel(ABC, WCNF):
"""Base weighted CNF model used by the MaxSAT backend."""
vpool: IDPool
_solver_epoch: int
def __init__(self) -> None:
WCNF.__init__(self)
self.vpool = IDPool() # Create new pool for each instance
self._solver_epoch = 0
def __setattr__(self, name: str, value: Any) -> None: # noqa: ANN401
object.__setattr__(self, name, value)
[docs]
def build_vars(self, *variables: "Var") -> None:
for variable in variables:
variable.build(model=self)
[docs]
def add_var(self, name: str) -> int:
if name in self.vpool.obj2id: # var has been already created
msg = f"Variable with name '{name}' already exists."
raise ValueError(msg)
return self.vpool.id(f"{name}") # type: ignore[no-any-return]
[docs]
def get_var(self, name: str) -> int:
if name not in self.vpool.obj2id: # var has not been created
msg = f"Variable with name '{name}' does not exist."
raise ValueError(msg)
return self.vpool.obj2id[name] # type: ignore[no-any-return]
[docs]
def add_hard(self, lits: list[int], return_id: bool = False) -> int: # noqa: FBT001, FBT002
"""
Add a hard clause (must be satisfied).
Returns:
The clause ID if return_id is True, otherwise -1.
"""
# weight=None => hard clause in WCNF
self.append(lits)
if return_id:
return len(self.hard) - 1 # pyright: ignore[reportUnknownArgumentType]
return -1
[docs]
def add_soft(self, lits: list[int], weight: int = 1) -> None:
"""Add a soft clause with a given weight."""
self.append(lits, weight=weight)
[docs]
def add_exactly_one(self, lits: list[int]) -> None:
"""Add constraint that exactly one path is selected."""
self.add_hard(lits) # at least one
for i in range(len(lits)):
for j in range(i + 1, len(lits)):
self.add_hard([-lits[i], -lits[j]]) # at most one
def _clean_soft(self) -> None:
"""Reset the model to only contain hard constraints."""
self.soft.clear()
self.wght.clear()
self.topw = 1
@property
def solver_epoch(self) -> int:
return self._solver_epoch
[docs]
def invalidate_solver_state(self) -> None:
"""Mark the formula as having changed non-monotonically."""
self._solver_epoch += 1
class Var(Protocol):
_name: str
def __init__(self, name: str) -> None:
self._name = name
def build(self, model: BaseModel) -> None: ...