Exporting models#
Exporting models to work with any metatensor-compatible simulation engine is
done with the MetatensorAtomisticModel
class. This class takes in an
arbitrary torch.nn.Module
, with a forward functions that follows the
ModelInterface
. In addition to the actual model, you also need to
define some information about the model, using ModelMetadata
and
ModelCapabilities
.
- class metatensor.torch.atomistic.ModelInterface[source]#
Bases:
Module
Interface for models that can be used with
MetatensorAtomisticModel
.There are several requirements that models must satisfy to be usable with
MetatensorAtomisticModel
. The main one is concerns theforward()
function, which must have the signature defined in this interface.Additionally, the model can request neighbor lists to be computed by the simulation engine, and stored inside the input
System
. This is done by defining the optionalrequested_neighbors_lists()
method for the model or any of it’s sub-module.MetatensorAtomisticModel
will check ifrequested_neighbors_lists
is defined for all the sub-modules of the model, then collect and unify identical requests for the simulation engine.- forward(systems: List[System], outputs: Dict[str, ModelOutput], selected_atoms: Labels | None) Dict[str, TensorMap] [source]#
This function should run the model for the given
systems
, returning the requestedoutputs
. Ifselected_atoms
is a set ofLabels
, only the corresponding atoms should be included as “main” atoms in the calculation and the output.outputs
will be a subset of the capabilities that where declared when exporting the model. For example if a model can compute both an"energy"
and a"charge"
output, the simulation engine might only request one them.The returned dictionary should have the same keys as
outputs
, and the values should contains the corresponding properties of thesystems
, as computed for the subset of atoms defined inselected_atoms
. For some specific outputs, there are additional constrains on how the associated metadata should look like, documented in the Standard model outputs section.The main use case for
selected_atoms
is domain decomposition, where theSystem
given to a model might contain both atoms in the current domain and some atoms from other domains; and the calculation should produce per-atom output only for the atoms in the domain (but still accounting for atoms from the other domains as potential neighbors).- Parameters:
- Returns:
properties of the systems, as predicted by the machine learning model
- Return type:
- requested_neighbors_lists() List[NeighborsListOptions] [source]#
Optional function declaring which neighbors list this model requires.
This function can be defined on either the root model or any of it’s sub-modules. A single module can request multiple neighbors list simultaneously if it needs them.
It is then the responsibility of the code calling the model to:
call this function (or more generally
MetatensorAtomisticModel.requested_neighbors_lists()
) to get the list of requests;compute all neighbor lists corresponding to these requests and add them to the systems before calling the model.
- Return type:
- class metatensor.torch.atomistic.MetatensorAtomisticModel(module: ModelInterface, metadata: ModelMetadata, capabilities: ModelCapabilities)[source]#
MetatensorAtomisticModel
is the main entry point for atomistic machine learning based on metatensor. It is the interface between custom, user-defined models and simulation engines. Users should wrap their models with this class, and useexport()
to save and export the model to a file. The exported models can then be loaded by a simulation engine to compute properties of atomistic systems.When wrapping a
module
, you should declare what the model is capable of (usingModelCapabilities
). This includes what units the model expects as input and what properties the model can compute (usingModelOutput
). The simulation engine will then ask the model to compute some subset of these properties (through aModelEvaluationOptions
), on all or a subset of atoms of an atomistic system.The wrapped module must follow the interface defined by
ModelInterface
, should not already be compiled by TorchScript, and should be in “eval” mode (i.e.module.training
should beFalse
).For example, a custom module predicting the energy as a constant value times the number of atoms could look like this
>>> class ConstantEnergy(torch.nn.Module): ... def __init__(self, constant: float): ... super().__init__() ... self.constant = torch.tensor(constant).reshape(1, 1) ... ... def forward( ... self, ... systems: List[System], ... outputs: Dict[str, ModelOutput], ... selected_atoms: Optional[Labels] = None, ... ) -> Dict[str, TensorMap]: ... results: Dict[str, TensorMap] = {} ... if "energy" in outputs: ... if outputs["energy"].per_atom: ... raise NotImplementedError("per atom energy is not implemented") ... ... dtype = systems[0].positions.dtype ... energies = torch.zeros(len(systems), 1, dtype=dtype) ... for i, system in enumerate(systems): ... if selected_atoms is None: ... n_atoms = len(system) ... else: ... n_atoms = len(selected_atoms) ... ... energies[i] = self.constant * n_atoms ... ... systems_idx = torch.tensor([[i] for i in range(len(systems))]) ... energy_block = TensorBlock( ... values=energies, ... samples=Labels(["system"], systems_idx.to(torch.int32)), ... components=torch.jit.annotate(List[Labels], []), ... properties=Labels(["energy"], torch.tensor([[0]])), ... ) ... ... results["energy"] = TensorMap( ... keys=Labels(["_"], torch.tensor([[0]])), ... blocks=[energy_block], ... ) ... ... return results ...
Wrapping and exporting this model would then look like this:
>>> import os >>> import tempfile >>> from metatensor.torch.atomistic import MetatensorAtomisticModel >>> from metatensor.torch.atomistic import ( ... ModelCapabilities, ... ModelOutput, ... ModelMetadata, ... ) >>> model = ConstantEnergy(constant=3.141592) >>> # put the model in inference mode >>> model = model.eval() >>> # Define the model capabilities >>> capabilities = ModelCapabilities( ... outputs={ ... "energy": ModelOutput( ... quantity="energy", ... unit="eV", ... per_atom=False, ... explicit_gradients=[], ... ), ... }, ... atomic_types=[1, 2, 6, 8, 12], ... interaction_range=0.0, ... length_unit="angstrom", ... supported_devices=["cpu"], ... dtype="float64", ... ) >>> # define metadata about this model >>> metadata = ModelMetadata( ... name="model-name", ... authors=["Some Author", "Another One"], ... # references and long description can also be added ... ) >>> # wrap the model >>> wrapped = MetatensorAtomisticModel(model, metadata, capabilities) >>> # export the model >>> with tempfile.TemporaryDirectory() as directory: ... wrapped.export(os.path.join(directory, "constant-energy-model.pt")) ...
- Parameters:
module (ModelInterface) – The torch module to wrap and export.
capabilities (ModelCapabilities) – Description of the model capabilities.
metadata (ModelMetadata) –
- wrapped_module() Module [source]#
Get the module wrapped in this
MetatensorAtomisticModel
- Return type:
- capabilities() ModelCapabilities [source]#
Get the capabilities of the wrapped model
- Return type:
- metadata() ModelMetadata [source]#
Get the metadata of the wrapped model
- Return type:
- requested_neighbors_lists() List[NeighborsListOptions] [source]#
Get the neighbors lists required by the wrapped model or any of the child module.
- Return type:
- forward(systems: List[System], options: ModelEvaluationOptions, check_consistency: bool) Dict[str, TensorMap] [source]#
Run the wrapped model and return the corresponding outputs.
Before running the model, this will convert the
systems
data from the engine unit to the model unit, including all neighbors lists distances.After running the model, this will convert all the outputs from the model units to the engine units.
- Parameters:
systems (List[System]) – input systems on which we should run the model. The systems should already contain all neighbors lists corresponding to the options in
requested_neighbors_lists()
.options (ModelEvaluationOptions) – options for this run of the model
check_consistency (bool) – Should we run additional check that everything is consistent? This should be set to
True
when verifying a model, and toFalse
once you are sure everything is running fine.
- Returns:
A dictionary containing all the model outputs
- Return type: