Source code for chemfit.utils
from __future__ import annotations
import json
from pathlib import Path
from typing import Any
import numpy as np
from pydictnest import flatten_dict
[docs]
def check_protocol(obj: Any | None, prot: Any):
if obj is not None and not isinstance(obj, prot):
msg = f"{obj} does not implement the {prot} protocol"
raise Exception(msg)
[docs]
def next_free_folder(base: Path) -> Path:
"""If 'path/to/base' does not exist, return 'path/to/base'. Otherwise attempt 'path/to/base_0', 'path/to/base_1', etc. until finding a non-existent Path, then return that."""
base = Path(base)
if not base.exists():
return base
i = 0
while True:
candidate = base.with_name(f"{base.name}_{i}")
if not candidate.exists():
return candidate
i += 1
[docs]
class ExtendedJSONEncoder(json.JSONEncoder):
[docs]
def default(self, o: Any):
if isinstance(o, Path):
return str(o)
super().default(o)
return None
[docs]
def dump_dict_to_file(file: Path, dictionary: dict) -> None:
"""Write `dictionary` as JSON to `file` (with indent=4)."""
file.parent.mkdir(exist_ok=True, parents=True)
with file.open("w") as f:
json.dump(dictionary, f, indent=4, cls=ExtendedJSONEncoder)
[docs]
def check_params_near_bounds(
params: dict[str, Any],
bounds: dict[str, Any],
relative_tol: float,
) -> list[tuple[str, float, float, float]]:
"""
Check if any of the parameters are near or beyond the bounds.
The criterions checked are
1. param_value < lower + relative_tol * (upper - lower)
2. param_value > upper - relative_tol * (upper - lower)
Args:
params(dict): the dict of params to check
bounds(dict): the dict of bounds to check
relative_tol(float):
The tolerance, relative to the span of the bounds.
Positive numbers mean the values must fulfill a stricter bound
Zero means the values must fulfill the exact bound
Negative numbers mean the values must fulfill a looser bound
Returns:
A list of tuples with information about parameters, which violate the constraint.
Each tuple contains
- A string identifying the parameter in a flattened dict
- The value of the parameter
- The lower bound
- The upper bound
"""
flat_params = flatten_dict(params)
flat_bounds = flatten_dict(bounds)
problematic_params = []
for kp, vp in flat_params.items():
# Get the bounds (if they are not specified, we set them both to None)
lower, upper = flat_bounds.get(kp, (None, None))
if lower is not None and upper is not None:
abs_tol = relative_tol * np.abs(upper - lower)
if vp < lower + abs_tol or vp > upper - abs_tol:
problematic_params.append((kp, vp, lower, upper))
return problematic_params