Source code for metatomic_torchsim._model
"""TorchSim wrapper for metatomic atomistic models.
Adapts metatomic models to the TorchSim ModelInterface protocol, allowing them to
be used within the torch-sim simulation framework for MD and other simulations.
Supports batched computations for multiple systems simultaneously, computing
energies, forces, and stresses via autograd.
"""
import logging
import os
import pathlib
from typing import Dict, List, Optional, Union
import torch
import vesin.metatomic
from metatensor.torch import Labels, TensorBlock
from metatomic.torch import (
AtomisticModel,
ModelEvaluationOptions,
ModelOutput,
NeighborListOptions,
System,
load_atomistic_model,
pick_device,
)
try:
from nvalchemiops.torch.neighbors import neighbor_list as nvalchemi_neighbor_list
HAS_NVALCHEMIOPS = True
except ImportError:
HAS_NVALCHEMIOPS = False
try:
import torch_sim as ts
from torch_sim.models.interface import ModelInterface
except ImportError as e:
raise ImportError(
"the torch_sim package is required for metatomic-torchsim: "
"pip install torch-sim-atomistic"
) from e
FilePath = Union[str, bytes, pathlib.PurePath]
LOGGER = logging.getLogger(__name__)
STR_TO_DTYPE = {
"float32": torch.float32,
"float64": torch.float64,
}
[docs]
class MetatomicModel(ModelInterface):
"""TorchSim wrapper for metatomic atomistic models.
Wraps a metatomic model to compute energies, forces, and stresses within the
TorchSim framework. Handles the translation between TorchSim's batched
``SimState`` and metatomic's list-of-``System`` convention, and uses autograd
for force/stress derivatives.
Neighbor lists are computed with vesin, or with nvalchemiops on CUDA when
available and the model requests full neighbor lists.
"""
def __init__(
self,
model: Union[FilePath, AtomisticModel, "torch.jit.RecursiveScriptModule"],
*,
extensions_directory: Optional[FilePath] = None,
device: Optional[Union[torch.device, str]] = None,
check_consistency: bool = False,
compute_forces: bool = True,
compute_stress: bool = True,
) -> None:
"""
:param model: Model to use. Accepts a file path to a ``.pt`` saved
model, a Python :py:class:`AtomisticModel` instance, or a
TorchScript :py:class:`torch.jit.RecursiveScriptModule`.
:param extensions_directory: Directory containing compiled TorchScript
extensions required by the model, if any.
:param device: Torch device for evaluation. When ``None``, the best
device is selected from the model's ``supported_devices``.
:param check_consistency: Run consistency checks during model evaluation.
Useful for debugging but hurts performance.
:param compute_forces: Compute atomic forces via autograd.
:param compute_stress: Compute stress tensors via the strain trick.
"""
super().__init__()
self._check_consistency = check_consistency
# Load the model, following the same patterns as ase_calculator.py
if isinstance(model, (str, bytes, pathlib.PurePath)):
model_path = str(model)
if not os.path.exists(model_path):
raise ValueError(f"given model path '{model_path}' does not exist")
model = load_atomistic_model(
model_path, extensions_directory=extensions_directory
)
elif isinstance(model, torch.jit.RecursiveScriptModule):
if model.original_name != "AtomisticModel":
raise TypeError(
"torch model must be 'AtomisticModel', "
f"got '{model.original_name}' instead"
)
elif isinstance(model, AtomisticModel):
pass
else:
raise TypeError(f"unknown type for model: {type(model)}")
capabilities = model.capabilities()
# Resolve device
if device is not None:
if isinstance(device, str):
device = torch.device(device)
self._device = device
else:
self._device = torch.device(
pick_device(capabilities.supported_devices, None)
)
# Resolve dtype from model capabilities
if capabilities.dtype in STR_TO_DTYPE:
self._dtype = STR_TO_DTYPE[capabilities.dtype]
else:
raise ValueError(
f"unexpected dtype in model capabilities: {capabilities.dtype}"
)
if "energy" not in capabilities.outputs:
raise ValueError(
"model does not have an 'energy' output. "
"Only models with energy outputs can be used with TorchSim."
)
self._model = model.to(device=self._device)
self._compute_forces = compute_forces
self._compute_stress = compute_stress
self._requested_neighbor_lists = self._model.requested_neighbor_lists()
self._evaluation_options = ModelEvaluationOptions(
length_unit="angstrom",
outputs={
"energy": ModelOutput(quantity="energy", unit="eV", per_atom=False)
},
)
[docs]
def forward(self, state: "ts.SimState") -> Dict[str, torch.Tensor]:
"""Compute energies, forces, and stresses for the given simulation state.
:param state: TorchSim simulation state
:returns: Dictionary with ``"energy"`` (shape ``[n_systems]``),
``"forces"`` (shape ``[n_atoms, 3]``, if ``compute_forces``), and
``"stress"`` (shape ``[n_systems, 3, 3]``, if ``compute_stress``).
"""
positions = state.positions
cell = state.row_vector_cell
atomic_nums = state.atomic_numbers
if positions.dtype != self._dtype:
raise TypeError(
f"positions dtype {positions.dtype} does not match "
f"model dtype {self._dtype}"
)
# Build per-system System objects. Metatomic expects a list of System
# rather than a single batched graph.
systems: List[System] = []
strains: List[torch.Tensor] = []
n_systems = len(cell)
for sys_idx in range(n_systems):
mask = state.system_idx == sys_idx
sys_positions = positions[mask]
sys_cell = cell[sys_idx]
sys_types = atomic_nums[mask]
if self._compute_forces:
sys_positions = sys_positions.detach().requires_grad_(True)
if self._compute_stress:
strain = torch.eye(
3,
device=self._device,
dtype=self._dtype,
requires_grad=True,
)
sys_positions = sys_positions @ strain
sys_cell = sys_cell @ strain
strains.append(strain)
systems.append(
System(
positions=sys_positions,
types=sys_types,
cell=sys_cell,
pbc=state.pbc,
)
)
# Compute neighbor lists
systems = _compute_requested_neighbors(
systems=systems,
requested_options=self._requested_neighbor_lists,
check_consistency=self._check_consistency,
)
# Run the model
model_outputs = self._model(
systems=systems,
options=self._evaluation_options,
check_consistency=self._check_consistency,
)
energy_values = model_outputs["energy"].block().values
results: Dict[str, torch.Tensor] = {}
results["energy"] = energy_values.detach().squeeze(-1)
# Compute forces and/or stresses via autograd
if self._compute_forces or self._compute_stress:
grad_inputs: List[torch.Tensor] = []
if self._compute_forces:
for system in systems:
grad_inputs.append(system.positions)
if self._compute_stress:
grad_inputs.extend(strains)
grads = torch.autograd.grad(
outputs=energy_values,
inputs=grad_inputs,
grad_outputs=torch.ones_like(energy_values),
)
if self._compute_forces and self._compute_stress:
n_sys = len(systems)
force_grads = grads[:n_sys]
stress_grads = grads[n_sys:]
elif self._compute_forces:
force_grads = grads
stress_grads = ()
else:
force_grads = ()
stress_grads = grads
if self._compute_forces:
results["forces"] = torch.cat([-g for g in force_grads])
if self._compute_stress:
results["stress"] = torch.stack(
[
g / torch.abs(torch.det(system.cell.detach()))
for g, system in zip(stress_grads, systems, strict=False)
]
)
return results
# -- Neighbor list helpers (shared with ase_calculator.py patterns) ----------
def _compute_requested_neighbors(
systems: List[System],
requested_options: List[NeighborListOptions],
check_consistency: bool = False,
) -> List[System]:
"""Compute all neighbor lists requested by the model and store them in the systems.
Uses nvalchemiops for full neighbor lists on CUDA when available, vesin otherwise.
"""
can_use_nvalchemi = HAS_NVALCHEMIOPS and all(
system.device.type == "cuda" for system in systems
)
if can_use_nvalchemi:
full_nl_options = []
half_nl_options = []
for options in requested_options:
if options.full_list:
full_nl_options.append(options)
else:
half_nl_options.append(options)
systems = _compute_requested_neighbors_nvalchemi(
systems=systems,
requested_options=full_nl_options,
)
systems = _compute_requested_neighbors_vesin(
systems=systems,
requested_options=half_nl_options,
check_consistency=check_consistency,
)
else:
systems = _compute_requested_neighbors_vesin(
systems=systems,
requested_options=requested_options,
check_consistency=check_consistency,
)
return systems
def _compute_requested_neighbors_vesin(
systems: List[System],
requested_options: List[NeighborListOptions],
check_consistency: bool = False,
) -> List[System]:
"""Compute neighbor lists using vesin."""
system_devices = []
moved_systems = []
for system in systems:
system_devices.append(system.device)
if system.device.type not in ["cpu", "cuda"]:
moved_systems.append(system.to(device="cpu"))
else:
moved_systems.append(system)
vesin.metatomic.compute_requested_neighbors_from_options(
systems=moved_systems,
system_length_unit="angstrom",
options=requested_options,
check_consistency=check_consistency,
)
systems = []
for system, device in zip(moved_systems, system_devices, strict=True):
systems.append(system.to(device=device))
return systems
def _compute_requested_neighbors_nvalchemi(
systems: List[System],
requested_options: List[NeighborListOptions],
) -> List[System]:
"""Compute full neighbor lists on CUDA using nvalchemiops."""
for options in requested_options:
assert options.full_list
for system in systems:
assert system.device.type == "cuda"
edge_index, _, S = nvalchemi_neighbor_list(
system.positions,
options.engine_cutoff("angstrom"),
cell=system.cell,
pbc=system.pbc,
return_neighbor_list=True,
)
D = (
system.positions[edge_index[1]]
- system.positions[edge_index[0]]
+ S.to(system.cell.dtype) @ system.cell
)
P = edge_index.T
neighbors = TensorBlock(
D.reshape(-1, 3, 1),
samples=Labels(
names=[
"first_atom",
"second_atom",
"cell_shift_a",
"cell_shift_b",
"cell_shift_c",
],
values=torch.hstack([P, S]),
),
components=[
Labels(
"xyz",
torch.tensor([[0], [1], [2]], device=system.device),
)
],
properties=Labels(
"distance",
torch.tensor([[0]], device=system.device),
),
)
system.add_neighbor_list(options, neighbors)
return systems