Source code for metatomic_torchsim._model

"""TorchSim wrapper for metatomic atomistic models.

Adapts metatomic models to the TorchSim ModelInterface protocol, allowing them to
be used within the torch-sim simulation framework for MD and other simulations.

Supports batched computations for multiple systems simultaneously, computing
energies, forces, and stresses via autograd.  Also supports output variants,
non-conservative forces/stress, energy uncertainty warnings, and additional
model outputs.
"""

import logging
import os
import pathlib
import warnings
from typing import Dict, List, Literal, Optional, Union

import torch
from metatensor.torch import TensorMap

from metatomic.torch import (
    AtomisticModel,
    ModelEvaluationOptions,
    ModelOutput,
    System,
    load_atomistic_model,
    pick_device,
    pick_output,
)

from ._neighbors import AllNeighborsCalculator


try:
    import torch_sim as ts
    from torch_sim.models.interface import ModelInterface
except ImportError as e:
    raise ImportError(
        "the torch_sim package is required for metatomic-torchsim: "
        "pip install torch-sim-atomistic"
    ) from e


FilePath = Union[str, bytes, pathlib.PurePath]

LOGGER = logging.getLogger(__name__)

STR_TO_DTYPE = {
    "float32": torch.float32,
    "float64": torch.float64,
}


