Source code for ocean.maxsat._env

from __future__ import annotations

from typing import TYPE_CHECKING, cast

from pysat.examples.rc2 import RC2

if TYPE_CHECKING:
    from pysat.formula import WCNF

Clause = list[int]


class Env:
    """Container exposing the shared MaxSAT solver instance."""

    _solver: MaxSATSolver

    def __init__(self) -> None:
        self._solver = MaxSATSolver()

    @property
    def solver(self) -> MaxSATSolver:
        return self._solver

    @solver.setter
    def solver(self, solver: MaxSATSolver) -> None:
        self._solver = solver


[docs] class MaxSATSolver: """Thin RC2 wrapper to keep a stable interface.""" _model: list[int] | None = None _cost: float = float("inf") _rc2: RC2 | None = None _formula: WCNF | None = None _formula_epoch: int | None = None _synced_hard: int = 0 _synced_soft: int = 0 _active_solver_name: str | None = None _state_version: int = 0 def __init__( self, solver_name: str = "cadical195", TimeLimit: int = 60, n_threads: int = 1, ) -> None: self.solver_name = solver_name self.TimeLimit = TimeLimit self.n_threads = n_threads self.verbose = False
[docs] def solve(self, w: WCNF) -> list[int]: if self._needs_rebuild(w): self._rebuild(w) else: self._sync(w) if self._rc2 is None: # pragma: no cover msg = "Internal RC2 solver was not initialized." raise RuntimeError(msg) self._rc2.verbose = int(self.verbose) model = cast("list[int] | None", self._rc2.compute()) if model is None: msg = "UNSAT: no counterfactual found." raise RuntimeError(msg) self._model = model self._cost = self._rc2.cost return model
[docs] def delete(self) -> None: """Release any persistent RC2 state held by this wrapper.""" if self._rc2 is not None: self._rc2.delete() self._rc2 = None self._formula = None self._formula_epoch = None self._synced_hard = 0 self._synced_soft = 0 self._active_solver_name = None self._model = None self._cost = float("inf")
@property def state_token(self) -> int | None: """Opaque token identifying the currently cached RC2 instance.""" if self._rc2 is None: return None return id(self._rc2) @property def synced_counts(self) -> tuple[int, int]: """Number of hard and soft clauses synchronized into RC2.""" return self._synced_hard, self._synced_soft @property def state_version(self) -> int: """Monotonic version incremented whenever RC2 is rebuilt.""" return self._state_version def _needs_rebuild(self, w: WCNF) -> bool: if self._rc2 is None: return True if self._formula is not w: return True if self._formula_epoch != getattr(w, "solver_epoch", 0): return True return self._active_solver_name != self.solver_name def _rebuild(self, w: WCNF) -> None: self.delete() self._state_version += 1 hard = cast("list[Clause]", w.hard) soft = cast("list[Clause]", w.soft) self._rc2 = RC2( w, solver=self.solver_name, adapt=True, exhaust=False, minz=True, verbose=int(self.verbose), ) self._formula = w self._formula_epoch = getattr(w, "solver_epoch", 0) self._synced_hard = len(hard) self._synced_soft = len(soft) self._active_solver_name = self.solver_name def _sync(self, w: WCNF) -> None: if self._rc2 is None: # pragma: no cover return hard = cast("list[Clause]", w.hard) soft = cast("list[Clause]", w.soft) weights = cast("list[int]", w.wght) for clause in hard[self._synced_hard :]: self._rc2.add_clause(clause) new_soft = zip( soft[self._synced_soft :], weights[self._synced_soft :], strict=True, ) for clause, weight in new_soft: self._rc2.add_clause(clause, weight=weight) self._synced_hard = len(hard) self._synced_soft = len(soft)
[docs] def model(self, v: int) -> float: """ Return 1.0 if variable v is true in the model, 0.0 otherwise. Args: v: Variable to check in the model. Returns: 1.0 if variable v is true in the model, 0.0 otherwise. Raises: ValueError: If no model has been found solve() must be called first. """ if self._model is None: msg = "No model found, please run 'solve' first." raise ValueError(msg) # The model is a list of signed literals. # Variable v is true if v is in the model, false if -v is in the model. if v in self._model: return 1.0 return 0.0
@property def cost(self) -> float: return self._cost
ENV = Env()