Source code for metatomic_torchsim._model
"""TorchSim wrapper for metatomic atomistic models.
Adapts metatomic models to the TorchSim ModelInterface protocol, allowing them to
be used within the torch-sim simulation framework for MD and other simulations.
Supports batched computations for multiple systems simultaneously, computing
energies, forces, and stresses via autograd. Also supports output variants,
non-conservative forces/stress, energy uncertainty warnings, and additional
model outputs.
"""
import logging
import os
import pathlib
import warnings
from typing import Dict, List, Optional, Union
import torch
from metatensor.torch import TensorMap
from metatomic.torch import (
AtomisticModel,
ModelEvaluationOptions,
ModelOutput,
System,
load_atomistic_model,
pick_device,
pick_output,
)
from ._neighbors import _compute_requested_neighbors
try:
import torch_sim as ts
from torch_sim.models.interface import ModelInterface
except ImportError as e:
raise ImportError(
"the torch_sim package is required for metatomic-torchsim: "
"pip install torch-sim-atomistic"
) from e
FilePath = Union[str, bytes, pathlib.PurePath]
LOGGER = logging.getLogger(__name__)
STR_TO_DTYPE = {
"float32": torch.float32,
"float64": torch.float64,
}
[docs]
class MetatomicModel(ModelInterface):
"""TorchSim wrapper for metatomic atomistic models.
Wraps a metatomic model to compute energies, forces, and stresses within the
TorchSim framework. Handles the translation between TorchSim's batched
``SimState`` and metatomic's list-of-``System`` convention, and uses autograd
for force/stress derivatives.
Neighbor lists are computed with vesin, or with nvalchemiops on CUDA when
available and the model requests full neighbor lists.
"""
def __init__(
self,
model: Union[FilePath, AtomisticModel, "torch.jit.RecursiveScriptModule"],
*,
extensions_directory: Optional[FilePath] = None,
device: Optional[Union[torch.device, str]] = None,
check_consistency: bool = False,
compute_forces: bool = True,
compute_stress: bool = True,
variants: Optional[Dict[str, Optional[str]]] = None,
non_conservative: bool = False,
uncertainty_threshold: Optional[float] = 0.1,
additional_outputs: Optional[Dict[str, ModelOutput]] = None,
) -> None:
"""
:param model: Model to use. Accepts a file path to a ``.pt`` saved
model, a Python :py:class:`AtomisticModel` instance, or a
TorchScript :py:class:`torch.jit.RecursiveScriptModule`.
:param extensions_directory: Directory containing compiled TorchScript
extensions required by the model, if any.
:param device: Torch device for evaluation. When ``None``, the best
device is selected from the model's ``supported_devices``.
:param check_consistency: Run consistency checks during model evaluation.
Useful for debugging but hurts performance.
:param compute_forces: Compute atomic forces via autograd.
:param compute_stress: Compute stress tensors via the strain trick.
:param variants: Dictionary mapping output names to a variant that should
be used. Setting ``{"energy": "pbe"}`` selects the ``"energy/pbe"``
output. The energy variant propagates to uncertainty and
non-conservative outputs unless overridden (e.g.
``{"energy": "pbe", "energy_uncertainty": "r2scan"}`` would select
``energy/pbe`` and ``energy_uncertainty/r2scan``).
:param non_conservative: If ``True``, the model will be asked to compute
non-conservative forces and stresses. This can afford a speed-up,
potentially at the expense of physical correctness (especially in
molecular dynamics simulations).
:param uncertainty_threshold: Threshold for per-atom energy uncertainty
in eV. When the model supports ``energy_uncertainty`` with
``per_atom=True``, atoms exceeding this threshold trigger a warning.
Set to ``None`` to disable.
:param additional_outputs: Dictionary of extra :py:class:`ModelOutput`
to request from the model. Results are stored in
:py:attr:`additional_outputs` after each forward call.
"""
super().__init__()
self._check_consistency = check_consistency
# Load the model, following the same patterns as ase_calculator.py
if isinstance(model, (str, bytes, pathlib.PurePath)):
model_path = str(model)
if not os.path.exists(model_path):
raise ValueError(f"given model path '{model_path}' does not exist")
model = load_atomistic_model(
model_path, extensions_directory=extensions_directory
)
elif isinstance(model, torch.jit.RecursiveScriptModule):
if model.original_name != "AtomisticModel":
raise TypeError(
"torch model must be 'AtomisticModel', "
f"got '{model.original_name}' instead"
)
elif isinstance(model, AtomisticModel):
pass
else:
raise TypeError(f"unknown type for model: {type(model)}")
capabilities = model.capabilities()
# Resolve device
if device is not None:
if isinstance(device, str):
device = torch.device(device)
self._device = device
else:
self._device = torch.device(
pick_device(capabilities.supported_devices, None)
)
# Resolve dtype from model capabilities
if capabilities.dtype in STR_TO_DTYPE:
self._dtype = STR_TO_DTYPE[capabilities.dtype]
else:
raise ValueError(
f"unexpected dtype in model capabilities: {capabilities.dtype}"
)
# Resolve output keys based on requested variants
variants = variants or {}
default_variant = variants.get("energy")
resolved_variants = {
key: variants.get(key, default_variant)
for key in [
"energy",
"energy_uncertainty",
"non_conservative_forces",
"non_conservative_stress",
]
}
outputs = capabilities.outputs
has_energy = any(
"energy" == key or key.startswith("energy/") for key in outputs.keys()
)
if not has_energy:
raise ValueError(
"model does not have an 'energy' output. "
"Only models with energy outputs can be used with TorchSim."
)
self._energy_key = pick_output("energy", outputs, resolved_variants["energy"])
# Uncertainty
has_energy_uq = any("energy_uncertainty" in key for key in outputs.keys())
if has_energy_uq and uncertainty_threshold is not None:
self._energy_uq_key = pick_output(
"energy_uncertainty",
outputs,
resolved_variants["energy_uncertainty"],
)
else:
self._energy_uq_key = None
# Non-conservative outputs
self._non_conservative = non_conservative
if non_conservative:
if (
"non_conservative_stress" in variants
and "non_conservative_forces" in variants
and (
(variants["non_conservative_stress"] is None)
!= (variants["non_conservative_forces"] is None)
)
):
raise ValueError(
"if both 'non_conservative_stress' and "
"'non_conservative_forces' are present in `variants`, they "
"must either be both `None` or both not `None`."
)
self._nc_forces_key = pick_output(
"non_conservative_forces",
outputs,
resolved_variants["non_conservative_forces"],
)
self._nc_stress_key = pick_output(
"non_conservative_stress",
outputs,
resolved_variants["non_conservative_stress"],
)
else:
self._nc_forces_key = None
self._nc_stress_key = None
# Additional outputs
if additional_outputs is None:
self._additional_output_requests: Dict[str, ModelOutput] = {}
else:
assert isinstance(additional_outputs, dict)
for name, output in additional_outputs.items():
assert isinstance(name, str)
assert isinstance(output, torch.ScriptObject), (
"outputs must be ModelOutput instances"
)
self._additional_output_requests = additional_outputs
self._model = model.to(device=self._device)
self._compute_forces = compute_forces
self._compute_stress = compute_stress
self._uncertainty_threshold = uncertainty_threshold
self._calculate_uncertainty = (
self._energy_uq_key in outputs
and outputs[self._energy_uq_key].per_atom
and uncertainty_threshold is not None
)
if self._calculate_uncertainty:
if uncertainty_threshold <= 0.0:
raise ValueError(
f"`uncertainty_threshold` is {uncertainty_threshold} but must "
"be positive"
)
self._requested_neighbor_lists = self._model.requested_neighbor_lists()
self._requested_inputs = self._model.requested_inputs()
if len(self._requested_inputs) != 0:
raise ValueError(
"this model requests extra inputs "
f"({', '.join(self._requested_inputs.keys())}), which are not "
"implemented in metatomic-torchsim. Please open an issue if "
"you need them!"
)
# Precompute the outputs dict (immutable after __init__)
run_outputs: Dict[str, ModelOutput] = {
self._energy_key: ModelOutput(quantity="energy", unit="eV", per_atom=False),
}
if self._calculate_uncertainty:
run_outputs[self._energy_uq_key] = ModelOutput(
quantity="energy", unit="eV", per_atom=True
)
if self._non_conservative:
if self._compute_forces:
run_outputs[self._nc_forces_key] = ModelOutput(
quantity="force", unit="eV/Angstrom", per_atom=True
)
if self._compute_stress:
run_outputs[self._nc_stress_key] = ModelOutput(
quantity="pressure", unit="eV/Angstrom^3", per_atom=False
)
run_outputs.update(self._additional_output_requests)
self._evaluation_options = ModelEvaluationOptions(
length_unit="angstrom",
outputs=run_outputs,
)
self.additional_outputs: Dict[str, TensorMap] = {}
"""
Additional outputs computed by :py:meth:`forward` are stored here.
Keys match the ``additional_outputs`` parameter to the constructor;
values are raw :py:class:`metatensor.torch.TensorMap` from the model.
"""
[docs]
def forward(self, state: "ts.SimState") -> Dict[str, torch.Tensor]:
"""Compute energies, forces, and stresses for the given simulation state.
:param state: TorchSim simulation state
:returns: Dictionary with ``"energy"`` (shape ``[n_systems]``),
``"forces"`` (shape ``[n_atoms, 3]``, if ``compute_forces``), and
``"stress"`` (shape ``[n_systems, 3, 3]``, if ``compute_stress``).
"""
positions = state.positions
cell = state.row_vector_cell
atomic_nums = state.atomic_numbers
if positions.dtype != self._dtype:
raise TypeError(
f"positions dtype {positions.dtype} does not match "
f"model dtype {self._dtype}"
)
# Determine whether autograd is needed
do_autograd_forces = self._compute_forces and not self._non_conservative
do_autograd_stress = self._compute_stress and not self._non_conservative
# Build per-system System objects. Metatomic expects a list of System
# rather than a single batched graph.
systems: List[System] = []
strains: List[torch.Tensor] = []
n_systems = len(cell)
pbc = state.pbc
if isinstance(pbc, bool):
pbc = torch.tensor([pbc, pbc, pbc])
elif not isinstance(pbc, torch.Tensor):
pbc = torch.tensor(pbc)
for sys_idx in range(n_systems):
mask = state.system_idx == sys_idx
sys_positions = positions[mask]
sys_cell = cell[sys_idx]
sys_types = atomic_nums[mask]
if do_autograd_forces:
sys_positions = sys_positions.detach().requires_grad_(True)
if do_autograd_stress:
strain = torch.eye(
3,
device=self._device,
dtype=self._dtype,
requires_grad=True,
)
sys_positions = sys_positions @ strain
sys_cell = sys_cell @ strain
strains.append(strain)
systems.append(
System(
positions=sys_positions,
types=sys_types,
cell=sys_cell,
pbc=pbc,
)
)
# Compute neighbor lists
systems = _compute_requested_neighbors(
systems=systems,
requested_options=self._requested_neighbor_lists,
check_consistency=self._check_consistency,
)
# Run the model (evaluation options precomputed in __init__)
model_outputs = self._model(
systems=systems,
options=self._evaluation_options,
check_consistency=self._check_consistency,
)
energy_values = model_outputs[self._energy_key].block().values
results: Dict[str, torch.Tensor] = {}
results["energy"] = energy_values.detach().squeeze(-1)
# Uncertainty warning
if self._calculate_uncertainty:
uncertainty = model_outputs[self._energy_uq_key].block().values
n_total_atoms = positions.shape[0]
if uncertainty.shape != (n_total_atoms, 1):
raise ValueError(
f"expected uncertainty shape ({n_total_atoms}, 1), "
f"got {uncertainty.shape}"
)
threshold = self._uncertainty_threshold
if torch.any(uncertainty > threshold):
exceeded = torch.where(uncertainty.squeeze(-1) > threshold)[0]
atom_list = exceeded.tolist()
if len(atom_list) > 20:
atom_list = atom_list[:20]
suffix = f" (and {len(exceeded) - 20} more)"
else:
suffix = ""
warnings.warn(
"Some of the atomic energy uncertainties are larger than the "
f"threshold of {threshold} eV. The prediction is above the "
f"threshold for atoms {atom_list}{suffix}.",
stacklevel=2,
)
# Forces and stresses
if self._non_conservative:
if self._compute_forces:
nc_forces = model_outputs[self._nc_forces_key].block().values.detach()
nc_forces = nc_forces.reshape(-1, 3)
# Remove spurious net force per system
for sys_idx in range(n_systems):
mask = state.system_idx == sys_idx
sys_forces = nc_forces[mask]
nc_forces[mask] = sys_forces - sys_forces.mean(dim=0, keepdim=True)
results["forces"] = nc_forces
if self._compute_stress:
nc_stress = model_outputs[self._nc_stress_key].block().values.detach()
nc_stress = nc_stress.reshape(n_systems, 3, 3)
results["stress"] = nc_stress
elif do_autograd_forces or do_autograd_stress:
grad_inputs: List[torch.Tensor] = []
if do_autograd_forces:
for system in systems:
grad_inputs.append(system.positions)
if do_autograd_stress:
grad_inputs.extend(strains)
grads = torch.autograd.grad(
outputs=energy_values,
inputs=grad_inputs,
grad_outputs=torch.ones_like(energy_values),
)
if do_autograd_forces and do_autograd_stress:
n_sys = len(systems)
force_grads = grads[:n_sys]
stress_grads = grads[n_sys:]
elif do_autograd_forces:
force_grads = grads
stress_grads = ()
else:
force_grads = ()
stress_grads = grads
if do_autograd_forces:
results["forces"] = torch.cat([-g for g in force_grads])
if do_autograd_stress:
results["stress"] = torch.stack(
[
g / torch.abs(torch.det(system.cell.detach()))
for g, system in zip(stress_grads, systems, strict=True)
]
)
# Store additional outputs
self.additional_outputs = {}
for name in self._additional_output_requests:
self.additional_outputs[name] = model_outputs[name]
return results