Abstract base classes

These are classes that define the interfaces that should be followed by metatrain components (models, trainers, etc.).

class metatrain.utils.abc.ModelInterface(hypers: HypersType, dataset_info: DatasetInfo, metadata: ModelMetadata)[source]

Bases: Module, Generic[HypersType]

Abstract base class for a machine learning model in metatrain.

All architectures in metatrain must be implemented as sub-class of this class, and implement the corresponding methods.

Parameters:
  • hypers (HypersType) – A dictionary with the model’s hyper-parameters.

  • dataset_info (DatasetInfo) – Information containing details about the dataset, such as target quantities and atomic types.

  • metadata (ModelMetadata) – Metadata about the model, e.g. author, description, and references.

__checkpoint_version__: int

The current version of the model’s checkpoint.

This is used to upgrade checkpoints produced with earlier versions of the code. See Checkpoint versioning for more information.

__supported_devices__: List[device]

List of torch devices supported by this model architecture.

They should be sorted in order of preference since metatrain will use this and __supported_dtypes__ to determine, based on the user request and machines’ availability, the optimal dtype and device for training.

__supported_dtypes__: List[dtype]

List of torch dtypes supported by this model architecture.

They should be sorted in order of preference since metatrain will use this and __supported_devices__ to determine, based on the user request and machines’ availability, the optimal dtype and device for training.

__default_metadata__: ModelMetadata

Default metadata for this model architecture.

Can be used to provide references that will be stored in the exported model. The references are stored in a dictionary with keys implementation and architecture. The implementation key should contain references to the software used in the implementation of the architecture, while the architecture key should contain references about the general architecture.

hypers

The model hyper passed at initialization

dataset_info

The dataset info passed at initialization

metadata

The metadata passed at initialization

abstractmethod forward(systems: List[System], outputs: Dict[str, ModelOutput], selected_atoms: Labels | None = None) Dict[str, TensorMap][source]

Execute the model for the given systems, computing the requested outputs.

Parameters:
  • systems (List[System]) – List of systems to evaluate the model on.

  • outputs (Dict[str, ModelOutput]) – Dictionary of outputs that the model should compute.

  • selected_atoms (Labels | None) – Optional Labels specifying a subset of atoms to compute the outputs for. If None, the outputs are computed for all atoms in each system.

Returns:

A dictionary mapping each requested output name to the corresponding TensorMap containing the computed values.

Return type:

Dict[str, TensorMap]

See also

metatomic.torch.ModelInterface for more explanation about the different arguments.

abstractmethod supported_outputs() Dict[str, ModelOutput][source]

Get the outputs currently supported by this model.

This will likely be the same outputs that are set as this model capabilities in ModelInterface.export().

Returns:

A dictionary of the supported outputs by this model.

Return type:

Dict[str, ModelOutput]

abstractmethod restart(dataset_info: DatasetInfo) ModelInterface[source]

Update a model to restart training, potentially with different dataset and/or targets.

This function is called whenever training restarts, with the same or a different dataset. It enables transfer learning (changing the targets), and fine-tuning (same targets, different datasets)

Parameters:

dataset_info (DatasetInfo) – Information about the new dataset, including the targets that will be used for training.

Returns:

The updated model, or a new instance of the model, that is able to handle the new dataset.

Return type:

ModelInterface

abstractmethod classmethod load_checkpoint(checkpoint: Dict[str, Any], context: Literal['restart', 'finetune', 'export']) ModelInterface[source]

Create a model from a checkpoint (i.e. state dictionary).

Parameters:
  • checkpoint (Dict[str, Any]) – Checkpoint’s state dictionary.

  • context (Literal['restart', 'finetune', 'export']) – Context in which to load the model. Possible values are "restart" when restarting a stopped traininf run, "finetune" when loading a model for further fine-tuning or transfer learning, and "export" when loading a model for final export. When multiple checkpoints are stored together, this can be used to pick one of them depending on the context.

Returns:

An instance of the model.

Return type:

ModelInterface

abstractmethod export(metadata: ModelMetadata | None = None) AtomisticModel[source]

Turn this model into an instance of metatomic.torch.MetatensorAtomisticModel, containing the model itself, a definition of the model capabilities and some metadata about the model.

Parameters:

metadata (ModelMetadata | None) – additional metadata to add in the model as specified by the user.

Returns:

An instance of metatomic.torch.MetatensorAtomisticModel

Return type:

AtomisticModel

abstractmethod classmethod upgrade_checkpoint(checkpoint: Dict[str, Any]) Dict[str, Any][source]

Upgrade the checkpoint to the current version of the model.

Parameters:

checkpoint (Dict[str, Any]) – Checkpoint’s state dictionary.

Raises:

RuntimeError – if the checkpoint cannot be upgraded to the current version of the model.

Returns:

The upgraded checkpoint.

Return type:

Dict[str, Any]

abstractmethod get_checkpoint() Dict[str, Any][source]

Get the checkpoint of the model. This should contain all the information needed by load_checkpoint to recreate the same model instance.

Returns:

The model’s checkpoint.

Return type:

Dict[str, Any]

class metatrain.utils.abc.TrainerInterface(hypers: HypersType)[source]

Bases: Generic[HypersType]

Abstract base class for a model trainer in metatrain.

All architectures in metatrain must implement such a trainer, which is responsible for training the model. The trainer must be a be sub-class of this class, and implement the corresponding methods.

Parameters:

hypers (HypersType) – A dictionary with the trainer’s hyper-parameters.

__checkpoint_version__: int

The current version of the trainer’s checkpoint.

This is used to upgrade checkpoints produced with earlier versions of the code. See Checkpoint versioning for more information.

hypers

The trainer hypers passed at intialization

abstractmethod train(model: ModelInterface, dtype: dtype, devices: List[device], train_datasets: List[Dataset | Subset], val_datasets: List[Dataset | Subset], checkpoint_dir: str) None[source]

Train the model using the train_datasets. How to train the model is left to this class, using the hyper-parameter given in __init__.

Parameters:
  • model (ModelInterface) – the model to train

  • dtype (dtype) – torch.dtype used by the data in the datasets

  • devices (List[device]) – torch.device to use for training the model. When training with more than one device (e.g. multi-GPU training), this can contains multiple devices.

  • train_datasets (List[Dataset | Subset]) – datasets to use to train the model

  • val_datasets (List[Dataset | Subset]) – datasets to use for model validation

  • checkpoint_dir (str) – directory where checkpoints shoudl be saved

Return type:

None

abstractmethod save_checkpoint(model: ModelInterface, path: str | Path) None[source]

Save a checkoint of both the model and trainer state to the given path

Parameters:
  • model (ModelInterface) – The model to save in the checkpoint.

  • path (str | Path) – The path where to save the checkpoint.

Return type:

None

abstractmethod classmethod upgrade_checkpoint(checkpoint: Dict) Dict[source]

Upgrade the checkpoint to the current version of the trainer.

Parameters:

checkpoint (Dict) – Checkpoint’s state dictionary.

Raises:

RuntimeError – if the checkpoint cannot be upgraded to the current version of the trainer.

Returns:

The upgraded checkpoint.

Return type:

Dict

abstractmethod classmethod load_checkpoint(checkpoint: Dict[str, Any], hypers: HypersType, context: Literal['restart', 'finetune']) TrainerInterface[source]

Create a trainer instance from data stored in the checkpoint.

Parameters:
  • checkpoint (Dict[str, Any]) – Checkpoint’s state dictionary.

  • hypers (HypersType) – Hyper-parameters for the trainer, as specified by the user.

  • context (Literal['restart', 'finetune']) – Context in which to load the model. Possible values are "restart" when restarting a stopped traininf run, and "finetune" when loading a model for further fine-tuning or transfer learning. When multiple checkpoints are stored together, this can be used to pick one of them depending on the context.

Returns:

The loaded trainer instance.

Return type:

TrainerInterface