Source code for chemfit.fitter

from __future__ import annotations

import logging
import math
import time
from concurrent.futures import ThreadPoolExecutor
from numbers import Real
from typing import TYPE_CHECKING, Any, Callable, cast

import nevergrad as ng
import numpy as np
import numpy.typing as npt
from pydictnest import (
    flatten_dict,
    unflatten_dict,
)
from scipy.optimize import OptimizeResult, minimize

from chemfit.abstract_objective_function import (
    EvaluateContext,
    ExecutorLike,
    ObjectiveFunctor,
)
from chemfit.executor_utils import map_with_context
from chemfit.utils import check_params_near_bounds
from chemfit.wrap_funcs import WrappedObjectiveFunctor

if TYPE_CHECKING:
    from collections.abc import Iterable

logger = logging.getLogger(__name__)


[docs] class FitterEvaluateContext(EvaluateContext):
[docs] def __init__(self): """ Initialize fitter-specific evaluation state. This context extends ``EvaluateContext`` with optimization-specific tracking fields that record the number of evaluations performed and the best loss, parameters, and metadata observed so far during a fit. """ super().__init__() self.n_evals: int = 0 self.opt_loss: float | None = None self.opt_params: dict[str, Any] | None = None self.opt_meta: dict[str, Any] | None = None
def __getstate__(self) -> dict[str, Any]: state = super().__getstate__() state["n_evals"] = self.n_evals state["opt_loss"] = self.opt_loss state["opt_params"] = self.opt_params state["opt_meta"] = self.opt_meta return state def __setstate__(self, state: dict[str, Any]): super().__setstate__(state) self.n_evals = state["n_evals"] self.opt_loss = state["opt_loss"] self.opt_params = state["opt_params"] self.opt_meta = state["opt_meta"]
[docs] class FitterObjectiveFunctor(ObjectiveFunctor):
[docs] def __init__( self, wrap_me: ObjectiveFunctor, swallow_exceptions: bool = False, log_exceptions: bool = True, value_bad_params: float = 1e5, ): """ Initialize a fitter-specific objective wrapper. This wrapper sits between a raw objective and an optimizer. It adds basic robustness and tracking behavior on top of the wrapped objective: - exceptions may be logged and optionally swallowed - non-scalar or NaN return values are replaced by a large penalty - the attached ``FitterEvaluateContext`` is updated with the number of evaluations and the best loss/parameters seen so far Args: wrap_me: Underlying objective functor to evaluate. swallow_exceptions: If ``True``, exceptions raised by the wrapped objective are converted into a penalized objective value instead of being re-raised. log_exceptions: If ``True``, exceptions raised by the wrapped objective are logged. value_bad_params (float, optional): Threshold used to represent invalid or numerically unstable parameter regions. Defaults to 1e5. """ self.wrap_me = wrap_me self.value_bad_params = value_bad_params self.swallow_exceptions: bool = swallow_exceptions self.log_exceptions: bool = log_exceptions
[docs] def post_process_return_value( self, parameters: dict[str, Any], value: float | None, ctx: FitterEvaluateContext, ) -> float: ctx.n_evals += 1 # then we make sure that the value is a float if not isinstance(value, Real): logger.debug( f"Objective function did not return a single float, but returned `{value}` with type {type(value)}. Clipping loss to {self.value_bad_params}" ) value = float(self.value_bad_params) if math.isnan(value): logger.debug( f"Objective function returned NaN. Clipping loss to {self.value_bad_params}" ) value = self.value_bad_params loss = float(value) if ctx.opt_loss is None or loss < ctx.opt_loss: ctx.opt_loss = loss ctx.opt_params = dict(parameters) ctx.opt_meta = dict(ctx.meta) return loss
[docs] def __call__( # type: ignore self, parameters: dict[str, Any], ctx: FitterEvaluateContext | None = None ) -> float: if ctx is None: ctx = FitterEvaluateContext() # first we try if we can get a value at all try: value = self.wrap_me(parameters, ctx) except Exception as e: if self.log_exceptions: logger.exception( "Caught exception while evaluating objective function." ) if not self.swallow_exceptions: raise e value = float("nan") return self.post_process_return_value( parameters=parameters, value=value, ctx=ctx )
[docs] class Fitter:
[docs] def __init__( self, objective_function: Callable[[dict[str, Any]], float] | ObjectiveFunctor, initial_params: dict[str, Any], bounds: dict[str, Any] | None = None, near_bound_tol: float | None = None, value_bad_params: float = 1e5, swallow_exceptions: bool = False, log_exceptions: bool = True, ) -> None: """ Driver class for parameter optimization. A `Fitter` wraps an objective (either a plain callable or an `ObjectiveFunctor`) in a `FitterObjectiveFunctor` and exposes convenience methods for running optimizations with nevergrad and SciPy. Args: objective_function (Callable | ObjectiveFunctor): Objective to be minimized. If a plain callable is provided, it is converted to an `ObjectiveFunctor` using `to_objective_functor`. initial_params (dict[str, Any]): Initial parameter values. bounds (dict[str, Any] | None, optional): Bounds for each parameter. The structure must mirror ``initial_params``, but may omit bounds for parameters. Defaults to None. near_bound_tol (float | None, optional): If provided, parameters whose optimized values lie within this relative distance of their bounds will trigger a warning in `hook_post_fit`. Defaults to None. value_bad_params (float, optional): Threshold used by some objective wrappers to represent invalid or numerically unstable parameter regions. Defaults to 1e5. """ # Make sure that we have an ObjectiveFunctor instance if not isinstance(objective_function, ObjectiveFunctor): objective_function = WrappedObjectiveFunctor( func=objective_function, pass_ctx=False ) self.objective_function = FitterObjectiveFunctor( objective_function, swallow_exceptions=swallow_exceptions, log_exceptions=log_exceptions, value_bad_params=value_bad_params, ) self.initial_parameters = initial_params if bounds is None: self.bounds = {} else: self.bounds = bounds self.value_bad_params: float = value_bad_params self.near_bound_tol = near_bound_tol self.contexts: list[FitterEvaluateContext] = [] self.callbacks: list[ tuple[Callable[[int, list[FitterEvaluateContext]], None], int] ] = []
[docs] def register_callback( self, func: Callable[[int, list[FitterEvaluateContext]], None], n_steps: int ): """ Register a callback to be executed during optimization. The callback is invoked every ``n_steps`` iterations (or nevergrad/SciPy "steps", depending on the backend), and receives the current step index and the list of `FitterEvaluateContext` instances used by the fitter. Args: func (Callable[[int, list[FitterEvaluateContext]], None]): Callback function of the form ``func(step, contexts)``. n_steps (int): Interval (in steps) at which the callback is invoked. """ self.callbacks.append((func, n_steps))
def _unify_callbacks( self, ) -> ( tuple[Callable[[int, list[FitterEvaluateContext]], None], int] | tuple[None, int] ): """Generate a single callback from the list of callbacks.""" if len(self.callbacks) == 0: return None, 0 min_n_steps = min([n_steps for (_, n_steps) in self.callbacks]) def callback(step: int, ctxs: list[FitterEvaluateContext]): for cb, n_steps in self.callbacks: if step % n_steps == 0: cb(step, ctxs) return callback, min_n_steps def _hook_pre_fit(self): """Run bookkeeping steps before starting an optimization.""" logger.info("Start fitting") self.time_fit_start = time.time() def _hook_post_fit(self, opt_params: dict[str, Any]): """ Run bookkeeping steps after an optimization. This method records the fit end time, logs completion, and optionally warns if any optimized parameters lie near or outside their bounds. Args: opt_params: Optimized parameter dictionary returned by the optimizer. """ self.time_fit_end = time.time() logger.info("End fitting") if self.near_bound_tol is not None: self.problematic_params = check_params_near_bounds( opt_params, self.bounds, self.near_bound_tol ) if len(self.problematic_params) > 0: logger.warning( f"The following parameters are near or outside the bounds (tolerance {self.near_bound_tol * 100:.1f}%):" ) for kp, vp, lower, upper in self.problematic_params: logger.warning( f" parameter = {kp}, lower = {lower}, value = {vp}, upper = {upper}" )
[docs] def fit_nevergrad( # noqa: PLR0912, PLR0915 self, budget: int, optimizer_str: str = "NgIohTuned", num_workers: int = 1, contexts: list[FitterEvaluateContext] | None = None, executor: ExecutorLike | None = None, initial_observations: Iterable[tuple[dict[str, Any], float | None]] | None = None, ) -> dict[str, Any]: """ Optimize parameters using a nevergrad optimizer. This method drives nevergrad's ask/tell interface and can evaluate multiple candidate points in parallel through an ``ExecutorLike`` instance. One ``FitterEvaluateContext`` is used per worker so that evaluation-side state can be tracked independently. Args: budget: Total number of objective evaluations to allow. optimizer_str: Name of the nevergrad optimizer to use. Must be a key in ``ng.optimizers.registry``. num_workers: Number of points to evaluate in parallel per ask/tell step. contexts: Optional list of per-worker fitter contexts. If provided, its length must equal ``num_workers``. executor: Optional executor used for parallel evaluation when ``num_workers > 1``. If ``None``, a ``ThreadPoolExecutor`` is created. initial_observations: Optional iterable of previously evaluated ``(parameters, loss)`` pairs used to seed the optimizer. These observations are replayed into the optimizer before the main optimization loop begins. This allows approximate continuation of a previous run or warm-starting a new optimization. If any parameter set violates the bounds, it is skipped. These observations do not consume evaluations from the main budget and do not trigger callbacks. This does not restore the exact internal state of the optimizer. Only the provided observations are injected. Returns: Dictionary of optimized parameter values. Raises: KeyError: If ``optimizer_str`` is not found in the nevergrad optimizer registry. AssertionError: If ``contexts`` is provided and its length does not equal ``num_workers``. Side Effects: - Initializes fitter bookkeeping via ``_hook_pre_fit()``. - Populates ``self.contexts`` with one context per worker. - Invokes registered callbacks during optimization. - Runs post-fit checks via ``_hook_post_fit()``. """ if num_workers != 1 and executor is None: executor = ThreadPoolExecutor(num_workers) self._hook_pre_fit() flat_bounds = flatten_dict(self.bounds) flat_initial_params = flatten_dict(self.initial_parameters) ng_params = ng.p.Dict() for k, v in flat_initial_params.items(): # If `k` is in bounds, fetch the lower and upper bound # It `k` is not in bounds just put lower=None and upper=None lower, upper = flat_bounds.get(k, (None, None)) ng_params[k] = ng.p.Scalar(init=v, lower=lower, upper=upper) instru = ng.p.Instrumentation(ng_params) try: OptimizerCls = ng.optimizers.registry[optimizer_str] except KeyError as e: e.add_note(f"Available solvers: {list(ng.optimizers.registry.keys())}") raise e optimizer = OptimizerCls( parametrization=instru, budget=budget, num_workers=num_workers ) def f_ng(parameters: dict[str, Any], ctx: FitterEvaluateContext) -> float: params = unflatten_dict(parameters, dict_factory=dict) return self.objective_function(params, ctx) callback, n_steps = self._unify_callbacks() # We need one context per worker if contexts is None: self.contexts = [FitterEvaluateContext() for _ in range(num_workers)] else: assert len(contexts) == num_workers self.contexts = contexts # After the if statements we know that we have a list of FitterEvaluateContexts and not None self.contexts = cast("list[FitterEvaluateContext]", self.contexts) # This applies restart parameters, **if** they are within the bounds if initial_observations is not None: for restart_params, restart_loss_value in initial_observations: skip = False flat_params = flatten_dict(restart_params) for k, (lower, upper) in flat_bounds.items(): rp = flat_params.get(k, None) if rp is not None and (rp < lower or rp > upper): skip = True if skip: continue optimizer.suggest(flat_params) asked_params = optimizer.ask() # The recorded loss value may be changed by our wrapper # Also we record the side effects on the context this way post_processed_loss_value = ( self.objective_function.post_process_return_value( parameters=restart_params, value=restart_loss_value, ctx=self.contexts[0], ) ) optimizer.tell(asked_params, post_processed_loss_value) for step in range(budget // num_workers): # On the first evaluation we ensure that the optimizer suggests the initial params if step == 0: optimizer.suggest(flat_initial_params) # Ask for num_workers parameters to evaluate in parallel asked_params = [optimizer.ask() for _ in range(num_workers)] flat_params = [p.value[0][0] for p in asked_params] if num_workers == 1: losses = [f_ng(flat_params[0], self.contexts[0])] else: assert executor is not None assert self.contexts is not None losses = map_with_context( executor, f_ng, flat_params, ctxs=self.contexts ) [ optimizer.tell(params, loss) for params, loss in zip(asked_params, losses, strict=True) ] if callback is not None and step % n_steps == 0: callback(step, self.contexts) recommendation = optimizer.provide_recommendation() args, _ = recommendation.value # Our optimal params are the first positional argument flat_opt_params = args[0] opt_params = unflatten_dict(flat_opt_params, dict_factory=dict[str, Any]) self._hook_post_fit(opt_params) return opt_params
[docs] def fit_scipy( self, method: str = "L-BFGS-B", ctx: FitterEvaluateContext | None = None, **kwargs, ) -> dict[str, Any]: """ Optimize parameters using ``scipy.optimize.minimize``. The parameter dictionary is flattened into a vector representation for SciPy and reconstructed on each objective evaluation. Because SciPy's ``minimize`` interface is synchronous, a single ``FitterEvaluateContext`` is used for the full optimization run. Args: method: Optimization method passed to ``scipy.optimize.minimize``. ctx: Optional fitter evaluation context to reuse during the fit. If ``None``, a new one is created. **kwargs: Additional keyword arguments forwarded to ``scipy.optimize.minimize``. Returns: Dictionary of optimized parameter values. Warning: If the optimizer does not converge, a warning is logged. Side Effects: - Initializes fitter bookkeeping via ``_hook_pre_fit()``. - Populates ``self.contexts`` with a single context. - Invokes registered callbacks during optimization. - Runs post-fit checks via ``_hook_post_fit()``. """ self._hook_pre_fit() # Scipy expects a function with n real-valued parameters f(x) # but our objective function takes a dictionary of parameters. # Moreover, the dictionary might not be flat but nested. # Therefore, as a first step, we flatten the bounds and # initial parameter dicts flat_params = flatten_dict(self.initial_parameters) flat_bounds = flatten_dict(self.bounds) # We then capture the order of keys in the flattened dictionary self._keys = flat_params.keys() # The initial value of x and of the bounds are derived from that order x0 = np.array([flat_params[k] for k in self._keys]) if len(flat_bounds) == 0: bounds = None else: bounds = np.array([flat_bounds.get(k, (None, None)) for k in self._keys]) # Since we know that scipy.optimize works synchronously, we create a single context, which we'll keep alive. if ctx is None: self.contexts = [FitterEvaluateContext()] else: self.contexts = [ctx] # The local objective function first creates a flat dictionary from the `x` array # by zipping it with the captured flattened keys and then unflattens the dictionary # to pass it to the objective functions def f_scipy(x: npt.NDArray) -> float: p = unflatten_dict(dict(zip(self._keys, x)), dict_factory=dict[str, Any]) cast("dict[str, Any]", p) assert self.contexts is not None return self.objective_function(p, ctx=self.contexts[0]) # First concatenate the list of callbacks into a single function callback, n_steps = self._unify_callbacks() def callback_scipy(intermediate_result: OptimizeResult): if "nit" in intermediate_result: step = intermediate_result.nit else: step = self.contexts[0].n_evals if callback is not None and step % n_steps == 0: callback(step, self.contexts) res = minimize( f_scipy, x0, method=method, bounds=bounds, **kwargs, callback=callback_scipy ) if not res.success: logger.warning(f"Fit did not converge: {res.message}") opt_params = dict(zip(self._keys, res.x)) opt_params = unflatten_dict(opt_params) self._hook_post_fit(opt_params) return opt_params