Source code for chemfit.ase_objective_function

from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable

from ase import Atoms
from ase.calculators.calculator import Calculator
from ase.io import read
from ase.optimize import BFGS

from chemfit.abstract_objective_function import EvaluateContext, QuantityComputer
from chemfit.utils import check_protocol

if TYPE_CHECKING:
    from pathlib import Path

logger = logging.getLogger(__name__)


[docs] @runtime_checkable class CalculatorFactory(Protocol): """ Protocol for a callable that attaches an ASE calculator to atoms. Implementations are expected to construct or configure a calculator for the given ``Atoms`` object and assign it to ``atoms.calc``. """
[docs] def __call__(self, atoms: Atoms) -> None: """Construct a calculator and overwrite `atoms.calc`.""" ...
[docs] @runtime_checkable class ParameterApplier(Protocol): """Protocol for a callable that applies parameters to an ASE calculator."""
[docs] def __call__(self, atoms: Atoms, params: dict[str, Any]) -> None: """Applies a parameter dictionary to `atoms.calc` in-place.""" ...
[docs] @runtime_checkable class AtomsPostProcessor(Protocol): """Protocol for a callable that post-processes an ASE Atoms object."""
[docs] def __call__(self, atoms: Atoms) -> None: """Modify the atoms in-place.""" ...
[docs] @runtime_checkable class AtomsFactory(Protocol): """Protocol for a function that creates an ASE Atoms object."""
[docs] def __call__(self) -> Atoms: """Create an atoms object.""" ...
[docs] @runtime_checkable class QuantityProcessor(Protocol): """ Protocol for a callable that extracts quantities from an ASE evaluation. A quantity processor is called after the calculator has evaluated an ``Atoms`` object. It receives the calculator and atoms pair and returns a dictionary of quantities to include in the final output. """
[docs] def __call__(self, calc: Calculator, atoms: Atoms) -> dict[str, Any]: """ Extract quantities from an evaluated calculator and atoms pair. Args: calc: Calculator that has already evaluated ``atoms``. atoms: Evaluated atoms object. Returns: A dictionary of extracted quantities. """ ...
[docs] class PathAtomsFactory(AtomsFactory): """Atoms factory that reads a single structure from a filesystem path."""
[docs] def __init__(self, path: Path, index: int | None = None) -> None: """ Initialize the factory. Args: path: Path to a structure file readable by ASE. index: Optional ASE index selecting which image to read. The selection must resolve to a single ``Atoms`` object. """ self.path = path self.index = index
[docs] def __call__(self) -> Atoms: atoms = read(self.path, self.index, parallel=False) if isinstance(atoms, list): msg = f"Index {self.index} selects multiple images from path {self.path}. This is not compatible with AtomsFactory." raise Exception(msg) return atoms
[docs] class DefaultQuantityProcessor:
[docs] def __init__(self, filter_keys: list[str] | None = None) -> None: """ Initialize a default quantity processor, that returns all of the `results` of the calculator. The returned quantity dictionary contains all entries from ``calc.results`` plus ``"n_atoms"``. Any keys listed in ``filter_keys`` are excluded from the returned dictionary. Args: filter_keys: Optional list of keys to exclude from the returned quantity dictionary. """ self.filter_keys = filter_keys
[docs] def __call__(self, calc: Calculator, atoms: Atoms) -> dict[str, Any]: res = {**calc.results, "n_atoms": len(atoms)} if self.filter_keys is not None: [res.pop(k) for k in self.filter_keys] return res
[docs] class SinglePointASEComputer(QuantityComputer): """ ASE-based quantity computer for single-point evaluations. This class evaluates quantities for a parameterized ASE calculation using an atoms factory, optional atoms post-processing, a calculator factory, a parameter applier, and one or more quantity processors. """
[docs] def __init__( self, calc_factory: CalculatorFactory, param_applier: ParameterApplier, atoms_factory: AtomsFactory, atoms_post_processor: AtomsPostProcessor | None = None, quantity_processors: list[QuantityProcessor] | None = None, tag: str | None = None, ) -> None: """ Initialize the computer. Args: calc_factory: Callable that attaches a calculator to an ``Atoms`` object. param_applier: Callable that applies a parameter dictionary to the calculator attached to an ``Atoms`` object. atoms_factory: Callable that creates the base ``Atoms`` object. atoms_post_processor: Optional callable that modifies the base atoms object before it is cached and copied for evaluation. quantity_processors: Optional list of callables that extract quantities from the evaluated calculator and atoms pair. If ``None``, a ``DefaultQuantityProcessor`` is used. tag: Optional label for this computer. If ``None``, ``"tag_None"`` is used. """ super().__init__() # Make sure all the protocols are properly implemented check_protocol(calc_factory, CalculatorFactory) check_protocol(param_applier, ParameterApplier) check_protocol(atoms_factory, AtomsFactory) check_protocol(atoms_post_processor, AtomsPostProcessor) self.calc_factory = calc_factory self.param_applier = param_applier self.atoms_factory = atoms_factory self.atoms_post_processor = atoms_post_processor if quantity_processors is None: self.quantity_processors: list[QuantityProcessor] = [ DefaultQuantityProcessor() ] else: self.quantity_processors = quantity_processors for qp in self.quantity_processors: check_protocol(qp, QuantityProcessor) self.tag = tag or "tag_None" self._atoms: Atoms | None = None self.static_meta_data = { "tag": self.tag, "type": type(self).__name__, }
[docs] def prepare_ctx(self, parameters: dict[str, Any], ctx: EvaluateContext): """ Prepare the evaluation context for a single-point calculation. This method lazily creates and caches a base atoms object using ``atoms_factory``. If provided, ``atoms_post_processor`` is applied once to that base object before it is cached. For each evaluation, the cached atoms object is copied into ``ctx.temp.atoms``, a fresh calculator is attached, and the provided parameters are applied. Args: parameters: Parameter dictionary for the current evaluation. ctx: Evaluation context to populate. """ if self._atoms is None: self._atoms = self.atoms_factory() if self.atoms_post_processor is not None: self.atoms_post_processor(self._atoms) # Since the calculation may change the internal state of the calculator # we create a new calculator and a new atoms object in the context ctx.temp.atoms = self._atoms.copy() self.calc_factory(ctx.temp.atoms) self.param_applier(ctx.temp.atoms, parameters)
def _compute( self, parameters: dict[str, Any], ctx: EvaluateContext, ) -> dict[str, Any]: """ Compute quantities from a single-point ASE evaluation. This implementation prepares the evaluation context, runs the calculator on the atoms object, and merges the outputs of all configured quantity processors. Args: parameters: Mapping of parameter names to parameter values. ctx: Evaluation context for the current call. Returns: A dictionary containing the merged quantities returned by the configured quantity processors. """ self.prepare_ctx(parameters, ctx) assert ctx.temp.atoms.calc is not None ctx.temp.atoms.calc.calculate(ctx.temp.atoms) quants = {} for qp in self.quantity_processors: quants.update(qp(ctx.temp.atoms.calc, ctx.temp.atoms)) return quants
[docs] class MinimizationASEComputer(SinglePointASEComputer): """ ASE-based quantity computer using a locally optimized structure. This computer evaluates quantities after performing a local geometry optimization using the ASE BFGS optimizer. Quantities are extracted from the relaxed structure using the configured quantity processors. """
[docs] def __init__( self, dt: float = 1e-2, fmax: float = 1e-5, max_steps: int = 2000, **kwargs ) -> None: """ Initialize a MinimizationASEComputer. All additional keyword arguments are forwarded to ``SinglePointASEComputer.__init__``. Args: dt: Relaxation step-size parameter retained for compatibility with earlier implementations. Currently unused. fmax: Force convergence criterion passed to the optimizer. max_steps: Maximum number of optimization steps. **kwargs: Additional keyword arguments forwarded to the parent initializer. """ self.dt = dt self.fmax = fmax self.max_steps = max_steps super().__init__(**kwargs)
def _compute( self, parameters: dict[str, Any], ctx: EvaluateContext ) -> dict[str, Any]: """ Compute quantities after local geometry optimization. This method prepares the evaluation context, performs a geometry optimization using ASE's BFGS optimizer, and extracts quantities from the relaxed structure using the configured quantity processors. Args: parameters: Mapping of parameter names to parameter values. ctx: Evaluation context for the current call. Returns: A dictionary containing the merged quantities returned by the configured quantity processors. Side Effects: - Creates and stores an ``Atoms`` object in ``ctx.temp.atoms``. - Attaches a fresh calculator to the atoms object. """ self.prepare_ctx(parameters, ctx) optimizer = BFGS(ctx.temp.atoms, logfile=None) optimizer.run(fmax=self.fmax, steps=self.max_steps) quants = {} for qp in self.quantity_processors: quants.update(qp(ctx.temp.atoms.calc, ctx.temp.atoms)) return quants