from __future__ import annotations
import logging
import math
import time
from dataclasses import dataclass
from functools import wraps
from numbers import Real
from typing import Any, Callable
import nevergrad as ng
import numpy as np
import numpy.typing as npt
from pydictnest import (
flatten_dict,
unflatten_dict,
)
from scipy.optimize import OptimizeResult, minimize
from chemfit.utils import check_params_near_bounds
logger = logging.getLogger(__name__)
[docs]
@dataclass
class FitInfo:
initial_value: float | None = None
final_value: float | None = None
time_taken: float | None = None
n_evals: int = 0
[docs]
@dataclass
class CallbackInfo:
opt_params: dict[str, Any]
opt_loss: float
cur_params: dict[str, Any]
cur_loss: float
step: int
info: FitInfo
[docs]
class Fitter:
[docs]
def __init__(
self,
objective_function: Callable[[dict[str, Any]], float],
initial_params: dict[str, Any],
bounds: dict[str, Any] | None = None,
near_bound_tol: float | None = None,
value_bad_params: float = 1e5,
) -> None:
"""
Initialize a Fitter.
Args:
objective_function (Callable[[dict], float]):
The objective function to be minimized.
initial_params (dict):
Initial values of the parameters.
bound (Optional[dict]):
Dictionary specifying bounds for each parameter.
near_bound_tol (Optional[float]):
If specified, checks whether any parameters are too close to their bounds and logs a warning if so.
value_bad_params (float):
Threshold value beyond which the objective function is considered to be in a poor or invalid region.
"""
self.objective_function = self.ob_func_wrapper(objective_function)
self.initial_parameters = initial_params
if bounds is None:
self.bounds = {}
else:
self.bounds = bounds
self.value_bad_params: float = value_bad_params
self.near_bound_tol = near_bound_tol
self.info = FitInfo()
self.callbacks: list[tuple[Callable[[CallbackInfo], None], int]] = []
[docs]
def register_callback(self, func: Callable[[CallbackInfo], None], n_steps: int):
"""
Register a callback which is executed after every `n_steps` of the optimization.
Multiple callbacks may be registered. They are executed in the order of registration.
The callback must be a callable with the following signature:
func(arg: CallbackInfo)
The `CallbackInfo` is a dataclass with the following attributes:
- `opt_params`: The optimal parameters at the time the callback is invoked.
- `opt_loss`: The loss value corresponding to the optimal parameters.
- `cur_params`: The parameters tested most recently when the callback is invoked.
- `cur_loss`: The loss value associated with the most recently tested parameters.
- `step`: The number of optimization steps performed so far
(generally not equal to the number of loss function evaluations).
- `info`: The current `FitInfo` instance of the fitter at the time the callback is invoked.
"""
self.callbacks.append((func, n_steps))
[docs]
def ob_func_wrapper(self, ob_func: Any) -> Callable[[dict[str, Any]], float]:
"""Wraps the objective function and applies some checks plus logging."""
@wraps(ob_func)
def wrapped_ob_func(params: dict[str, Any]) -> float:
# first we try if we can get a value at all
try:
value = ob_func(params)
self.info.n_evals += 1
except Exception as e:
# If we catch an exception we should just crash the code -> log and re-raise
logger.exception(
"Caught exception while evaluating objective function.",
stack_info=True,
stacklevel=2,
)
raise e
# then we make sure that the value is a float
if not isinstance(value, Real):
logger.debug(
f"Objective function did not return a single float, but returned `{value}` with type {type(value)}. Clipping loss to {self.value_bad_params}"
)
value = float(self.value_bad_params)
if math.isnan(value):
logger.debug(
f"Objective function returned NaN. Clipping loss to {self.value_bad_params}"
)
value = self.value_bad_params
return float(value)
return wrapped_ob_func
def _produce_callback(
self,
) -> tuple[Callable[[CallbackInfo], None], int] | tuple[None, int]:
"""Generate a single callback from the list of callbacks."""
if len(self.callbacks) == 0:
return None, 0
min_n_steps = min([n_steps for (_, n_steps) in self.callbacks])
def callback(callback_args: CallbackInfo):
for cb, n_steps in self.callbacks:
if callback_args.step % n_steps == 0:
cb(callback_args)
return callback, min_n_steps
[docs]
def hook_pre_fit(self):
"""A hook, which is invoked before optimizing."""
# Overwrite with a fresh FitInfo object
self.info = FitInfo()
logger.info("Start fitting")
self.info.initial_value = self.objective_function(self.initial_parameters)
logger.info(f" Initial obj func: {self.info.initial_value}")
if self.info.initial_value == self.value_bad_params:
logger.warning(
f"Starting optimization in a `bad` region. Objective function could not be evaluated properly. Loss has been set to {self.value_bad_params = }"
)
elif self.info.initial_value > self.value_bad_params:
new_value_bad_params = 1.1 * self.info.initial_value
logger.warning(
f"Starting optimization in a high loss region. Loss is {self.info.initial_value}, which is greater than {self.value_bad_params = }. Adjusting to {new_value_bad_params = }."
)
self.value_bad_params = new_value_bad_params
self.info.n_evals = 0
self.time_fit_start = time.time()
[docs]
def hook_post_fit(self, opt_params: dict[str, Any]):
"""A hook, which is invoked after optimizing."""
self.time_fit_end = time.time()
self.info.time_taken = self.time_fit_end - self.time_fit_start
assert self.info.final_value is not None
if self.info.final_value >= self.value_bad_params:
logger.warning(
f"Ending optimization in a `bad` region. Loss is greater or equal to {self.value_bad_params = }"
)
logger.info("End fitting")
logger.info(f" Info {self.info}")
if self.near_bound_tol is not None:
self.problematic_params = check_params_near_bounds(
opt_params, self.bounds, self.near_bound_tol
)
if len(self.problematic_params) > 0:
logger.warning(
f"The following parameters are near or outside the bounds (tolerance {self.near_bound_tol * 100:.1f}%):"
)
for kp, vp, lower, upper in self.problematic_params:
logger.warning(
f" parameter = {kp}, lower = {lower}, value = {vp}, upper = {upper}"
)
[docs]
def fit_nevergrad(
self, budget: int, optimizer_str: str = "NgIohTuned", **kwargs
) -> dict[str, Any]:
self.hook_pre_fit()
flat_bounds = flatten_dict(self.bounds)
flat_initial_params = flatten_dict(self.initial_parameters)
ng_params = ng.p.Dict()
for k, v in flat_initial_params.items():
# If `k` is in bounds, fetch the lower and upper bound
# It `k` is not in bounds just put lower=None and upper=None
lower, upper = flat_bounds.get(k, (None, None))
ng_params[k] = ng.p.Scalar(init=v, lower=lower, upper=upper)
instru = ng.p.Instrumentation(ng_params)
try:
OptimizerCls = ng.optimizers.registry[optimizer_str]
except KeyError as e:
e.add_note(f"Available solvers: {list(ng.optimizers.registry.keys())}")
raise e
optimizer = OptimizerCls(parametrization=instru, budget=budget)
def f_ng(p: dict[str, Any]) -> float:
params = unflatten_dict(p, dict_factory=dict[str, Any])
return self.objective_function(params)
callback, n_steps = self._produce_callback()
assert self.info.initial_value is not None
opt_loss = self.info.initial_value
for i in range(budget):
if i == 0:
flat_params = flat_initial_params
cur_loss = self.info.initial_value
p = optimizer.parametrization.spawn_child()
p.value = ( # type: ignore
(flat_params,),
{},
)
optimizer.tell(p, self.info.initial_value)
else:
p = optimizer.ask()
args, kwargs = p.value
flat_params = args[0]
cur_loss = f_ng(flat_params)
optimizer.tell(p, cur_loss)
opt_loss = min(opt_loss, cur_loss)
if callback is not None and i % n_steps == 0:
recommendation = optimizer.provide_recommendation()
args, kwargs = recommendation.value
flat_opt_params = args[0]
opt_params = unflatten_dict(
flat_opt_params, dict_factory=dict[str, Any]
)
cur_params = unflatten_dict(flat_params, dict_factory=dict[str, Any])
callback(
CallbackInfo(
opt_params=opt_params,
opt_loss=opt_loss,
cur_params=cur_params,
cur_loss=cur_loss,
step=i,
info=self.info,
)
)
recommendation = optimizer.provide_recommendation()
args, kwargs = recommendation.value
# Our optimal params are the first positional argument
flat_opt_params = args[0]
# loss is an optional field in the recommendation so we have to test if it has been written
if recommendation.loss is not None:
self.info.final_value = recommendation.loss
else: # otherwise we compute the optimal loss
self.info.final_value = self.objective_function(flat_opt_params)
opt_params = unflatten_dict(flat_opt_params, dict_factory=dict[str, Any])
self.hook_post_fit(opt_params)
return opt_params
[docs]
def fit_scipy(self, method: str = "L-BFGS-B", **kwargs) -> dict[str, Any]:
"""
Optimize parameters using SciPy's minimize function.
Parameters
----------
initial_parameters : dict
Initial guess for each parameter, as a mapping from name to value.
**kwargs
Additional keyword arguments passed directly to scipy.optimize.minimize.
Returns
-------
dict
Dictionary of optimized parameter values.
Warnings
--------
If the optimizer does not converge, a warning is logged.
Example
-------
>>> def objective_function(idx: int, params: dict):
... return 2.0 * (params["x"] - 2) ** 2 + 3.0 * (params["y"] + 1) ** 2
>>> fitter = Fitter(objective_function=objective_function)
>>> initial_params = dict(x=0.0, y=0.0)
>>> optimal_params = fitter.fit_scipy(initial_parameters=initial_params)
>>> print(optimal_params)
{'x': 2.0, 'y': -1.0}
"""
self.hook_pre_fit()
# Scipy expects a function with n real-valued parameters f(x)
# but our objective function takes a dictionary of parameters.
# Moreover, the dictionary might not be flat but nested.
# Therefore, as a first step, we flatten the bounds and
# initial parameter dicts
flat_params = flatten_dict(self.initial_parameters)
flat_bounds = flatten_dict(self.bounds)
# We then capture the order of keys in the flattened dictionary
self._keys = flat_params.keys()
# The initial value of x and of the bounds are derived from that order
x0 = np.array([flat_params[k] for k in self._keys])
if len(flat_bounds) == 0:
bounds = None
else:
bounds = np.array([flat_bounds.get(k, (None, None)) for k in self._keys])
# The local objective function first creates a flat dictionary from the `x` array
# by zipping it with the captured flattened keys and then unflattens the dictionary
# to pass it to the objective functions
def f_scipy(x: npt.NDArray) -> float:
p = unflatten_dict(dict(zip(self._keys, x)), dict_factory=dict[str, Any])
return self.objective_function(p)
# Then we need to handle some awkwardness:
# 1. Scipy does not mandate all of the optimizers
# to write all the values we need for our callback system.
# Therefore, we need to roll our own bookkeeping logic for the
# number of steps taken.
# 2. Scipy mandates a different function signature, so we have to "translate"
# We do this in the following functor:
class CallbackScipy:
def __init__(
self,
keys: list[str],
info: FitInfo,
callback: Callable[[CallbackInfo], None],
n_steps: int,
) -> None:
self._step: int = 0
self._keys = keys
self._info = info
self._callback = callback
self._n_steps: int = n_steps
def __call__(self, intermediate_result: OptimizeResult):
# This callback is executed after *every* iteration
# We may have to track the step ourselves
self._step += 1
# If we are given "nit", we use it instead
if "nit" in intermediate_result:
self._step = intermediate_result.nit
if self._step % self._n_steps == 0:
x = intermediate_result.x
cur_params = unflatten_dict(dict(zip(self._keys, x)))
cur_loss = intermediate_result.fun
# We assume (can be wrong though)
opt_params = cur_params
opt_loss = cur_loss
self._callback(
CallbackInfo(
opt_params=opt_params,
opt_loss=opt_loss,
cur_params=cur_params,
cur_loss=cur_loss,
step=self._step,
info=self._info,
)
)
# First concatenate the list of callbacks into a single function
callback, n_steps = self._produce_callback()
# Then, we wrap it in a way that scipy understands
if callback is not None:
callback_scipy = CallbackScipy(
keys=list(self._keys),
info=self.info,
callback=callback,
n_steps=n_steps,
)
else:
callback_scipy = None
# ob = partial(self.ob_func_wrapper, ob_func=f_scipy)
res = minimize(
f_scipy, x0, method=method, bounds=bounds, **kwargs, callback=callback_scipy
)
if not res.success:
logger.warning(f"Fit did not converge: {res.message}")
self.info.final_value = res.fun
opt_params = dict(zip(self._keys, res.x))
opt_params = unflatten_dict(opt_params)
self.hook_post_fit(opt_params)
return opt_params