[docs] class MetatomicModel(ModelInterface): """TorchSim wrapper for metatomic atomistic models. Wraps a metatomic model to compute energies, forces, and stresses within the TorchSim framework. Handles the translation between TorchSim's batched ``SimState`` and metatomic's list-of-``System`` convention, and uses autograd for force/stress derivatives. Neighbor lists are computed with vesin, or with nvalchemiops on CUDA when available and the model requests full neighbor lists. """ def __init__( self, model: Union[FilePath, AtomisticModel, "torch.jit.RecursiveScriptModule"], *, extensions_directory: Optional[FilePath] = None, device: Optional[Union[torch.device, str]] = None, check_consistency: bool = False, compute_forces: bool = True, compute_stress: bool = True, variants: Optional[Dict[str, Optional[str]]] = None, non_conservative: Union[bool, Literal["forces", "stress"]] = False, uncertainty_threshold: Optional[float] = 0.1, additional_outputs: Optional[Dict[str, ModelOutput]] = None, ) -> None: """ :param model: Model to use. Accepts a file path to a ``.pt`` saved model, a Python :py:class:`AtomisticModel` instance, or a TorchScript :py:class:`torch.jit.RecursiveScriptModule`. :param extensions_directory: Directory containing compiled TorchScript extensions required by the model, if any. :param device: Torch device for evaluation. When ``None``, the best device is selected from the model's ``supported_devices``. :param check_consistency: Run consistency checks during model evaluation. Useful for debugging but hurts performance. :param compute_forces: Compute atomic forces via autograd. :param compute_stress: Compute stress tensors via the strain trick. :param variants: Dictionary mapping output names to a variant that should be used. Setting ``{"energy": "pbe"}`` selects the ``"energy/pbe"`` output. The energy variant propagates to uncertainty and non-conservative outputs unless overridden (e.g. ``{"energy": "pbe", "energy_uncertainty": "r2scan"}`` would select ``energy/pbe`` and ``energy_uncertainty/r2scan``). :param non_conservative: controls which outputs are obtained directly from the model rather than via autograd on the energy. Accepted values are: - ``False`` (default): conservative mode; forces and stress are both derived from the gradient of the energy. - ``True``: both forces and stress are read directly from the model's ``non_conservative_forces`` and ``non_conservative_stress`` outputs. - ``"forces"``: forces come from the model's ``non_conservative_forces`` output; stress is still obtained via autograd. - ``"stress"``: stress comes from the model's ``non_conservative_stress`` output; forces are still obtained via autograd. Using any value other than ``False`` can afford a speed-up, potentially at the expense of physical correctness (especially in molecular dynamics simulations). :param uncertainty_threshold: Threshold for per-atom energy uncertainty in eV. When the model supports ``energy_uncertainty`` with ``sample_kind="atom"``, atoms exceeding this threshold trigger a warning. Set to ``None`` to disable. :param additional_outputs: Dictionary of extra :py:class:`ModelOutput` to request from the model. Results are stored in :py:attr:`additional_outputs` after each forward call. """ super().__init__() self._check_consistency = check_consistency # Load the model, following the same patterns as ase_calculator.py if isinstance(model, (str, bytes, pathlib.PurePath)): model_path = str(model) if not os.path.exists(model_path): raise ValueError(f"given model path '{model_path}' does not exist") model = load_atomistic_model( model_path, extensions_directory=extensions_directory ) elif isinstance(model, torch.jit.RecursiveScriptModule): if model.original_name != "AtomisticModel": raise TypeError( "torch model must be 'AtomisticModel', " f"got '{model.original_name}' instead" ) elif isinstance(model, AtomisticModel): pass else: raise TypeError(f"unknown type for model: {type(model)}") capabilities = model.capabilities() # Resolve device if device is not None: if isinstance(device, str): device = torch.device(device) self._device = device else: self._device = torch.device( pick_device(capabilities.supported_devices, None) ) # Resolve dtype from model capabilities if capabilities.dtype in STR_TO_DTYPE: self._dtype = STR_TO_DTYPE[capabilities.dtype] else: raise ValueError( f"unexpected dtype in model capabilities: {capabilities.dtype}" ) # Resolve output keys based on requested variants variants = variants or {} default_variant = variants.get("energy") resolved_variants = { key: variants.get(key, default_variant) for key in [ "energy", "energy_uncertainty", "non_conservative_force", "non_conservative_stress", ] } if "non_conservative_forces" in variants: warnings.warn( "variant name 'non_conservative_forces' is deprecated, please use " "'non_conservative_force' instead", stacklevel=2, ) if "non_conservative_force" in resolved_variants: raise ValueError( "you can not specify both 'non_conservative_force' and " "'non_conservative_forces' in `variants`" ) resolved_variants["non_conservative_force"] = variants[ "non_conservative_forces" ] outputs = capabilities.outputs has_energy = any( "energy" == key or key.startswith("energy/") for key in outputs.keys() ) if not has_energy: raise ValueError( "model does not have an 'energy' output. " "Only models with energy outputs can be used with TorchSim." ) self._energy_key = pick_output("energy", outputs, resolved_variants["energy"]) # Uncertainty has_energy_uq = any("energy_uncertainty" in key for key in outputs.keys()) if has_energy_uq and uncertainty_threshold is not None: self._energy_uq_key = pick_output( "energy_uncertainty", outputs, resolved_variants["energy_uncertainty"], ) else: self._energy_uq_key = None # Non-conservative outputs _valid_nc = (True, False, "forces", "stress") if non_conservative not in _valid_nc: raise ValueError( f"non_conservative must be one of {list(_valid_nc)}, " f"got {non_conservative!r}" ) self._nc_forces = non_conservative in (True, "forces") self._nc_stress = non_conservative in (True, "stress") if self._nc_forces and self._nc_stress: if ( "non_conservative_stress" in variants and "non_conservative_force" in variants and ( (variants["non_conservative_stress"] is None) != (variants["non_conservative_force"] is None) ) ): raise ValueError( "if both 'non_conservative_stress' and " "'non_conservative_force' are present in `variants`, they " "must either be both `None` or both not `None`." ) if self._nc_forces: self._nc_forces_key = pick_output( "non_conservative_force", outputs, resolved_variants["non_conservative_force"], ) else: self._nc_forces_key = None if self._nc_stress: self._nc_stress_key = pick_output( "non_conservative_stress", outputs, resolved_variants["non_conservative_stress"], ) else: self._nc_stress_key = None # Additional outputs if additional_outputs is None: self._additional_output_requests: Dict[str, ModelOutput] = {} else: assert isinstance(additional_outputs, dict) for name, output in additional_outputs.items(): assert isinstance(name, str) assert isinstance(output, torch.ScriptObject), ( "outputs must be ModelOutput instances" ) self._additional_output_requests = additional_outputs self._model = model.to(device=self._device) self._compute_forces = compute_forces self._compute_stress = compute_stress self._uncertainty_threshold = uncertainty_threshold self._calculate_uncertainty = ( self._energy_uq_key in outputs and outputs[self._energy_uq_key].sample_kind == "atom" and uncertainty_threshold is not None ) if self._calculate_uncertainty: if uncertainty_threshold <= 0.0: raise ValueError( f"`uncertainty_threshold` is {uncertainty_threshold} but must " "be positive" ) self._requested_inputs = self._model.requested_inputs(use_new_names=True) if len(self._requested_inputs) != 0: raise ValueError( "this model requests extra inputs " f"({', '.join(self._requested_inputs.keys())}), which are not " "implemented in metatomic-torchsim. Please open an issue if " "you need them!" ) # Precompute the outputs dict (immutable after __init__) run_outputs: Dict[str, ModelOutput] = { self._energy_key: ModelOutput(unit="eV", sample_kind="system") } if self._calculate_uncertainty: run_outputs[self._energy_uq_key] = ModelOutput( unit="eV", sample_kind="atom" ) if self._nc_forces and self._compute_forces: run_outputs[self._nc_forces_key] = ModelOutput( unit="eV/Angstrom", sample_kind="atom" ) if self._nc_stress and self._compute_stress: run_outputs[self._nc_stress_key] = ModelOutput( unit="eV/Angstrom^3", sample_kind="system" ) run_outputs.update(self._additional_output_requests) self._evaluation_options = ModelEvaluationOptions( length_unit="angstrom", outputs=run_outputs, ) self._nl_calculators = AllNeighborsCalculator( requested_options=self._model.requested_neighbor_lists(), check_consistency=check_consistency, ) self.additional_outputs: Dict[str, TensorMap] = {} """ Additional outputs computed by :py:meth:`forward` are stored here. Keys match the ``additional_outputs`` parameter to the constructor; values are raw :py:class:`metatensor.torch.TensorMap` from the model. """
[docs] def forward(self, state: "ts.SimState") -> Dict[str, torch.Tensor]: """Compute energies, forces, and stresses for the given simulation state. :param state: TorchSim simulation state :returns: Dictionary with ``"energy"`` (shape ``[n_systems]``), ``"forces"`` (shape ``[n_atoms, 3]``, if ``compute_forces``), and ``"stress"`` (shape ``[n_systems, 3, 3]``, if ``compute_stress``). """ positions = state.positions cell = state.row_vector_cell atomic_nums = state.atomic_numbers if positions.dtype != self._dtype: raise TypeError( f"positions dtype {positions.dtype} does not match " f"model dtype {self._dtype}" ) # Determine whether autograd is needed do_autograd_forces = self._compute_forces and not self._nc_forces do_autograd_stress = self._compute_stress and not self._nc_stress # Build per-system System objects. Metatomic expects a list of System # rather than a single batched graph. systems: List[System] = [] strains: List[torch.Tensor] = [] n_systems = len(cell) pbc = state.pbc if isinstance(pbc, bool): pbc = torch.tensor([pbc, pbc, pbc]) elif not isinstance(pbc, torch.Tensor): pbc = torch.tensor(pbc) for sys_idx in range(n_systems): mask = state.system_idx == sys_idx sys_positions = positions[mask] sys_cell = cell[sys_idx] sys_types = atomic_nums[mask] if do_autograd_forces: sys_positions = sys_positions.detach().requires_grad_(True) if do_autograd_stress: strain = torch.eye( 3, device=self._device, dtype=self._dtype, requires_grad=True, ) sys_positions = sys_positions @ strain sys_cell = sys_cell @ strain strains.append(strain) systems.append( System( positions=sys_positions, types=sys_types, cell=sys_cell, pbc=pbc, ) ) # Compute neighbor lists systems = self._nl_calculators.compute(systems=systems) # Run the model (evaluation options precomputed in __init__) model_outputs = self._model( systems=systems, options=self._evaluation_options, check_consistency=self._check_consistency, ) energy_values = model_outputs[self._energy_key].block().values results: Dict[str, torch.Tensor] = {} results["energy"] = energy_values.detach().squeeze(-1) # Uncertainty warning if self._calculate_uncertainty: uncertainty = model_outputs[self._energy_uq_key].block().values n_total_atoms = positions.shape[0] if uncertainty.shape != (n_total_atoms, 1): raise ValueError( f"expected uncertainty shape ({n_total_atoms}, 1), " f"got {uncertainty.shape}" ) threshold = self._uncertainty_threshold if torch.any(uncertainty > threshold): exceeded = torch.where(uncertainty.squeeze(-1) > threshold)[0] atom_list = exceeded.tolist() if len(atom_list) > 20: atom_list = atom_list[:20] suffix = f" (and {len(exceeded) - 20} more)" else: suffix = "" warnings.warn( "Some of the atomic energy uncertainties are larger than the " f"threshold of {threshold} eV. The prediction is above the " f"threshold for atoms {atom_list}{suffix}.", stacklevel=2, ) # Forces and stresses; NC outputs are read directly from model if self._compute_forces and self._nc_forces: nc_forces = model_outputs[self._nc_forces_key].block().values.detach() nc_forces = nc_forces.reshape(-1, 3) # Remove spurious net force per system for sys_idx in range(n_systems): mask = state.system_idx == sys_idx sys_forces = nc_forces[mask] nc_forces[mask] = sys_forces - sys_forces.mean(dim=0, keepdim=True) results["forces"] = nc_forces if self._compute_stress and self._nc_stress: nc_stress = model_outputs[self._nc_stress_key].block().values.detach() nc_stress = nc_stress.reshape(n_systems, 3, 3) results["stress"] = nc_stress # Forces and stresses; autograd outputs if do_autograd_forces or do_autograd_stress: grad_inputs: List[torch.Tensor] = [] if do_autograd_forces: for system in systems: grad_inputs.append(system.positions) if do_autograd_stress: grad_inputs.extend(strains) grads = torch.autograd.grad( outputs=energy_values, inputs=grad_inputs, grad_outputs=torch.ones_like(energy_values), ) if do_autograd_forces and do_autograd_stress: n_sys = len(systems) force_grads = grads[:n_sys] stress_grads = grads[n_sys:] elif do_autograd_forces: force_grads = grads stress_grads = () else: force_grads = () stress_grads = grads if do_autograd_forces: results["forces"] = torch.cat([-g for g in force_grads]) if do_autograd_stress: results["stress"] = torch.stack( [ g / torch.abs(torch.det(system.cell.detach())) for g, system in zip(stress_grads, systems, strict=True) ] ) # Store additional outputs self.additional_outputs = {} for name in self._additional_output_requests: self.additional_outputs[name] = model_outputs[name] return results