from __future__ import annotations
import contextlib
import copy
from functools import partial
from types import SimpleNamespace
from typing import TYPE_CHECKING, Any, Callable, Generic, Protocol, TypeVar
if TYPE_CHECKING:
from collections.abc import Iterable
T = TypeVar("T", covariant=True) # noqa: PLC0105
[docs]
class FutureLike(Protocol, Generic[T]):
"""
Minimal protocol for future-like objects used by the evaluation framework.
This protocol intentionally mirrors the subset of the interface provided by
:class:`concurrent.futures.Future`
"""
[docs]
def result(self, timeout: float | None = None) -> T: ...
[docs]
def cancel(self) -> bool: ...
[docs]
class ExecutorLike(Protocol):
"""
Minimal executor protocol used for parallel evaluation.
This interface is modeled after :class:`concurrent.futures.Executor`.
"""
[docs]
def submit(self, fn: Callable[..., T], /, *args, **kwargs) -> FutureLike[T]: ...
self,
fn: Callable[..., T],
*iterables: Iterable[Any],
timeout: float | None = None,
chunksize: int = 1,
) -> Iterable[T]: ...
[docs]
class ChildContextConfigurator(Protocol):
"""
Protocol for configuring child evaluation contexts.
The configurator is called once for each child context immediately
after the parent context has spawned them. It may mutate the child
context or the parent context in place to configure child-specific
evaluation behavior, metadata, or resource usage.
The ``idx_child_ctx`` argument is the absolute index of the current
child within the spawned batch.
"""
self,
idx_child_ctx: int,
child_ctx: EvaluateContext,
num_children: int,
parent_ctx: EvaluateContext,
): ...
[docs]
class EvaluateContext:
[docs]
def __init__(
self,
config: SimpleNamespace | None = None,
shared: SimpleNamespace | None = None,
executor: ExecutorLike | None = None,
):
"""
Container for per-evaluation state.
A new instance of `EvaluateContext` should generally be created for each
evaluation of an objective function or quantity computation.
Implementations write all per-call information into the context
rather than storing it in the objective instance.
This makes evaluation easier to reason about and
compatible with concurrent execution.
The context may also own a single batch of child contexts representing
nested sub-evaluations. Such child contexts can be created explicitly
with ``spawn_children()`` or managed with the ``child_contexts()``
context manager.
Args:
config:
Optional child-local evaluation configuration. This
namespace is copied to spawned child contexts so that
children inherit parent defaults but can be configured
independently
shared:
Optional namespace for shared read-only state that may be
reused across related contexts, such as parent/child
evaluations.
executor:
Optional executor-like object that can be used by evaluation
code to schedule parallel work.
Attributes:
quantities (dict[str, Any] | None): Intermediate quantities
computed during evaluation. Implementations may leave this
as None if no quantities are produced.
parameters (dict[str, Any] | None): Parameter dictionary used
for this evaluation.
loss (float | None): Final scalar loss value. Set by
`ObjectiveFunctor` implementations.
meta (dict[str, Any]): Free-form metadata dictionary.
Implementations may add diagnostic or structural
information here as needed.
Meta data from child contexts may be collected into the parent
temp (SimpleNamespace): Scratch space for temporary values
during evaluation. Nothing stored here is part of the
public API. It is omitted from the `to_meta_data` function.
config:
Child-local evaluation configuration for this context.
shared:
Shared state or resources reused across related contexts.
"""
self._set_defaults(config, shared)
self.executor: ExecutorLike | None = executor
[docs]
def to_meta_data(self) -> dict[str, Any]:
"""
Return a dictionary summarizing the evaluation state.
Returns:
dict[str, Any]: A dictionary containing the fields
`quantities`, `parameters`, `loss`, and `meta`.
"""
return {
"quantities": self.quantities,
"parameters": self.parameters,
"loss": self.loss,
"meta": self.meta,
}
def _set_defaults(
self, config: SimpleNamespace | None, shared: SimpleNamespace | None
):
self.quantities: dict[str, Any] | None = None
self.parameters: dict[str, Any] | None = None
self.loss: float | None = None
self.temp = SimpleNamespace()
self.config = SimpleNamespace() if config is None else config
self.shared = SimpleNamespace() if shared is None else shared
self.meta: dict[str, Any] = {}
self.executor = None
self._children: list[EvaluateContext] = []
[docs]
def spawn_children(
self, n_children: int, configurator: ChildContextConfigurator | None = None
) -> list[EvaluateContext]:
"""
Create child contexts linked to this context.
Each child receives a deep copy of ``config``, while sharing the
same ``shared`` namespace and executor reference as the parent.
An ``EvaluateContext`` is intended to manage at most one batch of
child contexts per evaluation. Calling ``spawn_children()`` again on
the same context replaces the previous child batch.
In many cases, ``child_contexts()`` is the preferred interface, since
it automatically collects child metadata when the nested evaluation
scope exits.
Args:
n_children: Number of child contexts to create.
configurator:
Optional configurator applied once to each spawned child
context immediately after creation.
Returns:
The newly created child contexts.
"""
self._children = [
EvaluateContext(
config=copy.deepcopy(self.config),
shared=self.shared,
executor=self.executor,
)
for _ in range(n_children)
]
if configurator is not None:
for idx_child, child_ctx in enumerate(self._children):
configurator(
idx_child_ctx=idx_child,
child_ctx=child_ctx,
num_children=n_children,
parent_ctx=self,
)
return self._children
[docs]
def collect_child_meta_data(self, recursive: bool = True):
"""
Collect metadata from child contexts.
The collected child metadata is stored in ``self.meta["children"]``.
Components that spawn child contexts are generally expected to collect
their child metadata before returning to their caller. The
``child_contexts()`` context manager provides a convenient scoped way
to do this automatically.
Args:
recursive: If ``True``, metadata from all descendants is collected before
serializing the immediate children. This produces a fully
materialized metadata tree.
If ``False``, only the immediate children are serialized. This
can be useful when nested components manage their own metadata
collection and have already populated their ``meta`` fields.
Notes:
In most cases ``recursive=True`` is the safest choice, since it
ensures that nested child contexts are fully represented in the
resulting metadata structure.
"""
if len(self._children) > 0:
if recursive:
[c.collect_child_meta_data(recursive) for c in self._children]
self.meta["children"] = [c.to_meta_data() for c in self._children]
[docs]
@contextlib.contextmanager
def child_contexts(
self,
n_children: int,
configurator: ChildContextConfigurator | None = None,
recursive: bool = True,
):
"""
Create a scoped child-context batch and collect its metadata on exit.
This context manager is a convenience wrapper around
``spawn_children()`` and ``collect_child_meta_data()``. It is intended
for nested evaluations where the component spawning child contexts is
also responsible for collecting their metadata before returning.
Args:
n_children: Number of child contexts to create.
configurator:
Optional configurator applied to each spawned child context.
recursive:
Passed to ``collect_child_meta_data()`` when the scope exits.
Yields:
The list of spawned child contexts.
Notes:
Child metadata is collected automatically when the context manager
exits, even if an exception is raised inside the managed block.
"""
children = self.spawn_children(n_children, configurator=configurator)
try:
yield children
finally:
self.collect_child_meta_data(recursive)
def __getstate__(self) -> dict[str, Any]:
"""
Return the worker-input state for this context.
Returns:
Dictionary containing the context state needed to initialize the
context in a worker or other execution environment.
"""
return {
"parameters": self.parameters,
"config": self.config,
"shared": self.shared,
"meta": self.meta,
}
def __setstate__(self, state: dict[str, Any]):
"""
Restore worker-input state into this context.
Args:
state:
State previously produced by ``__getstate__()``.
"""
self._set_defaults(config=state["config"], shared=state["shared"])
self.parameters = state["parameters"]
self.meta = state["meta"]
[docs]
def to_result_state(self) -> dict[str, Any]:
"""
Return the result-bearing state of this context.
Returns:
Dictionary containing the evaluation results recorded in this
context. This state is intended for child/worker-to-parent
synchronization and does not include shared resources or child
context objects.
"""
return {
"parameters": self.parameters,
"loss": self.loss,
"quantities": self.quantities,
"meta": self.meta,
}
[docs]
def apply_result_state(self, state: dict[str, Any]):
"""
Apply result-bearing state from another context.
Args:
state:
State previously produced by ``to_result_state()``.
Side Effects:
Updates ``parameters``, ``loss``, ``quantities``, and ``meta``
on this context.
"""
self.parameters = state["parameters"]
self.loss = state["loss"]
self.quantities = state["quantities"]
self.meta = state["meta"]
[docs]
class ObjectiveFunctor:
[docs]
def __call__(
self, parameters: dict[str, Any], ctx: EvaluateContext | None = None
) -> float:
"""
Evaluate the objective function.
Implementations should compute a scalar loss from the given
parameter dictionary. All per-evaluation state must be written
into the provided `ctx`. If no context is supplied,
a new one should be created internally.
Args:
parameters (dict[str, Any]): Mapping of parameter names to
float values.
ctx (EvaluateContext | None): Optional evaluation context. If
None, a new `EvaluateContext` should be created.
Returns:
float: The computed scalar loss.
Notes:
- Implementations should avoid mutating `self` during the
call. All per-evaluation information should be placed in
`ctx` instead.
- This method is synchronous. For concurrent or asynchronous
evaluation, use one `EvaluateContext` per call and invoke
this method in multiple threads/tasks.
"""
raise NotImplementedError
# We need a forward declaration here, so that we can write the `QuantityComputer.with_loss` function,
# which returns a QuantityComputerObjectiveFunction
class QuantityComputerObjectiveFunction: # type: ignore
...
LossFunction = (
Callable[[dict[str, Any]], float]
| Callable[[dict[str, Any], dict[str, Any]], float]
)
[docs]
class QuantityComputer:
[docs]
def __init__(self):
"""
Initialize a quantity computer.
A `QuantityComputer` maps a parameter dictionary to a dictionary
of intermediate quantities, typically used by an objective
function. Instances may hold static configuration, but
should not store per-evaluation state internally.
Attributes:
static_meta_data (dict[str, Any]): Static metadata associated
with this quantity computer. This is merged into
`ctx.meta` on each call.
"""
self.static_meta_data: dict[str, Any] = {} # For static meta data
[docs]
def __call__(
self, parameters: dict[str, Any], ctx: EvaluateContext | None = None
) -> dict[str, Any]:
"""
Compute quantities for the given parameters.
Args:
parameters (dict[str, Any]): Parameter dictionary.
ctx (EvaluateContext | None): Optional context. If None, a
new one is created.
Returns:
dict[str, Any]: The computed quantity dictionary.
Notes:
Implementations of `_compute` must not mutate `self`. All
per-evaluation information should be written into `ctx`.
Side Effects:
Stores ``parameters`` in ``ctx.parameters``.
Merges ``self.static_meta_data`` into ``ctx.meta``.
Stores the computed quantities in ``ctx.quantities``.
"""
if ctx is None:
ctx = EvaluateContext()
ctx.parameters = parameters
ctx.meta.update(self.static_meta_data)
ctx.quantities = self._compute(parameters, ctx)
return ctx.quantities
def _compute(
self, parameters: dict[str, Any], ctx: EvaluateContext
) -> dict[str, Any]:
"""Compute dictionary of quantities for a given set of parameters."""
raise NotImplementedError
[docs]
def with_loss(
self, loss_function: Callable[..., float], /, **kwargs: Any
) -> QuantityComputerObjectiveFunction:
"""
Create a new QuantityComputerObjectiveFunction from this QuantityComputer.
Args:
loss_function (LossFunction): The loss function to use.
Returns:
QuantityComputerObjectiveFunction: A new QuantityComputerObjectiveFunction
"""
return QuantityComputerObjectiveFunction(
loss_function=partial(loss_function, **kwargs),
quantity_computer=self,
)
[docs]
class QuantityComputerObjectiveFunction(ObjectiveFunctor): # noqa: F811
[docs]
def __init__(
self,
loss_function: LossFunction,
quantity_computer: QuantityComputer,
) -> None:
"""
Objective function composed of a `QuantityComputer` and a loss function.
This objective first computes intermediate quantities using
``quantity_computer`` and then applies ``loss_function`` to
obtain a scalar loss.
Args:
loss_function (Callable): A function with signature:
`loss_function(quantities) -> float`
or
`loss_function(quantities, parameters) -> float`
quantity_computer (QuantityComputer): Object responsible for
computing intermediate quantities.
Attributes:
static_meta_data (dict[str, Any]): Static metadata associated
with this objective. Merged into `ctx.meta` on each call.
"""
super().__init__()
self.quantity_computer = quantity_computer
self.static_meta_data: dict[str, Any] = {}
self.loss_function = loss_function
[docs]
def __call__(
self, parameters: dict[str, Any], ctx: EvaluateContext | None = None
) -> float:
"""
Compute the objective loss.
This method:
1. Computes intermediate quantities using the quantity computer.
2. Applies the loss function.
3. Stores results in the evaluation context.
Args:
parameters (dict[str, Any]): Parameter dictionary.
ctx (EvaluateContext | None): Optional context. If None, a
new one is created.
Returns:
float: The computed scalar loss.
Side Effects:
Stores the computed loss in ``ctx.loss``.
Updates ``ctx.meta`` with ``self.static_meta_data`` after the
wrapped ``QuantityComputer`` may have already added metadata.
Populates ``ctx.quantities`` and ``ctx.parameters`` via the
wrapped ``QuantityComputer``.
Notes:
``loss_function`` may accept either ``(quantities)`` or
``(quantities, parameters) as positional args``.
"""
if ctx is None:
ctx = EvaluateContext()
quantities = self.quantity_computer(parameters, ctx)
# Update or set static meta data if needed
ctx.meta.update(self.static_meta_data)
try:
ctx.loss = self.loss_function(quantities) # pyright: ignore[reportCallIssue] # we actually handle this with the signature checking
except TypeError:
ctx.loss = self.loss_function(quantities, parameters) # pyright: ignore[reportCallIssue] # we actually handle this with the signature checking
return ctx.loss