from __future__ import annotations
from typing import Any, Callable
from chemfit.abstract_objective_function import (
EvaluateContext,
ObjectiveFunctor,
QuantityComputer,
)
WrappableObjFunction = (
Callable[[dict[str, Any]], float]
| Callable[[dict[str, Any], EvaluateContext], float]
)
[docs]
class WrappedObjectiveFunctor(ObjectiveFunctor):
[docs]
def __init__(
self,
func: WrappableObjFunction,
pass_ctx: bool = False,
func_args: tuple[Any] | None = None,
func_kwargs: dict[str, Any] | None = None,
):
"""
Initialize a wrapped objective functor.
Args:
func: Callable to wrap as an ``ObjectiveFunctor``. The callable may
either accept only ``parameters`` or accept both
``parameters`` and ``ctx``.
pass_ctx: If ``True``, call ``func(parameters, ctx)``. If
``False``, call ``func(parameters)``.
"""
super().__init__()
self.func = func
self.pass_ctx = pass_ctx
if func_args is None:
self.func_args = ()
else:
self.func_args = func_args
if func_kwargs is None:
self.func_kwargs = {}
else:
self.func_kwargs = func_kwargs
[docs]
def bind(self, /, *args: Any, **kwargs: Any) -> WrappedObjectiveFunctor:
"""
Return a new quantity computer with extra arguments bound.
The bound arguments are passed to the wrapped function in addition
to the usual ChemFit arguments.
Args:
*args: Positional arguments to bind after ``parameters`` (and
after ``ctx`` as well if ``pass_ctx=True``).
**kwargs: Keyword arguments to bind.
Returns:
A new wrapped quantity computer with the requested arguments
pre-applied.
"""
return type(self)(
func=self.func, pass_ctx=self.pass_ctx, func_args=args, func_kwargs=kwargs
)
[docs]
def __call__(
self, parameters: dict[str, Any], ctx: EvaluateContext | None = None
) -> float:
"""
Evaluate the wrapped callable as an objective functor.
If ``pass_ctx`` is ``True``, the wrapped callable receives both the
parameter dictionary and the evaluation context. Otherwise, it
receives only the parameter dictionary.
Args:
parameters: Parameter dictionary for the current evaluation.
ctx: Optional evaluation context. If ``None``, a new
``EvaluateContext`` is created.
Returns:
Scalar loss value returned by the wrapped callable.
Side Effects:
- Stores ``parameters`` in ``ctx.parameters``.
- Stores the returned loss in ``ctx.loss``.
"""
if ctx is None:
ctx = EvaluateContext()
ctx.parameters = parameters
if self.pass_ctx:
ctx.loss = self.func(
parameters, *self.func_args, **self.func_kwargs, ctx=ctx
)
else:
ctx.loss = self.func(parameters, *self.func_args, **self.func_kwargs)
return ctx.loss
[docs]
def to_objective_functor(pass_ctx: bool = False):
"""
Create a decorator that wraps a callable as an objective functor.
Args:
pass_ctx: If ``True``, the decorated callable is expected to accept
``(parameters, ctx)``. Otherwise, it is expected to accept only
``(parameters)``.
Returns:
Decorator that converts a compatible callable into a
``WrappedObjectiveFunctor``.
"""
def wrap(func: WrappableObjFunction):
return WrappedObjectiveFunctor(func, pass_ctx=pass_ctx)
return wrap
WrappableQuantFunction = Callable[..., dict[str, Any]]
[docs]
class WrappedQuantityComputer(QuantityComputer):
[docs]
def __init__(
self,
func: WrappableQuantFunction,
pass_ctx: bool = False,
func_args: tuple[Any] | None = None,
func_kwargs: dict[str, Any] | None = None,
):
"""
Initialize a wrapped quantity computer.
Args:
func: Callable to wrap as a ``QuantityComputer``. The callable may
either accept only ``parameters`` or accept both
``parameters`` and ``ctx``.
pass_ctx: If ``True``, call ``func(parameters, ctx)``. If
``False``, call ``func(parameters)``.
"""
super().__init__()
self.func = func
self.pass_ctx = pass_ctx
if func_args is None:
self.func_args = ()
else:
self.func_args = func_args
if func_kwargs is None:
self.func_kwargs = {}
else:
self.func_kwargs = func_kwargs
[docs]
def bind(self, /, *args: Any, **kwargs: Any) -> WrappedQuantityComputer:
"""
Return a new quantity computer with extra arguments bound.
The bound arguments are passed to the wrapped function in addition
to the usual ChemFit arguments.
Args:
*args: Positional arguments to bind after ``parameters`` (and
after ``ctx`` as well if ``pass_ctx=True``).
**kwargs: Keyword arguments to bind.
Returns:
A new wrapped quantity computer with the requested arguments
pre-applied.
"""
return type(self)(
func=self.func, pass_ctx=self.pass_ctx, func_args=args, func_kwargs=kwargs
)
def _compute(
self,
parameters: dict[str, Any],
ctx: EvaluateContext,
) -> dict[str, Any]:
"""
Compute quantities using the wrapped callable.
If ``pass_ctx`` is ``True``, the wrapped callable receives both the
parameter dictionary and the evaluation context. Otherwise, it
receives only the parameter dictionary.
Args:
parameters: Parameter dictionary for the current evaluation.
ctx: Evaluation context for the current call.
Returns:
Quantity dictionary returned by the wrapped callable.
"""
if self.pass_ctx:
return self.func(parameters, *self.func_args, **self.func_kwargs, ctx=ctx)
return self.func(parameters, *self.func_args, **self.func_kwargs)
[docs]
def to_quantity_computer(pass_ctx: bool = False):
"""
Create a decorator that wraps a callable as a quantity computer.
Args:
pass_ctx: If ``True``, the decorated callable is expected to accept
``(parameters, ctx)``. Otherwise, it is expected to accept only
``(parameters)``.
Returns:
Decorator that converts a compatible callable into a
``WrappedQuantityComputer``.
"""
def wrap(func: WrappableQuantFunction):
return WrappedQuantityComputer(func, pass_ctx=pass_ctx)
return wrap