Source code for metatensor.torch.atomistic.ase_calculator
import logging
import os
import pathlib
import warnings
from typing import Dict, List, Optional, Union
import numpy as np
import vesin
import torch
from torch.profiler import record_function
from .. import Labels, TensorBlock, TensorMap
from . import (
    MetatensorAtomisticModel,
    ModelEvaluationOptions,
    ModelMetadata,
    ModelOutput,
    System,
    load_atomistic_model,
    register_autograd_neighbors,
)
import ase  # isort: skip
import ase.neighborlist  # isort: skip
import ase.calculators.calculator  # isort: skip
from ase.calculators.calculator import (  # isort: skip
    InputError,
    PropertyNotImplementedError,
    all_properties as ALL_ASE_PROPERTIES,
)
if os.environ.get("METATENSOR_IMPORT_FOR_SPHINX", "0") == "0":
    # this can not be imported when building the documentation
    from .. import sum_over_samples  # isort: skip
FilePath = Union[str, bytes, pathlib.PurePath]
LOGGER = logging.getLogger(__name__)
STR_TO_DTYPE = {
    "float32": torch.float32,
    "float64": torch.float64,
}
[docs]
class MetatensorCalculator(ase.calculators.calculator.Calculator):
    """
    The :py:class:`MetatensorCalculator` class implements ASE's
    :py:class:`ase.calculators.calculator.Calculator` API using metatensor atomistic
    models to compute energy, forces and any other supported property.
    This class can be initialized with any :py:class:`MetatensorAtomisticModel`, and
    used to run simulations using ASE's MD facilities.
    Neighbor lists are computed using the fast
    `vesin <https://luthaf.fr/vesin/latest/index.html>`_ neighbor list library,
    unless the system has mixed periodic and non-periodic boundary conditions (which
    are not yet supported by ``vesin``), in which case the slower ASE neighbor list
    is used.
    """
    additional_outputs: Dict[str, TensorMap]
    """
    Additional outputs computed by :py:meth:`calculate` are stored in this dictionary.
    The keys will match the keys of the ``additional_outputs`` parameters to the
    constructor; and the values will be the corresponding raw
    :py:class:`metatensor.torch.TensorMap` produced by the model.
    """
    def __init__(
        self,
        model: Union[FilePath, MetatensorAtomisticModel],
        *,
        additional_outputs: Optional[Dict[str, ModelOutput]] = None,
        extensions_directory=None,
        check_consistency=False,
        device=None,
    ):
        """
        :param model: model to use for the calculation. This can be a file path, a
            Python instance of :py:class:`MetatensorAtomisticModel`, or the output of
            :py:func:`torch.jit.script` on :py:class:`MetatensorAtomisticModel`.
        :param additional_outputs: Dictionary of additional outputs to be computed by
            the model. These outputs will always be computed whenever the
            :py:meth:`calculate` function is called (e.g. by
            :py:meth:`ase.Atoms.get_potential_energy`,
            :py:meth:`ase.optimize.optimize.Dynamics.run`, *etc.*) and stored in the
            :py:attr:`additional_outputs` attribute. If you want more control over when
            and how to compute specific outputs, you should use :py:meth:`run_model`
            instead.
        :param extensions_directory: if the model uses extensions, we will try to load
            them from this directory
        :param check_consistency: should we check the model for consistency when
            running, defaults to False.
        :param device: torch device to use for the calculation. If ``None``, we will try
            the options in the model's ``supported_device`` in order.
        """
        super().__init__()
        self.parameters = {
            "check_consistency": check_consistency,
        }
        # Load the model
        if isinstance(model, (str, bytes, pathlib.PurePath)):
            if not os.path.exists(model):
                raise InputError(f"given model path '{model}' does not exist")
            self.parameters["model_path"] = str(model)
            model = load_atomistic_model(
                model, extensions_directory=extensions_directory
            )
        elif isinstance(model, torch.jit.RecursiveScriptModule):
            if model.original_name != "MetatensorAtomisticModel":
                raise InputError(
                    "torch model must be 'MetatensorAtomisticModel', "
                    f"got '{model.original_name}' instead"
                )
        elif isinstance(model, MetatensorAtomisticModel):
            # nothing to do
            pass
        else:
            raise TypeError(f"unknown type for model: {type(model)}")
        self.parameters["device"] = str(device) if device is not None else None
        # check if the model supports the requested device
        capabilities = model.capabilities()
        if device is None:
            device = _find_best_device(capabilities.supported_devices)
        else:
            device = torch.device(device)
            device_is_supported = False
            for supported in capabilities.supported_devices:
                try:
                    supported = torch.device(supported)
                except RuntimeError as e:
                    warnings.warn(
                        "the model contains an invalid device in `supported_devices`: "
                        f"{e}",
                        stacklevel=2,
                    )
                    continue
                if supported.type == device.type:
                    device_is_supported = True
                    break
            if not device_is_supported:
                raise ValueError(
                    f"This model does not support the requested device ({device}), "
                    "the following devices are supported: "
                    f"{capabilities.supported_devices}"
                )
        if capabilities.dtype in STR_TO_DTYPE:
            self._dtype = STR_TO_DTYPE[capabilities.dtype]
        else:
            raise ValueError(
                f"found unexpected dtype in model capabilities: {capabilities.dtype}"
            )
        if additional_outputs is None:
            self._additional_output_requests = {}
        else:
            assert isinstance(additional_outputs, dict)
            for name, output in additional_outputs.items():
                assert isinstance(name, str)
                assert isinstance(output, torch.ScriptObject)
                assert (
                    "explicit_gradients_setter" in output._method_names()
                ), "outputs must be ModelOutput instances"
            self._additional_output_requests = additional_outputs
        self._device = device
        self._model = model.to(device=self._device)
        # We do our own check to verify if a property is implemented in `calculate()`,
        # so we pretend to be able to compute all properties ASE knows about.
        self.implemented_properties = ALL_ASE_PROPERTIES
    def todict(self):
        if "model_path" not in self.parameters:
            raise RuntimeError(
                "can not save metatensor model in ASE `todict`, please initialize "
                "`MetatensorCalculator` with a path to a saved model file if you need "
                "to use `todict`"
            )
        return self.parameters
    @classmethod
    def fromdict(cls, data):
        return MetatensorCalculator(
            model=data["model_path"],
            check_consistency=data["check_consistency"],
            device=data["device"],
        )
[docs]
    def metadata(self) -> ModelMetadata:
        """Get the metadata of the underlying model"""
        return self._model.metadata()
[docs]
    def run_model(
        self,
        atoms: ase.Atoms,
        outputs: Dict[str, ModelOutput],
        selected_atoms: Optional[Labels] = None,
    ) -> Dict[str, TensorMap]:
        """
        Run the model on the given ``atoms``, computing the requested ``outputs`` and
        only these.
        The output of the model is returned directly, and as such the blocks' ``values``
        will be :py:class:`torch.Tensor`.
        This is intended as an easy way to run metatensor models on
        :py:class:`ase.Atoms` when the model can compute outputs not supported by the
        standard ASE's calculator interface.
        All the parameters have the same meaning as the corresponding ones in
        :py:meth:`metatensor.torch.atomistic.ModelInterface.forward`.
        :param atoms: system on which to run the model
        :param outputs: outputs of the model that should be predicted
        :param selected_atoms: subset of atoms on which to run the calculation
        """
        types, positions, cell, pbc = _ase_to_torch_data(
            atoms=atoms, dtype=self._dtype, device=self._device
        )
        system = System(types, positions, cell, pbc)
        # Compute the neighbors lists requested by the model using ASE NL
        for options in self._model.requested_neighbor_lists():
            neighbors = _compute_ase_neighbors(
                atoms, options, dtype=self._dtype, device=self._device
            )
            register_autograd_neighbors(
                system,
                neighbors,
                check_consistency=self.parameters["check_consistency"],
            )
            system.add_neighbor_list(options, neighbors)
        options = ModelEvaluationOptions(
            length_unit="angstrom",
            outputs=outputs,
            selected_atoms=selected_atoms,
        )
        return self._model(
            systems=[system],
            options=options,
            check_consistency=self.parameters["check_consistency"],
        )
[docs]
    def calculate(
        self,
        atoms: ase.Atoms,
        properties: List[str],
        system_changes: List[str],
    ) -> None:
        """
        Compute some ``properties`` with this calculator, and return them in the format
        expected by ASE.
        This is not intended to be called directly by users, but to be an implementation
        detail of ``atoms.get_energy()`` and related functions. See
        :py:meth:`ase.calculators.calculator.Calculator.calculate` for more information.
        """
        super().calculate(
            atoms=atoms,
            properties=properties,
            system_changes=system_changes,
        )
        # In the next few lines, we decide which properties to calculate among energy,
        # forces and stress. In addition to the requested properties, we calculate the
        # energy if any of the three is requested, as it is an intermediate step in the
        # calculation of the other two. We also calculate the forces if the stress is
        # requested, and vice-versa. The overhead for the latter operation is also
        # small, assuming that the majority of the model computes forces and stresses
        # by backward propagation as opposed to forward-mode differentiation.
        calculate_energy = (
            "energy" in properties
            or "energies" in properties
            or "forces" in properties
            or "stress" in properties
        )
        calculate_energies = "energies" in properties
        calculate_forces = "forces" in properties or "stress" in properties
        calculate_stress = "stress" in properties or "forces" in properties
        if "stresses" in properties:
            raise NotImplementedError("'stresses' are not implemented yet")
        with record_function("ASECalculator::prepare_inputs"):
            outputs = _ase_properties_to_metatensor_outputs(properties)
            outputs.update(self._additional_output_requests)
            capabilities = self._model.capabilities()
            for name in outputs.keys():
                if name not in capabilities.outputs:
                    raise ValueError(
                        f"you asked for the calculation of {name}, but this model "
                        "does not support it"
                    )
            types, positions, cell, pbc = _ase_to_torch_data(
                atoms=atoms, dtype=self._dtype, device=self._device
            )
            do_backward = False
            if calculate_forces:
                do_backward = True
                positions.requires_grad_(True)
            if calculate_stress:
                do_backward = True
                strain = torch.eye(
                    3, requires_grad=True, device=self._device, dtype=self._dtype
                )
                positions = positions @ strain
                positions.retain_grad()
                cell = cell @ strain
            run_options = ModelEvaluationOptions(
                length_unit="angstrom",
                outputs=outputs,
                selected_atoms=None,
            )
        with record_function("ASECalculator::compute_neighbors"):
            # convert from ase.Atoms to metatensor.torch.atomistic.System
            system = System(types, positions, cell, pbc)
            for options in self._model.requested_neighbor_lists():
                neighbors = _compute_ase_neighbors(
                    atoms, options, dtype=self._dtype, device=self._device
                )
                register_autograd_neighbors(
                    system,
                    neighbors,
                    check_consistency=self.parameters["check_consistency"],
                )
                system.add_neighbor_list(options, neighbors)
        # no `record_function` here, this will be handled by MetatensorAtomisticModel
        outputs = self._model(
            [system],
            run_options,
            check_consistency=self.parameters["check_consistency"],
        )
        energy = outputs["energy"]
        with record_function("ASECalculator::sum_energies"):
            if run_options.outputs["energy"].per_atom:
                assert len(energy) == 1
                assert energy.sample_names == ["system", "atom"]
                assert torch.all(energy.block().samples["system"] == 0)
                assert torch.all(
                    energy.block().samples["atom"] == torch.arange(positions.shape[0])
                )
                energies = energy.block().values
                assert energies.shape == (len(atoms), 1)
                energy = sum_over_samples(energy, sample_names=["atom"])
            assert len(energy.block().gradients_list()) == 0
            energy = energy.block().values
            assert energy.shape == (1, 1)
        with record_function("ASECalculator::run_backward"):
            if do_backward:
                energy.backward()
        with record_function("ASECalculator::convert_outputs"):
            self.results = {}
            if calculate_energies:
                energies_values = energies.detach().reshape(-1)
                energies_values = energies_values.to(device="cpu").to(
                    dtype=torch.float64
                )
                self.results["energies"] = energies_values.numpy()
            if calculate_energy:
                energy_values = energy.detach()
                energy_values = energy_values.to(device="cpu").to(dtype=torch.float64)
                self.results["energy"] = energy_values.numpy()[0, 0]
            if calculate_forces:
                forces_values = -system.positions.grad.reshape(-1, 3)
                forces_values = forces_values.to(device="cpu").to(dtype=torch.float64)
                self.results["forces"] = forces_values.numpy()
            if calculate_stress:
                stress_values = strain.grad.reshape(3, 3) / atoms.cell.volume
                stress_values = stress_values.to(device="cpu").to(dtype=torch.float64)
                self.results["stress"] = _full_3x3_to_voigt_6_stress(
                    stress_values.numpy()
                )
            self.additional_outputs = {}
            for name in self._additional_output_requests:
                self.additional_outputs[name] = outputs[name]
def _find_best_device(devices: List[str]) -> torch.device:
    """
    Find the best device from the list of ``devices`` that is available to the current
    PyTorch installation.
    """
    for device in devices:
        if device == "cpu":
            return torch.device("cpu")
        elif device == "cuda":
            if torch.cuda.is_available():
                return torch.device("cuda")
            else:
                LOGGER.warning(
                    "the model suggested to use CUDA devices before CPU, "
                    "but we are unable to find it"
                )
        elif device == "mps":
            if (
                hasattr(torch.backends, "mps")
                and torch.backends.mps.is_built()
                and torch.backends.mps.is_available()
            ):
                return torch.device("mps")
            else:
                LOGGER.warning(
                    "the model suggested to use MPS devices before CPU, "
                    "but we are unable to find it"
                )
        else:
            warnings.warn(
                f"unknown device in the model's `supported_devices`: '{device}'",
                stacklevel=2,
            )
    warnings.warn(
        "could not find a valid device in the model's `supported_devices`, "
        "falling back to CPU",
        stacklevel=2,
    )
    return torch.device("cpu")
def _ase_properties_to_metatensor_outputs(properties):
    energy_properties = []
    for p in properties:
        if p in ["energy", "energies", "forces", "stress", "stresses"]:
            energy_properties.append(p)
        else:
            raise PropertyNotImplementedError(
                f"property '{p}' it not yet supported by this calculator, "
                "even if it might be supported by the model"
            )
    output = ModelOutput()
    output.quantity = "energy"
    output.unit = "ev"
    output.explicit_gradients = []
    if "energies" in properties or "stresses" in properties:
        output.per_atom = True
    else:
        output.per_atom = False
    if "stresses" in properties:
        output.explicit_gradients = ["cell"]
    return {"energy": output}
def _compute_ase_neighbors(atoms, options, dtype, device):
    # options.strict is ignored by this function, since `ase.neighborlist.neighbor_list`
    # only computes strict NL, and these are valid even with `strict=False`
    if np.all(atoms.pbc) or np.all(~atoms.pbc):
        nl_i, nl_j, nl_S, nl_D = vesin.ase_neighbor_list(
            "ijSD",
            atoms,
            cutoff=options.engine_cutoff(engine_length_unit="angstrom"),
        )
    else:
        nl_i, nl_j, nl_S, nl_D = ase.neighborlist.neighbor_list(
            "ijSD",
            atoms,
            cutoff=options.engine_cutoff(engine_length_unit="angstrom"),
        )
    # The pair selection code here below avoids a relatively slow loop over
    # all pairs to improve performance
    reject_condition = (
        # we want a half neighbor list, so drop all duplicated neighbors
        (nl_j < nl_i)
        | (
            (nl_i == nl_j)
            & (
                # only create pairs with the same atom twice if the pair spans more
                # than one unit cell
                ((nl_S[:, 0] == 0) & (nl_S[:, 1] == 0) & (nl_S[:, 2] == 0))
                |
                # When creating pairs between an atom and one of its periodic images,
                # the code generates multiple redundant pairs
                # (e.g. with shifts 0 1 1 and 0 -1 -1); and we want to only keep one of
                # these. We keep the pair in the positive half plane of shifts.
                (
                    (nl_S.sum(axis=1) < 0)
                    | (
                        (nl_S.sum(axis=1) == 0)
                        & ((nl_S[:, 2] < 0) | ((nl_S[:, 2] == 0) & (nl_S[:, 1] < 0)))
                    )
                )
            )
        )
    )
    selected = np.logical_not(reject_condition)
    n_pairs = np.sum(selected)
    if options.full_list:
        distances = np.empty((2 * n_pairs, 3), dtype=np.float64)
        samples = np.empty((2 * n_pairs, 5), dtype=np.int32)
    else:
        distances = np.empty((n_pairs, 3), dtype=np.float64)
        samples = np.empty((n_pairs, 5), dtype=np.int32)
    samples[:n_pairs, 0] = nl_i[selected]
    samples[:n_pairs, 1] = nl_j[selected]
    samples[:n_pairs, 2:] = nl_S[selected]
    distances[:n_pairs] = nl_D[selected]
    if options.full_list:
        samples[n_pairs:, 0] = nl_j[selected]
        samples[n_pairs:, 1] = nl_i[selected]
        samples[n_pairs:, 2:] = -nl_S[selected]
        distances[n_pairs:] = -nl_D[selected]
    distances = torch.from_numpy(distances).to(dtype=dtype).to(device=device)
    return TensorBlock(
        values=distances.reshape(-1, 3, 1),
        samples=Labels(
            names=[
                "first_atom",
                "second_atom",
                "cell_shift_a",
                "cell_shift_b",
                "cell_shift_c",
            ],
            values=torch.from_numpy(samples),
        ).to(device=device),
        components=[Labels.range("xyz", 3).to(device)],
        properties=Labels.range("distance", 1).to(device),
    )
def _ase_to_torch_data(atoms, dtype, device):
    """Get the positions, cell and pbc from ASE atoms as torch tensors"""
    types = torch.from_numpy(atoms.numbers).to(dtype=torch.int32, device=device)
    positions = torch.from_numpy(atoms.positions).to(dtype=dtype, device=device)
    cell = torch.zeros((3, 3), dtype=dtype, device=device)
    pbc = torch.tensor(atoms.pbc, dtype=torch.bool, device=device)
    cell[pbc] = torch.tensor(atoms.cell[atoms.pbc], dtype=dtype, device=device)
    return types, positions, cell, pbc
def _full_3x3_to_voigt_6_stress(stress):
    """
    Re-implementation of ``ase.stress.full_3x3_to_voigt_6_stress`` which does not do the
    stress symmetrization correctly (they do ``(stress[1, 2] + stress[1, 2]) / 2.0``)
    """
    return np.array(
        [
            stress[0, 0],
            stress[1, 1],
            stress[2, 2],
            (stress[1, 2] + stress[2, 1]) / 2.0,
            (stress[0, 2] + stress[2, 0]) / 2.0,
            (stress[0, 1] + stress[1, 0]) / 2.0,
        ]
    )