Source code for chemfit.abstract_objective_function
from __future__ import annotations
import abc
from typing import Any, Callable, Protocol, runtime_checkable
[docs]
class ObjectiveFunctor(abc.ABC):
[docs]
@abc.abstractmethod
def __call__(self, parameters: dict[str, Any]) -> float:
"""
Compute the objective value given a set of parameters.
Args:
parameters: Dictionary of parameter names to float values.
Returns:
float: Computed objective value (e.g., error metric).
"""
...
[docs]
class QuantityComputer(abc.ABC):
[docs]
def __init__(self):
"""Initialize the QuantityComputer."""
self._last_quantities: dict[str, Any] | None = None
[docs]
def __call__(self, parameters: dict[str, Any]) -> dict[str, Any]:
self._last_quantities = self._compute(parameters)
return self._last_quantities
@abc.abstractmethod
def _compute(self, parameters: dict[str, Any]) -> dict[str, Any]:
"""Compute dictionary of quantities for a given set of new parameters."""
...
[docs]
class QuantityComputerObjectiveFunction(ObjectiveFunctor):
[docs]
def __init__(
self,
loss_function: Callable[[dict[str, Any]], float] | ObjectiveFunctor,
quantity_computer: QuantityComputer,
) -> None:
"""Initialize the objective function with a quantity computer."""
super().__init__()
self.quantity_computer = quantity_computer
self.loss_function = loss_function
self._last_loss: float | None = None
[docs]
def __call__(self, parameters: dict[str, Any]) -> float:
quantities = self.quantity_computer(parameters)
self._last_loss = self.loss_function(quantities)
return self._last_loss