from __future__ import annotations
import logging
import math
from enum import Enum
from typing import Any
from mpi4py import MPI
from chemfit.abstract_objective_function import EvaluateContext, ObjectiveFunctor
from chemfit.combined_objective_function import CombinedObjectiveFunction
from chemfit.debug_utils import log_all_methods
logger = logging.getLogger(__name__)
[docs]
def slice_up_range(n: int, n_ranks: int):
"""
Split a range of length ``n`` into contiguous rank-local chunks.
The chunks are distributed as evenly as possible across ``n_ranks``
by using a ceiling-based chunk size.
Args:
n: Total number of items to distribute.
n_ranks: Number of MPI ranks.
Yields:
Tuples ``(start, end)`` defining half-open index ranges for
each rank.
"""
chunk_size = math.ceil(n / n_ranks)
for rank in range(n_ranks):
start = rank * chunk_size
end = min(start + chunk_size, n)
yield (start, end)
[docs]
class Signal(Enum):
ABORT = -1
[docs]
class MPIWrapperCOB(ObjectiveFunctor):
"""
MPI-based wrapper for ``CombinedObjectiveFunction``.
This wrapper distributes the terms of a combined objective across
MPI ranks. Rank 0 acts as the driver rank: it broadcasts the
evaluation context to all worker ranks, evaluates its own local
slice of terms, gathers the worker results, re-raises any worker
exceptions, collects child metadata, and applies the wrapped
combined objective's reduction.
Worker ranks do not call ``__call__`` directly. Instead, they run
``worker_loop()``, which waits for broadcast evaluation requests
from rank 0 and processes the local slice assigned to that rank.
"""
[docs]
def __init__(
self,
cob: CombinedObjectiveFunction,
comm: Any | None = None,
mpi_debug_log: bool = False,
) -> None:
"""
Initialize an MPI wrapper for a combined objective.
Args:
cob: Combined objective function whose terms are distributed
across MPI ranks.
comm: MPI communicator to use. If ``None``, a duplicate of
``MPI.COMM_WORLD`` is created.
mpi_debug_log: If ``True``, wrap the communicator so that MPI
method calls are logged for debugging.
Notes:
Each rank is assigned a contiguous slice of objective terms at
initialization time.
"""
self.cob = cob
if comm is None:
self.comm = MPI.COMM_WORLD.Dup()
else:
self.comm = comm
self.rank = self.comm.Get_rank()
self.size = self.comm.Get_size()
if mpi_debug_log:
self.comm = log_all_methods(
self.comm,
log_func=self._log_func,
log_args=True,
log_res=True,
)
self.start, self.end = list(slice_up_range(self.cob.n_terms(), self.size))[
self.rank
]
def _log_func(self, msg: str):
logger.warning(f"[Rank {self.rank}] {msg}")
def __enter__(self):
return self
[docs]
def shifted_child_context_configurator(
self,
idx_child_ctx: int,
child_ctx: EvaluateContext,
num_children: int, # noqa: ARG002
parent_ctx: EvaluateContext,
):
"""
Invoke the parents child context configurator, while accounting for the slicing.
The goal of this function is to lead to the same behaviour as on the original combined objective function.
This means the child_context_configurator has to "see" the idx of the current child not within the current slice, but
the absolute index.
Fort the same reason we overwrite the number of children.
Args:
idx_child_ctx (int): The index of the current child context *within* the slice
child_ctx (EvaluateContext): The child context
num_children (int): The number of children *within* the current slice
parent_ctx (EvaluateContext): The parent context.
"""
if self.cob.child_context_configurator is not None:
self.cob.child_context_configurator(
idx_child_ctx=idx_child_ctx + self.start,
child_ctx=child_ctx,
num_children=self.cob.n_terms(),
parent_ctx=parent_ctx,
)
[docs]
def evaluate_slice(
self, params: dict[str, Any], ctx: EvaluateContext
) -> list[float | None]:
idx_slice = slice(self.start, self.end)
selected_indices = range(self.cob.n_terms())[idx_slice]
local_terms = []
with ctx.child_contexts(
len(selected_indices), configurator=self.shifted_child_context_configurator
) as contexts:
for idx, ctx_term in zip(selected_indices, contexts, strict=True):
try:
local_terms.append(self.cob.evaluate_term(params, idx, ctx_term))
except Exception as e: # noqa: PERF203
# If we catch an exception we log it and append it to the local terms
# It will be sent to master and raised there
logger.exception(e)
local_terms.append(e)
return local_terms
[docs]
def worker_process_params(self, params: dict[str, Any], ctx: EvaluateContext):
# In the usual use-case the worker loop will be the top-level context for the worker ranks.
local_terms = self.evaluate_slice(params, ctx)
# Finally, we have to run the gather
# This must always happen, otherwise, we might cause deadlocks because other ranks might wait on a reduce.
# Sum up all local_totals into a global_total on the master rank
_ = self.comm.gather(local_terms, root=0)
[docs]
def worker_loop(self):
"""
Run the worker-side MPI evaluation loop.
This method must be called only on nonzero ranks. Each worker rank
waits for broadcast messages from rank 0. On receiving an
``EvaluateContext``, it evaluates its assigned slice of the
combined objective and gathers both term values and child metadata
back to rank 0. On receiving ``Signal.ABORT``, the loop exits.
Raises:
RuntimeError: If called on rank 0.
"""
# Ensure only rank 0 can call this
if self.rank == 0:
msg = "`worker_loop` cannot be used on rank 0"
raise RuntimeError(msg)
# Worker loop: wait for params, compute slice+reduce, repeat
while True:
# receive a signal from rank 0
signal = self.comm.bcast(None, root=0)
if signal == Signal.ABORT:
break
if isinstance(signal, EvaluateContext):
assert signal.parameters is not None
params: dict[str, Any] = signal.parameters
ctx = signal
self.worker_process_params(params, ctx)
self.worker_gather_meta_data(ctx)
[docs]
def __call__(
self, params: dict[str, Any], ctx: EvaluateContext | None = None
) -> float:
"""
Evaluate the combined objective on rank 0 using MPI.
Rank 0 broadcasts the evaluation context to all worker ranks,
evaluates its own assigned slice locally, gathers the worker term
values, re-raises any worker exceptions, gathers child metadata
from all ranks, and reduces the full list of term values using the
wrapped combined objective's reduction function.
Args:
params: Parameter dictionary for the current evaluation.
ctx: Optional parent evaluation context. If ``None``, a new
``EvaluateContext`` is created.
Returns:
Reduced scalar loss value.
Raises:
RuntimeError: If called on a nonzero rank.
Exception: Re-raises any exception returned from a worker rank.
Side Effects:
- Stores ``params`` in ``ctx.parameters``.
- Broadcasts the evaluation context to all worker ranks.
- Collects child metadata into ``ctx.meta["children"]``.
- Stores the final reduced loss in ``ctx.loss``.
"""
# Function to evaluate the objective function, to be called from rank 0
if ctx is None:
ctx = EvaluateContext()
ctx.parameters = params
# Ensure only rank 0 can call this
if self.rank != 0:
msg = "`__call__` can only be used on rank 0"
raise RuntimeError(msg)
# Broadcast the params to the worker ranks
self.comm.bcast(ctx, root=0)
local_terms: list[
float | None
] = [] # So we get NaN in case the local compute fails
try:
# Compute one slice of the objective function on the main rank
local_terms = self.evaluate_slice(params, ctx=ctx)
finally:
# Finally, we have to run the reduce. This must always happen since, otherwise, we might cause deadlocks
# Sum up all local_totals into a global_total on every rank
gathered_terms = self.comm.gather(local_terms, root=0)
self.gather_meta_data(ctx)
# Since gathered will now be a list of list, we unpack it
terms: list[float | None] = []
if gathered_terms is not None:
[terms.extend(m) for m in gathered_terms]
# If any exceptions from worker loops were sent to us we re-raise it
for t in terms:
if isinstance(t, Exception):
raise t
filtered_terms = self.cob.filter_terms(terms, ctx)
ctx.loss = self.cob.apply_reduction(filtered_terms, ctx)
return ctx.loss
[docs]
def release_workers(self):
# Only rank 0 needs to shut down workers
if self.rank == 0 and self.size > 1:
# send the poison-pill (None) so workers break out
self.comm.bcast(Signal.ABORT, root=0)
def __exit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
tb: object,
):
self.release_workers()