Source code for chemfit.combined_objective_function

from __future__ import annotations

import inspect
import math
from collections.abc import Sequence
from typing import Any, Callable, Protocol, cast

from typing_extensions import Self

from chemfit.abstract_objective_function import (
    ChildContextConfigurator,
    EvaluateContext,
    ObjectiveFunctor,
)
from chemfit.wrap_funcs import WrappedObjectiveFunctor


[docs] def transform_generic_callables( list_of_callables: Sequence[Callable[[dict[str, Any]], float]], ) -> list[ObjectiveFunctor]: res = [] for func in list_of_callables: if isinstance(func, ObjectiveFunctor): res.append(func) else: res.append(WrappedObjectiveFunctor(func)) return res
[docs] class Reducer(Protocol):
[docs] def __call__(self, terms: list[float]) -> float: ...
[docs] class Aggregator(Protocol):
[docs] def __call__(
self, terms: list[float], quantities: list[dict[str, Any]], ctx: EvaluateContext, ) -> float: ...
[docs] def sum_reducer(terms: list[float]) -> float: return sum(terms)
[docs] def mean_reducer(terms: list[float]) -> float: return sum(terms) / len(terms)
[docs] def root_mean_reducer(terms: list[float]) -> float: return math.sqrt(mean_reducer(terms))
[docs] class ExceptionHandler(Protocol):
[docs] def __call__(
self, exception: Exception, ctx: EvaluateContext, idx: int ) -> float | None: ...
[docs] def raising_exception_handler( exception: Exception, ctx: EvaluateContext, # noqa: ARG001 idx: int, # noqa: ARG001 ) -> float | None: raise exception
[docs] def nan_exception_handler( exception: Exception, # noqa: ARG001 ctx: EvaluateContext, # noqa: ARG001 idx: int, # noqa: ARG001 ) -> float | None: return math.nan
[docs] def skip_exception_handler( exception: Exception, # noqa: ARG001 ctx: EvaluateContext, # noqa: ARG001 idx: int, # noqa: ARG001 ) -> float | None: return None
[docs] class WrappedReducer(Aggregator):
[docs] def __init__(self, reducer: Reducer) -> None: """A reducer that is wrapped in order to be used like as an Aggregator.""" self.reducer = reducer
[docs] def __call__( self, terms: Sequence[float], quantities: Sequence[dict[str, Any]], # noqa: ARG002 ctx: EvaluateContext, # noqa: ARG002 ) -> float: return self.reducer(terms)
[docs] def to_reducer(self) -> Reducer: return self.reducer
[docs] class CombinedObjectiveFunction(ObjectiveFunctor):
[docs] def __init__( self, objective_functions: Sequence[Callable[[dict[str, Any]], float]], weights: Sequence[float] | None = None, child_context_configurator: ChildContextConfigurator | None = None, reduction: Reducer | Aggregator = sum_reducer, exception_handler: ExceptionHandler = raising_exception_handler, ) -> None: """ Initialize a combined objective from multiple weighted terms. Each objective term is evaluated independently in its own child context. The resulting term values are multiplied by their corresponding weights, optionally filtered through ``exception_handler`` if evaluation fails, and then combined using ``reduction``. Generic callables are automatically wrapped as ``ObjectiveFunctor`` instances. Args: objective_functions: Sequence of objective functors or compatible callables. weights: Optional non-negative weight for each objective term. If ``None``, all weights default to ``1.0``. child_context_configurator: Optional callable used to configure each spawned child context before term evaluation. reduction: Callable used to reduce the list of weighted term values to a single scalar loss. Can be either a simple reducer, or the more advanced Aggregator, which can make use of the full context and the quantities. exception_handler: Callable used to handle exceptions raised during term evaluation. It may return a replacement value or ``None`` to skip the term entirely. Raises: AssertionError: If the number of weights does not match the number of objective functions, or if any weight is negative. """ # Convert to list internally for mutability self.objective_functions: list[ObjectiveFunctor] = transform_generic_callables( objective_functions ) self.child_context_configurator = child_context_configurator if len(inspect.signature(reduction).parameters) == 1: reduction = cast("Reducer", reduction) self.reduction: Aggregator = WrappedReducer(reduction) else: reduction = cast("Aggregator", reduction) self.reduction: Aggregator = reduction self.exception_handler = exception_handler if weights is None: # Default each weight to 1.0 self.weights: list[float] = [1.0] * len(self.objective_functions) else: self.weights = list(weights) # Ensure alignment between objective functions and weights assert len(self.weights) == len(self.objective_functions), ( "Number of weights must match number of objective functions." ) # Ensure all weights are non-negative assert all(w >= 0 for w in self.weights), "All weights must be non-negative."
[docs] def n_terms(self) -> int: """Return the number of objective terms.""" return len(self.weights)
[docs] def add( self, obj_funcs: ( Sequence[Callable[[dict[str, Any]], float]] | Callable[[dict[str, Any]], float] ), weights: Sequence[float] | float = 1.0, ) -> Self: """ Add one or more objective terms to the combined objective. Each added callable is converted to an ``ObjectiveFunctor`` if needed and appended to the existing term list. The corresponding weights are appended in the same order. Args: obj_funcs: A single objective callable or a sequence of objective callables to add. weights: Either a single non-negative weight applied to every added callable, or a sequence of non-negative weights whose length matches the number of added callables. Returns: The current instance. Raises: AssertionError: If a sequence of weights is provided with a length that does not match the number of added callables, or if any provided weight is negative. """ # Determine how many new functions are being added if isinstance(obj_funcs, Sequence) and not callable(obj_funcs): funcs_to_add = list(obj_funcs) # type: ignore[assignment] else: funcs_to_add = [obj_funcs] # type: ignore[assignment] funcs_to_add = transform_generic_callables(funcs_to_add) # Append each new objective function for fn in funcs_to_add: self.objective_functions.append(fn) # Handle weights if isinstance(weights, Sequence) and not isinstance(weights, (str, bytes)): weights_to_add = list(weights) # type: ignore[assignment] # Must match number of new functions assert len(weights_to_add) == len(funcs_to_add), ( "Length of weights sequence must equal number of functions added." ) else: # Single weight repeated for each new function weights_to_add = [float(weights) for _ in funcs_to_add] # Ensure all new weights are non-negative assert all(w >= 0 for w in weights_to_add), "All weights must be non-negative." # Append the new weights self.weights.extend(weights_to_add) # Final sanity check that lists remain aligned assert len(self.weights) == len(self.objective_functions), ( "After adding, weights and objective_functions must remain the same length." ) return self
[docs] @classmethod def add_flat( cls, combined_objective_functions_list: Sequence[Self], weights: Sequence[float] | None = None, ) -> Self: """ Flatten multiple combined objectives into a single instance. The objective functions from all input instances are concatenated into one flat list. The weights of each input instance are scaled by the corresponding entry in ``weights`` before concatenation. Warning: This method **does not preserve execution policy** from the input combined objectives. The resulting instance uses the default ``reduction``, ``exception_handler``, and ``child_context_configurator`` unless they are explicitly set afterward. Args: combined_objective_functions_list: Combined objective instances to flatten. weights: Optional non-negative scaling weights, one per input combined objective. If ``None``, all scaling weights default to ``1.0``. Returns: A new combined objective containing all flattened terms and scaled weights. Raises: AssertionError: If the number of scaling weights does not match the number of combined objectives, or if any scaling weight is negative. """ if weights is None: weights = [1.0 for _ in combined_objective_functions_list] # Ensure we have one scaling weight per sub-instance assert len(combined_objective_functions_list) == len(weights), ( "Must supply exactly one weight per CombinedObjectiveFunction." ) # Ensure all scaling weights are non-negative assert all(w >= 0 for w in weights), "All scaling weights must be non-negative." total_objective_functions: list[Callable[[dict[str, Any]], float]] = [] total_weights: list[float] = [] for sub_cob, scale in zip(combined_objective_functions_list, weights): total_objective_functions.extend(sub_cob.objective_functions) # Scale each internal weight total_weights.extend([w * scale for w in sub_cob.weights]) # Ensure no negative weights after scaling assert all(w >= 0 for w in total_weights), ( "Resulting weights must be non-negative." ) return cls(total_objective_functions, total_weights)
[docs] def evaluate_term( self, parameters: dict[str, Any], idx: int, ctx: EvaluateContext, ) -> float | None: """ Evaluate a single weighted objective term. The selected objective function is evaluated with the provided parameters and child context, then multiplied by its corresponding weight. If evaluation raises an exception, the configured ``exception_handler`` is called. Args: parameters: Parameter dictionary for the current evaluation. idx: Absolute index of the objective term to evaluate. ctx: Child evaluation context for this term. Returns: The weighted term value, or ``None`` if the exception handler chooses to skip the term. """ try: return self.objective_functions[idx](parameters, ctx) * self.weights[idx] except Exception as e: return self.exception_handler(e, ctx, idx)
[docs] def filter_terms( self, terms: list[float | None], ctx: EvaluateContext ) -> list[float]: """ Filter out terms that are 'None', while recording the skipped terms in ctx.meta['skipped_indices']. Side effects: - Writes to ctx.meta['skipped_indices'] """ skipped_indices = [] filtered_terms = [] for i, t in enumerate(terms): if t is None: skipped_indices.append(i) else: filtered_terms.append(t) ctx.meta["skipped_indices"] = skipped_indices return filtered_terms
[docs] def apply_reduction(self, terms: Sequence[float], ctx: EvaluateContext): child_quantities = [] for idx, child in enumerate(ctx.meta["children"]): if idx not in ctx.meta["skipped_indices"]: child_quantities.append(child["quantities"]) return self.reduction(terms, child_quantities, ctx)
[docs] def evaluate_terms( self, parameters: dict[str, Any], ctx: EvaluateContext ) -> list[float]: """ Evaluate the objective terms. This method prepares child contexts, evaluates each selected term in its own context, and drops any terms for which ``evaluate_term()`` returns ``None``. Args: parameters: Parameter dictionary for the current evaluation. ctx: Parent evaluation context. Returns: List of weighted term values that were successfully evaluated and not skipped by the exception handler. """ ctx.meta.update({"n_terms": self.n_terms()}) with ctx.child_contexts( n_children=self.n_terms(), configurator=self.child_context_configurator ) as child_ctxs: terms = [] for idx, ctx_term in enumerate(child_ctxs): terms.append(self.evaluate_term(parameters, idx, ctx_term)) return self.filter_terms(terms, ctx)
[docs] def __call__( self, parameters: dict[str, Any], ctx: EvaluateContext | None = None, ) -> float: """ Evaluate the combined objective. Each objective term is evaluated in its own child context using the same parameter dictionary. The weighted term values are combined using ``self.reduction``. After evaluation, child metadata is collected into the parent context and the reduced loss is stored in ``ctx.loss``. Args: parameters: Parameter dictionary for the evaluation. ctx: Optional parent evaluation context. If ``None``, a new ``EvaluateContext`` is created. Returns: The reduced scalar loss computed from the evaluated terms. Side Effects: - Populates ``ctx.parameters``. - Spawns child contexts in ``ctx``. - Collects child metadata into ``ctx.meta["children"]``. - Stores the final reduced loss in ``ctx.loss``. """ if ctx is None: ctx = EvaluateContext() ctx.parameters = parameters filtered_terms = self.evaluate_terms(parameters=parameters, ctx=ctx) ctx.loss = self.apply_reduction(filtered_terms, ctx) return ctx.loss