Source code for metatensor.torch.atomistic.ase_calculator

import logging
import os
import pathlib
import warnings
from typing import Dict, List, Optional, Union

import numpy as np
import torch
import vesin
from torch.profiler import record_function

from .. import Labels, TensorBlock, TensorMap
from . import (
    MetatensorAtomisticModel,
    ModelEvaluationOptions,
    ModelMetadata,
    ModelOutput,
    System,
    load_atomistic_model,
    register_autograd_neighbors,
)


import ase  # isort: skip
import ase.neighborlist  # isort: skip
import ase.calculators.calculator  # isort: skip
from ase.calculators.calculator import (  # isort: skip
    InputError,
    PropertyNotImplementedError,
    all_properties as ALL_ASE_PROPERTIES,
)


if os.environ.get("METATENSOR_IMPORT_FOR_SPHINX", "0") == "0":
    # this can not be imported when building the documentation
    from .. import sum_over_samples  # isort: skip

FilePath = Union[str, bytes, pathlib.PurePath]

LOGGER = logging.getLogger(__name__)


STR_TO_DTYPE = {
    "float32": torch.float32,
    "float64": torch.float64,
}


[docs] class MetatensorCalculator(ase.calculators.calculator.Calculator): """ The :py:class:`MetatensorCalculator` class implements ASE's :py:class:`ase.calculators.calculator.Calculator` API using metatensor atomistic models to compute energy, forces and any other supported property. This class can be initialized with any :py:class:`MetatensorAtomisticModel`, and used to run simulations using ASE's MD facilities. Neighbor lists are computed using the fast `vesin <https://luthaf.fr/vesin/latest/index.html>`_ neighbor list library, unless the system has mixed periodic and non-periodic boundary conditions (which are not yet supported by ``vesin``), in which case the slower ASE neighbor list is used. """ additional_outputs: Dict[str, TensorMap] """ Additional outputs computed by :py:meth:`calculate` are stored in this dictionary. The keys will match the keys of the ``additional_outputs`` parameters to the constructor; and the values will be the corresponding raw :py:class:`metatensor.torch.TensorMap` produced by the model. """ def __init__( self, model: Union[FilePath, MetatensorAtomisticModel], *, additional_outputs: Optional[Dict[str, ModelOutput]] = None, extensions_directory=None, check_consistency=False, device=None, ): """ :param model: model to use for the calculation. This can be a file path, a Python instance of :py:class:`MetatensorAtomisticModel`, or the output of :py:func:`torch.jit.script` on :py:class:`MetatensorAtomisticModel`. :param additional_outputs: Dictionary of additional outputs to be computed by the model. These outputs will always be computed whenever the :py:meth:`calculate` function is called (e.g. by :py:meth:`ase.Atoms.get_potential_energy`, :py:meth:`ase.optimize.optimize.Dynamics.run`, *etc.*) and stored in the :py:attr:`additional_outputs` attribute. If you want more control over when and how to compute specific outputs, you should use :py:meth:`run_model` instead. :param extensions_directory: if the model uses extensions, we will try to load them from this directory :param check_consistency: should we check the model for consistency when running, defaults to False. :param device: torch device to use for the calculation. If ``None``, we will try the options in the model's ``supported_device`` in order. """ super().__init__() self.parameters = { "check_consistency": check_consistency, } # Load the model if isinstance(model, (str, bytes, pathlib.PurePath)): if not os.path.exists(model): raise InputError(f"given model path '{model}' does not exist") self.parameters["model_path"] = str(model) model = load_atomistic_model( model, extensions_directory=extensions_directory ) elif isinstance(model, torch.jit.RecursiveScriptModule): if model.original_name != "MetatensorAtomisticModel": raise InputError( "torch model must be 'MetatensorAtomisticModel', " f"got '{model.original_name}' instead" ) elif isinstance(model, MetatensorAtomisticModel): # nothing to do pass else: raise TypeError(f"unknown type for model: {type(model)}") self.parameters["device"] = str(device) if device is not None else None # check if the model supports the requested device capabilities = model.capabilities() if device is None: device = _find_best_device(capabilities.supported_devices) else: device = torch.device(device) device_is_supported = False for supported in capabilities.supported_devices: try: supported = torch.device(supported) except RuntimeError as e: warnings.warn( "the model contains an invalid device in `supported_devices`: " f"{e}", stacklevel=2, ) continue if supported.type == device.type: device_is_supported = True break if not device_is_supported: raise ValueError( f"This model does not support the requested device ({device}), " "the following devices are supported: " f"{capabilities.supported_devices}" ) if capabilities.dtype in STR_TO_DTYPE: self._dtype = STR_TO_DTYPE[capabilities.dtype] else: raise ValueError( f"found unexpected dtype in model capabilities: {capabilities.dtype}" ) if additional_outputs is None: self._additional_output_requests = {} else: assert isinstance(additional_outputs, dict) for name, output in additional_outputs.items(): assert isinstance(name, str) assert isinstance(output, torch.ScriptObject) assert ( "explicit_gradients_setter" in output._method_names() ), "outputs must be ModelOutput instances" self._additional_output_requests = additional_outputs self._device = device self._model = model.to(device=self._device) # We do our own check to verify if a property is implemented in `calculate()`, # so we pretend to be able to compute all properties ASE knows about. self.implemented_properties = ALL_ASE_PROPERTIES def todict(self): if "model_path" not in self.parameters: raise RuntimeError( "can not save metatensor model in ASE `todict`, please initialize " "`MetatensorCalculator` with a path to a saved model file if you need " "to use `todict`" ) return self.parameters @classmethod def fromdict(cls, data): return MetatensorCalculator( model=data["model_path"], check_consistency=data["check_consistency"], device=data["device"], )
[docs] def metadata(self) -> ModelMetadata: """Get the metadata of the underlying model""" return self._model.metadata()
[docs] def run_model( self, atoms: ase.Atoms, outputs: Dict[str, ModelOutput], selected_atoms: Optional[Labels] = None, ) -> Dict[str, TensorMap]: """ Run the model on the given ``atoms``, computing the requested ``outputs`` and only these. The output of the model is returned directly, and as such the blocks' ``values`` will be :py:class:`torch.Tensor`. This is intended as an easy way to run metatensor models on :py:class:`ase.Atoms` when the model can compute outputs not supported by the standard ASE's calculator interface. All the parameters have the same meaning as the corresponding ones in :py:meth:`metatensor.torch.atomistic.ModelInterface.forward`. :param atoms: system on which to run the model :param outputs: outputs of the model that should be predicted :param selected_atoms: subset of atoms on which to run the calculation """ types, positions, cell, pbc = _ase_to_torch_data( atoms=atoms, dtype=self._dtype, device=self._device ) system = System(types, positions, cell, pbc) # Compute the neighbors lists requested by the model using ASE NL for options in self._model.requested_neighbor_lists(): neighbors = _compute_ase_neighbors( atoms, options, dtype=self._dtype, device=self._device ) register_autograd_neighbors( system, neighbors, check_consistency=self.parameters["check_consistency"], ) system.add_neighbor_list(options, neighbors) options = ModelEvaluationOptions( length_unit="angstrom", outputs=outputs, selected_atoms=selected_atoms, ) return self._model( systems=[system], options=options, check_consistency=self.parameters["check_consistency"], )
[docs] def calculate( self, atoms: ase.Atoms, properties: List[str], system_changes: List[str], ) -> None: """ Compute some ``properties`` with this calculator, and return them in the format expected by ASE. This is not intended to be called directly by users, but to be an implementation detail of ``atoms.get_energy()`` and related functions. See :py:meth:`ase.calculators.calculator.Calculator.calculate` for more information. """ super().calculate( atoms=atoms, properties=properties, system_changes=system_changes, ) # In the next few lines, we decide which properties to calculate among energy, # forces and stress. In addition to the requested properties, we calculate the # energy if any of the three is requested, as it is an intermediate step in the # calculation of the other two. We also calculate the forces if the stress is # requested, and vice-versa. The overhead for the latter operation is also # small, assuming that the majority of the model computes forces and stresses # by backward propagation as opposed to forward-mode differentiation. calculate_energy = ( "energy" in properties or "energies" in properties or "forces" in properties or "stress" in properties ) calculate_energies = "energies" in properties calculate_forces = "forces" in properties or "stress" in properties calculate_stress = "stress" in properties or "forces" in properties if "stresses" in properties: raise NotImplementedError("'stresses' are not implemented yet") with record_function("ASECalculator::prepare_inputs"): outputs = _ase_properties_to_metatensor_outputs(properties) outputs.update(self._additional_output_requests) capabilities = self._model.capabilities() for name in outputs.keys(): if name not in capabilities.outputs: raise ValueError( f"you asked for the calculation of {name}, but this model " "does not support it" ) types, positions, cell, pbc = _ase_to_torch_data( atoms=atoms, dtype=self._dtype, device=self._device ) do_backward = False if calculate_forces: do_backward = True positions.requires_grad_(True) if calculate_stress: do_backward = True strain = torch.eye( 3, requires_grad=True, device=self._device, dtype=self._dtype ) positions = positions @ strain positions.retain_grad() cell = cell @ strain run_options = ModelEvaluationOptions( length_unit="angstrom", outputs=outputs, selected_atoms=None, ) with record_function("ASECalculator::compute_neighbors"): # convert from ase.Atoms to metatensor.torch.atomistic.System system = System(types, positions, cell, pbc) for options in self._model.requested_neighbor_lists(): neighbors = _compute_ase_neighbors( atoms, options, dtype=self._dtype, device=self._device ) register_autograd_neighbors( system, neighbors, check_consistency=self.parameters["check_consistency"], ) system.add_neighbor_list(options, neighbors) # no `record_function` here, this will be handled by MetatensorAtomisticModel outputs = self._model( [system], run_options, check_consistency=self.parameters["check_consistency"], ) energy = outputs["energy"] with record_function("ASECalculator::sum_energies"): if run_options.outputs["energy"].per_atom: assert len(energy) == 1 assert energy.sample_names == ["system", "atom"] assert torch.all(energy.block().samples["system"] == 0) assert torch.all( energy.block().samples["atom"] == torch.arange(positions.shape[0]) ) energies = energy.block().values assert energies.shape == (len(atoms), 1) energy = sum_over_samples(energy, sample_names=["atom"]) assert len(energy.block().gradients_list()) == 0 energy = energy.block().values assert energy.shape == (1, 1) with record_function("ASECalculator::run_backward"): if do_backward: energy.backward() with record_function("ASECalculator::convert_outputs"): self.results = {} if calculate_energies: energies_values = energies.detach().reshape(-1) energies_values = energies_values.to(device="cpu").to( dtype=torch.float64 ) self.results["energies"] = energies_values.numpy() if calculate_energy: energy_values = energy.detach() energy_values = energy_values.to(device="cpu").to(dtype=torch.float64) self.results["energy"] = energy_values.numpy()[0, 0] if calculate_forces: forces_values = -system.positions.grad.reshape(-1, 3) forces_values = forces_values.to(device="cpu").to(dtype=torch.float64) self.results["forces"] = forces_values.numpy() if calculate_stress: stress_values = strain.grad.reshape(3, 3) / atoms.cell.volume stress_values = stress_values.to(device="cpu").to(dtype=torch.float64) self.results["stress"] = _full_3x3_to_voigt_6_stress( stress_values.numpy() ) self.additional_outputs = {} for name in self._additional_output_requests: self.additional_outputs[name] = outputs[name]
def _find_best_device(devices: List[str]) -> torch.device: """ Find the best device from the list of ``devices`` that is available to the current PyTorch installation. """ for device in devices: if device == "cpu": return torch.device("cpu") elif device == "cuda": if torch.cuda.is_available(): return torch.device("cuda") else: LOGGER.warning( "the model suggested to use CUDA devices before CPU, " "but we are unable to find it" ) elif device == "mps": if ( hasattr(torch.backends, "mps") and torch.backends.mps.is_built() and torch.backends.mps.is_available() ): return torch.device("mps") else: LOGGER.warning( "the model suggested to use MPS devices before CPU, " "but we are unable to find it" ) else: warnings.warn( f"unknown device in the model's `supported_devices`: '{device}'", stacklevel=2, ) warnings.warn( "could not find a valid device in the model's `supported_devices`, " "falling back to CPU", stacklevel=2, ) return torch.device("cpu") def _ase_properties_to_metatensor_outputs(properties): energy_properties = [] for p in properties: if p in ["energy", "energies", "forces", "stress", "stresses"]: energy_properties.append(p) else: raise PropertyNotImplementedError( f"property '{p}' it not yet supported by this calculator, " "even if it might be supported by the model" ) output = ModelOutput() output.quantity = "energy" output.unit = "ev" output.explicit_gradients = [] if "energies" in properties or "stresses" in properties: output.per_atom = True else: output.per_atom = False if "stresses" in properties: output.explicit_gradients = ["cell"] return {"energy": output} def _compute_ase_neighbors(atoms, options, dtype, device): # options.strict is ignored by this function, since `ase.neighborlist.neighbor_list` # only computes strict NL, and these are valid even with `strict=False` if np.all(atoms.pbc) or np.all(~atoms.pbc): nl_i, nl_j, nl_S, nl_D = vesin.ase_neighbor_list( "ijSD", atoms, cutoff=options.engine_cutoff(engine_length_unit="angstrom"), ) else: nl_i, nl_j, nl_S, nl_D = ase.neighborlist.neighbor_list( "ijSD", atoms, cutoff=options.engine_cutoff(engine_length_unit="angstrom"), ) # The pair selection code here below avoids a relatively slow loop over # all pairs to improve performance reject_condition = ( # we want a half neighbor list, so drop all duplicated neighbors (nl_j < nl_i) | ( (nl_i == nl_j) & ( # only create pairs with the same atom twice if the pair spans more # than one unit cell ((nl_S[:, 0] == 0) & (nl_S[:, 1] == 0) & (nl_S[:, 2] == 0)) | # When creating pairs between an atom and one of its periodic images, # the code generates multiple redundant pairs # (e.g. with shifts 0 1 1 and 0 -1 -1); and we want to only keep one of # these. We keep the pair in the positive half plane of shifts. ( (nl_S.sum(axis=1) < 0) | ( (nl_S.sum(axis=1) == 0) & ((nl_S[:, 2] < 0) | ((nl_S[:, 2] == 0) & (nl_S[:, 1] < 0))) ) ) ) ) ) selected = np.logical_not(reject_condition) n_pairs = np.sum(selected) if options.full_list: distances = np.empty((2 * n_pairs, 3), dtype=np.float64) samples = np.empty((2 * n_pairs, 5), dtype=np.int32) else: distances = np.empty((n_pairs, 3), dtype=np.float64) samples = np.empty((n_pairs, 5), dtype=np.int32) samples[:n_pairs, 0] = nl_i[selected] samples[:n_pairs, 1] = nl_j[selected] samples[:n_pairs, 2:] = nl_S[selected] distances[:n_pairs] = nl_D[selected] if options.full_list: samples[n_pairs:, 0] = nl_j[selected] samples[n_pairs:, 1] = nl_i[selected] samples[n_pairs:, 2:] = -nl_S[selected] distances[n_pairs:] = -nl_D[selected] distances = torch.from_numpy(distances).to(dtype=dtype).to(device=device) return TensorBlock( values=distances.reshape(-1, 3, 1), samples=Labels( names=[ "first_atom", "second_atom", "cell_shift_a", "cell_shift_b", "cell_shift_c", ], values=torch.from_numpy(samples), ).to(device=device), components=[Labels.range("xyz", 3).to(device)], properties=Labels.range("distance", 1).to(device), ) def _ase_to_torch_data(atoms, dtype, device): """Get the positions, cell and pbc from ASE atoms as torch tensors""" types = torch.from_numpy(atoms.numbers).to(dtype=torch.int32, device=device) positions = torch.from_numpy(atoms.positions).to(dtype=dtype, device=device) cell = torch.zeros((3, 3), dtype=dtype, device=device) pbc = torch.tensor(atoms.pbc, dtype=torch.bool, device=device) cell[pbc] = torch.tensor(atoms.cell[atoms.pbc], dtype=dtype, device=device) return types, positions, cell, pbc def _full_3x3_to_voigt_6_stress(stress): """ Re-implementation of ``ase.stress.full_3x3_to_voigt_6_stress`` which does not do the stress symmetrization correctly (they do ``(stress[1, 2] + stress[1, 2]) / 2.0``) """ return np.array( [ stress[0, 0], stress[1, 1], stress[2, 2], (stress[1, 2] + stress[2, 1]) / 2.0, (stress[0, 2] + stress[2, 0]) / 2.0, (stress[0, 1] + stress[1, 0]) / 2.0, ] )