Abstract base classes¶
These are classes that define the interfaces that should be followed by metatrain components (models, trainers, etc.).
- class metatrain.utils.abc.ModelInterface(hypers: HypersType, dataset_info: DatasetInfo, metadata: ModelMetadata)[source]¶
Bases:
Module,Generic[HypersType]Abstract base class for a machine learning model in metatrain.
All architectures in metatrain must be implemented as sub-class of this class, and implement the corresponding methods.
- Parameters:
hypers (HypersType) – A dictionary with the model’s hyper-parameters.
dataset_info (DatasetInfo) – Information containing details about the dataset, such as target quantities and atomic types.
metadata (ModelMetadata) – Metadata about the model, e.g. author, description, and references.
- __checkpoint_version__: int¶
The current version of the model’s checkpoint.
This is used to upgrade checkpoints produced with earlier versions of the code. See Checkpoint versioning for more information.
- __supported_devices__: List[device]¶
List of torch devices supported by this model architecture.
They should be sorted in order of preference since
metatrainwill use this and__supported_dtypes__to determine, based on the user request and machines’ availability, the optimaldtypeanddevicefor training.
- __supported_dtypes__: List[dtype]¶
List of torch dtypes supported by this model architecture.
They should be sorted in order of preference since
metatrainwill use this and__supported_devices__to determine, based on the user request and machines’ availability, the optimaldtypeanddevicefor training.
- __default_metadata__: ModelMetadata¶
Default metadata for this model architecture.
Can be used to provide references that will be stored in the exported model. The references are stored in a dictionary with keys
implementationandarchitecture. Theimplementationkey should contain references to the software used in the implementation of the architecture, while thearchitecturekey should contain references about the general architecture.
- hypers¶
The model hyper passed at initialization
- dataset_info¶
The dataset info passed at initialization
- metadata¶
The metadata passed at initialization
- abstractmethod forward(systems: List[System], outputs: Dict[str, ModelOutput], selected_atoms: Labels | None = None) Dict[str, TensorMap][source]¶
Execute the model for the given
systems, computing the requestedoutputs.- Parameters:
systems (List[System]) – List of systems to evaluate the model on.
outputs (Dict[str, ModelOutput]) – Dictionary of outputs that the model should compute.
selected_atoms (Labels | None) – Optional
Labelsspecifying a subset of atoms to compute the outputs for. IfNone, the outputs are computed for all atoms in each system.
- Returns:
A dictionary mapping each requested output name to the corresponding
TensorMapcontaining the computed values.- Return type:
See also
metatomic.torch.ModelInterfacefor more explanation about the different arguments.
- abstractmethod supported_outputs() Dict[str, ModelOutput][source]¶
Get the outputs currently supported by this model.
This will likely be the same outputs that are set as this model capabilities in
ModelInterface.export().- Returns:
A dictionary of the supported outputs by this model.
- Return type:
- abstractmethod restart(dataset_info: DatasetInfo) ModelInterface[source]¶
Update a model to restart training, potentially with different dataset and/or targets.
This function is called whenever training restarts, with the same or a different dataset. It enables transfer learning (changing the targets), and fine-tuning (same targets, different datasets)
- Parameters:
dataset_info (DatasetInfo) – Information about the new dataset, including the targets that will be used for training.
- Returns:
The updated model, or a new instance of the model, that is able to handle the new dataset.
- Return type:
- abstractmethod classmethod load_checkpoint(checkpoint: Dict[str, Any], context: Literal['restart', 'finetune', 'export']) ModelInterface[source]¶
Create a model from a checkpoint (i.e. state dictionary).
- Parameters:
checkpoint (Dict[str, Any]) – Checkpoint’s state dictionary.
context (Literal['restart', 'finetune', 'export']) – Context in which to load the model. Possible values are
"restart"when restarting a stopped traininf run,"finetune"when loading a model for further fine-tuning or transfer learning, and"export"when loading a model for final export. When multiple checkpoints are stored together, this can be used to pick one of them depending on the context.
- Returns:
An instance of the model.
- Return type:
- abstractmethod export(metadata: ModelMetadata | None = None) AtomisticModel[source]¶
Turn this model into an instance of
metatomic.torch.MetatensorAtomisticModel, containing the model itself, a definition of the model capabilities and some metadata about the model.- Parameters:
metadata (ModelMetadata | None) – additional metadata to add in the model as specified by the user.
- Returns:
An instance of
metatomic.torch.MetatensorAtomisticModel- Return type:
- class metatrain.utils.abc.TrainerInterface(hypers: HypersType)[source]¶
Bases:
Generic[HypersType]Abstract base class for a model trainer in metatrain.
All architectures in metatrain must implement such a trainer, which is responsible for training the model. The trainer must be a be sub-class of this class, and implement the corresponding methods.
- Parameters:
hypers (HypersType) – A dictionary with the trainer’s hyper-parameters.
- __checkpoint_version__: int¶
The current version of the trainer’s checkpoint.
This is used to upgrade checkpoints produced with earlier versions of the code. See Checkpoint versioning for more information.
- hypers¶
The trainer hypers passed at intialization
- abstractmethod train(model: ModelInterface, dtype: dtype, devices: List[device], train_datasets: List[Dataset | Subset], val_datasets: List[Dataset | Subset], checkpoint_dir: str) None[source]¶
Train the
modelusing thetrain_datasets. How to train the model is left to this class, using the hyper-parameter given in__init__.- Parameters:
model (ModelInterface) – the model to train
dtype (dtype) –
torch.dtypeused by the data in the datasetsdevices (List[device]) –
torch.deviceto use for training the model. When training with more than one device (e.g. multi-GPU training), this can contains multiple devices.train_datasets (List[Dataset | Subset]) – datasets to use to train the model
val_datasets (List[Dataset | Subset]) – datasets to use for model validation
checkpoint_dir (str) – directory where checkpoints shoudl be saved
- Return type:
None
- abstractmethod save_checkpoint(model: ModelInterface, path: str | Path) None[source]¶
Save a checkoint of both the
modeland trainer state to the givenpath- Parameters:
model (ModelInterface) – The model to save in the checkpoint.
- Return type:
None
- abstractmethod classmethod upgrade_checkpoint(checkpoint: Dict) Dict[source]¶
Upgrade the checkpoint to the current version of the trainer.
- Parameters:
checkpoint (Dict) – Checkpoint’s state dictionary.
- Raises:
RuntimeError – if the checkpoint cannot be upgraded to the current version of the trainer.
- Returns:
The upgraded checkpoint.
- Return type:
- abstractmethod classmethod load_checkpoint(checkpoint: Dict[str, Any], hypers: HypersType, context: Literal['restart', 'finetune']) TrainerInterface[source]¶
Create a trainer instance from data stored in the
checkpoint.- Parameters:
checkpoint (Dict[str, Any]) – Checkpoint’s state dictionary.
hypers (HypersType) – Hyper-parameters for the trainer, as specified by the user.
context (Literal['restart', 'finetune']) – Context in which to load the model. Possible values are
"restart"when restarting a stopped traininf run, and"finetune"when loading a model for further fine-tuning or transfer learning. When multiple checkpoints are stored together, this can be used to pick one of them depending on the context.
- Returns:
The loaded trainer instance.
- Return type: