Source code for metatensor.torch.atomistic.ase_calculator
import os
import pathlib
from typing import Dict, List, Optional, Union
import numpy as np
import torch
from .. import Labels, TensorBlock
from . import (
MetatensorAtomisticModel,
ModelEvaluationOptions,
ModelOutput,
System,
check_atomistic_model,
register_autograd_neighbors,
)
import ase # isort: skip
import ase.neighborlist # isort: skip
import ase.calculators.calculator # isort: skip
from ase.calculators.calculator import ( # isort: skip
InputError,
PropertyNotImplementedError,
all_properties as ALL_ASE_PROPERTIES,
)
# import here to get an error early if the user is missing metatensor-operations
from .. import sum_over_samples # isort: skip
FilePath = Union[str, bytes, pathlib.PurePath]
[docs]
class MetatensorCalculator(ase.calculators.calculator.Calculator):
"""
The :py:class:`MetatensorCalculator` class implements ASE's
:py:class`ase.calculators.calculator.Calculator` API using metatensor atomistic
models to compute energy, forces and any other supported property.
This class can be initialized with any :py:class:`MetatensorAtomisticModel`, and
used to run simulations using ASE's MD facilities.
"""
def __init__(
self,
model: Union[
FilePath,
MetatensorAtomisticModel,
],
check_consistency=False,
):
"""
:param model: model to use for the calculation. This can be a file path, a
Python instance of :py:class:`MetatensorAtomisticModel`, or the output of
:py:func:`torch.jit.script` on :py:class:`MetatensorAtomisticModel`.
:param check_consistency: should we check the model for consistency when
running, defaults to False.
"""
super().__init__()
if isinstance(model, (str, bytes, pathlib.PurePath)):
if not os.path.exists(model):
raise InputError(f"given model path '{model}' does not exists")
check_atomistic_model(model)
self._model = torch.jit.load(model)
elif isinstance(model, torch.jit.RecursiveScriptModule):
if model.original_name != "MetatensorAtomisticModel":
raise InputError(
"torch model must be 'MetatensorAtomisticModel', "
f"got '{model.original_name}' instead"
)
self._model = model
elif isinstance(model, MetatensorAtomisticModel):
self._model = model
else:
raise TypeError(f"unknown type for model: {type(model)}")
self.parameters = {
"model": model,
"check_consistency": check_consistency,
}
# We do our own check to verify if a property is implemented in `calculate()`,
# so we pretend to be able to compute all properties ASE knows about.
self.implemented_properties = ALL_ASE_PROPERTIES
def todict(self):
# used by ASE to save the calculator
raise NotImplementedError("todict is not yet implemented")
@classmethod
def fromdict(cls, dict):
# used by ASE to load a saved calculator
raise NotImplementedError("fromdict is not yet implemented")
[docs]
def run_model(
self,
atoms: ase.Atoms,
outputs: Dict[str, ModelOutput],
selected_atoms: Optional[Labels] = None,
) -> Dict[str, TensorBlock]:
"""
Run the model on the given ``atoms``, computing properties according to the
``outputs`` and ``selected_atoms`` options.
The output of the model is returned directly, and as such the blocks' ``values``
will be :py:class:`torch.Tensor`.
This is intended as an easy way to run metatensor models on
:py:class:`ase.Atoms` when the model can predict properties not supported by the
usual ASE's calculator interface.
"""
species, positions, cell = _ase_to_torch_data(atoms)
system = System(species, positions, cell)
# Compute the neighbors lists requested by the model using ASE NL
for options in self._model.requested_neighbors_lists(length_unit="angstrom"):
neighbors = _compute_ase_neighbors(atoms, options)
register_autograd_neighbors(
system,
neighbors,
check_consistency=self.parameters["check_consistency"],
)
system.add_neighbors_list(options, neighbors)
options = ModelEvaluationOptions(
length_unit="angstrom",
outputs=outputs,
selected_atoms=selected_atoms,
)
return self._model(
systems=[system],
options=options,
check_consistency=self.parameters["check_consistency"],
)
[docs]
def calculate(
self,
atoms: ase.Atoms,
properties: List[str],
system_changes: List[str],
) -> Dict[str, np.ndarray]:
"""
Compute some ``properties`` with this calculator, and return them in the format
expected by ASE.
This is not intended to be called directly by users, but to be an implementation
detail of ``atoms.get_energy()`` and related functions. See
:py:meth:`ase.calculators.calculator.Calculator.calculate` for more information.
"""
super().calculate(
atoms=atoms,
properties=properties,
system_changes=system_changes,
)
outputs = _ase_properties_to_metatensor_outputs(properties)
species, positions, cell = _ase_to_torch_data(atoms)
do_backward = False
if "forces" in properties:
do_backward = True
positions.requires_grad_(True)
if "stress" in properties:
do_backward = True
scaling = torch.eye(3, requires_grad=True, dtype=cell.dtype)
positions = positions @ scaling
positions.retain_grad()
cell = cell @ scaling
if "stresses" in properties:
raise NotImplementedError("'stresses' are not implemented yet")
# convert from ase.Atoms to metatensor.torch.atomistic.System
system = System(species, positions, cell)
for options in self._model.requested_neighbors_lists(length_unit="angstrom"):
neighbors = _compute_ase_neighbors(atoms, options)
register_autograd_neighbors(
system,
neighbors,
check_consistency=self.parameters["check_consistency"],
)
system.add_neighbors_list(options, neighbors)
run_options = ModelEvaluationOptions(
length_unit="angstrom",
outputs=outputs,
selected_atoms=None,
)
outputs = self._model(
[system],
run_options,
check_consistency=self.parameters["check_consistency"],
)
energy = outputs["energy"]
if run_options.outputs["energy"].per_atom:
assert len(energy) == 1
assert energy.sample_names == ["structure", "atom"]
assert torch.all(energy.block().samples["structure"] == 0)
assert torch.all(
energy.block().samples["atom"] == torch.arange(positions.shape[0])
)
energies = energy.block().values
assert energies.shape == (len(atoms), 1)
energy = sum_over_samples(energy, sample_names=["atom"])
assert len(energy.block().gradients_list()) == 0
energy = energy.block().values
assert energy.shape == (1, 1)
self.results = {}
if "energies" in properties:
self.results["energies"] = (
energies.detach().to(device="cpu").numpy().reshape(-1)
)
if "energy" in properties:
self.results["energy"] = energy.detach().to(device="cpu").numpy()[0, 0]
if do_backward:
energy.backward(-torch.ones_like(energy))
if "forces" in properties:
self.results["forces"] = (
system.positions.grad.to(device="cpu").numpy().reshape(-1, 3)
)
if "stress" in properties:
volume = atoms.cell.volume
scaling_grad = -scaling.grad.to(device="cpu").numpy().reshape(3, 3)
self.results["stress"] = scaling_grad / volume
def _ase_properties_to_metatensor_outputs(properties):
energy_properties = []
for p in properties:
if p in ["energy", "energies", "forces", "stress", "stresses"]:
energy_properties.append(p)
else:
raise PropertyNotImplementedError(
f"property '{p}' it not yet supported by this calculator, "
"even if it might be supported by the model"
)
output = ModelOutput()
output.quantity = "energy"
output.unit = "ev"
output.explicit_gradients = []
if "energies" in properties or "stresses" in properties:
output.per_atom = True
else:
output.per_atom = False
if "stresses" in properties:
output.explicit_gradients = ["cell"]
return {"energy": output}
def _compute_ase_neighbors(atoms, options):
nl = ase.neighborlist.NeighborList(
cutoffs=[options.engine_cutoff] * len(atoms),
skin=0.0,
sorted=False,
self_interaction=False,
bothways=options.full_list,
primitive=ase.neighborlist.NewPrimitiveNeighborList,
)
nl.update(atoms)
cell = torch.from_numpy(atoms.cell[:])
positions = torch.from_numpy(atoms.positions)
samples = []
distances = []
cutoff2 = options.engine_cutoff * options.engine_cutoff
for i in range(len(atoms)):
indices, offsets = nl.get_neighbors(i)
for j, offset in zip(indices, offsets):
distance = positions[j] - positions[i] + offset.dot(cell)
distance2 = torch.dot(distance, distance).item()
if distance2 > cutoff2:
continue
samples.append((i, j, offset[0], offset[1], offset[2]))
distances.append(distance.to(dtype=torch.float64))
if len(distances) == 0:
distances = torch.zeros((0, 3), dtype=positions.dtype)
samples = torch.zeros((0, 5), dtype=torch.int32)
else:
samples = torch.tensor(samples, dtype=torch.int32)
distances = torch.vstack(distances)
return TensorBlock(
values=distances.reshape(-1, 3, 1),
samples=Labels(
names=[
"first_atom",
"second_atom",
"cell_shift_a",
"cell_shift_b",
"cell_shift_c",
],
values=samples,
),
components=[Labels.range("xyz", 3)],
properties=Labels.range("distance", 1),
)
def _ase_to_torch_data(atoms):
"""Get the positions and cell from ASE atoms as torch tensors"""
species = torch.from_numpy(atoms.numbers).to(dtype=torch.int32)
positions = torch.from_numpy(atoms.positions)
if np.all(atoms.pbc):
cell = torch.from_numpy(atoms.cell[:])
elif np.any(atoms.pbc):
raise ValueError(
f"partial PBC ({atoms.pbc}) are not currently supported in "
"metatensor atomistic models"
)
else:
cell = torch.zeros((3, 3), dtype=torch.float64)
return species, positions, cell