Adding a new architecture

This page describes the required classes and files necessary for adding a new architecture to metatrain as experimental or stable architecture as described on the Life Cycle of an Architecture page.

What is a metatrain architecture?

To work with metatrain any architecture has to follow the same public API to be called correctly within the metatrain.cli.train() function to process the user’s options. In brief, the core of the train function looks similar to these lines

from architecture import __model__ as Model
from architecture import __trainer__ as Trainer

hypers = {...}
dataset_info = DatasetInfo()

if checkpoint_path is not None:
    checkpoint = torch.load(checkpoint_path)

    trainer = Trainer.load_checkpoint(
        checkpoint, hypers=hypers["training"], context="restart")
    model = Model.load_checkpoint(checkpoint, context="restart")
    model = model.restart(dataset_info)
else:
    trainer = Trainer(hypers["training"])

    if hasattr(hypers["training"], "finetune"):
        checkpoint = hypers["training"]["finetune"]["read_from"]
        model = Model.load_checkpoint(path=checkpoint, context="finetune")
    else:
        model = Model(hypers["model"], dataset_info)

trainer.train(
    model=model,
    dtype=dtype,
    devices=[],
    train_datasets=[],
    val_datasets=[],
    checkpoint_dir="path",
)

model.save_checkpoint("model.ckpt")

mts_atomistic_model = model.export()
mts_atomistic_model.export("model.pt", collect_extensions="extensions/")

General code structure

To follow this, a new architecture has to define two classes

Note

metatrain does not know the types and numbers of targets/datasets an architecture can handle. As a result, it cannot generate useful error messages when a user attempts to train an architecture with unsupported target and dataset combinations. Therefore, it is the responsibility of the architecture developer to verify if the model and the trainer support the provided train_datasets and val_datasets passed to the Trainer, as well as the dataset_info passed to the model.

The architecture must also define a documentation file which contains the default hyperparameters, along with their types and descriptions.

To comply with this design each architecture has to implement four files inside a new architecture directory, either inside the experimental subdirectory or in the root of the Python source if the new architecture already complies with all requirements to be stable. The usual structure of architecture looks as

myarchitecture
    ├── __init__.py
    ├── documentation.py
    ├── model.py
    └── trainer.py

Note

Because achitectures can live in either src/metatrain/<architecture>, src/metatrain/experimental/<architecture>, or src/metatrain/deprecated/<architecture>; the code inside should use absolute imports use the tools provided by metatrain.

# do not do this
from ..utils.dtype import dtype_to_str

# Do this instead
from metatrain.utils.dtype import dtype_to_str

Model class (model.py)

A model class has to follow the interface defined in ModelInterface. That is, all the methods that are marked as abstract in the interface must be implemented with the indicated API (same arguments and same return). At first sight, the interface might feel overwhelming, therefore here is a summary of the steps to take to implement a new model class:

  • Implement the __init__ method, which takes as input the model hyperparameters and the dataset information. This should initialize your model.

  • Implement the forward method, which defines the forward pass of the model.

  • Add some class attributes with __names_like_this__ that will help metatrain understand how to treat your model. They are listed and described in the ModelInterface documentation.

  • Implement the rest of abstract methods, which in general deal with handling checkpoints, exporting the model, and restarting training from a checkpoint.

Here is an incomplete example of what a model implementation looks like:

import torch
from metatomic.torch import DatasetInfo, ModelMetadata

from metatrain.utils.abc import ModelInterface

class MyModel(ModelInterface):
    __checkpoint_version__ = 1
    __supported_devices__ = ["cuda", "cpu"]
    __supported_dtypes__ = [torch.float64, torch.float32]
    __default_metadata__ = ModelMetadata(
        references={"implementation": ["ref1"], "architecture": ["ref2"]}
    )

    def __init__(self, hypers: dict, dataset_info: DatasetInfo):
        super().__init__(hypers, dataset_info)

        # To access hyperparameters, one can use self.hypers, whose
        # defaults are defined in the documentation.py file.
        self.hypers["size"]
        ...

    # Here one would implement the rest of the abstract methods

Trainer class (trainer.py)

A trainer class has to follow the interface defined in TrainerInterface. That is, all the methods that are marked as abstract in the interface must be implemented with the indicated API (same arguments and same return). We recommend looking at existing implementations of trainers for inspiration. They will look something like this:

from metatrain.utils.abc import TrainerInterface

class MyTrainer(TrainerInterface):
    __checkpoint_version__ = 1

    def __init__(self, hypers: dict):
        super().__init__(hypers)
        # To access hyperparameters, one can use self.hypers, whose
        # defaults are defined in the documentation.py file.
        self.hypers["learning_rate"]
        ...

    # Here one would implement the rest of the abstract methods

Init file (__init__.py)

You are free to name the Model and Trainer classes as you want. These classes should then be made available in the __init__.py under the names __model__ and __trainer__ so metatrain knows where to find them. __init__.py must also contain definition for the original __authors__ and current __maintainers__ of the architecture.

from .model import ModelInterface
from .trainer import TrainerInterface

# class to use as the architecture's model
__model__ = ModelInterface
# class to use as the architecture's trainer
__trainer__ = TrainerInterface

# List of the original authors of the architecture, each with an email
# address and GitHub handle.
#
# These authors are not necessarily currently in charge of maintaining the code
__authors__ = [
    ("Jane Roe <[email protected]>", "@janeroe"),
    ("John Doe <[email protected]>", "@johndoe"),
]

# Current maintainers of the architecture code, using the same
# style as ``__authors__``
__maintainers__ = [("Joe Bloggs <[email protected]>", "@joebloggs")]

Documentation (documentation.py)

The documentation file is used to define:

  • The hyperparameters for the model class.

  • The hyperparameters for the trainer class.

  • The text that will go to the online documentation for the architecture.

Warning

This file is meant to be imported separately to generate the documentation page for the architecture without needing the extra dependencies that the architecture might require.

Therefore, all imports in this file should be absolute and this file should not import the rest of the architecture code unless the architecture has no extra dependencies.

Bare minimum

We understand that during development of a new architecture expecting full documentation for all hyperparameters is unreasonable. Therefore, metatrain will work with a very minimal documentation.py file containing only the default hyperparameters for both the model and the trainer. One just needs to define a ModelHypers and a TrainerHypers, for the hypers of the model and the trainer respectively.

# This is the most minimal documentation.py file possible.
# Something like this should only be used during development.

# Default hyperparameters for the model
class ModelHypers:
    size = 150
    mode = "strict"

# Default hyperparameters for the trainer
class TrainerHypers:
    learning_rate = 1e-3
    lr_scheduler = "CosineAnnealing"

Note

The name of these classes (ModelHypers and TrainerHypers), as well as the file they are in (documentation.py) are mandatory. metatrain will look for these specific names when loading the architecture.

This rigidity allows metatrain to easily generate documentation pages and maintain a consistent experience across all architectures.

For an experimental architecture

For an architecture to be considered accepted as “experimental” into the main metatrain distribution, documentation.py should at least contain:

  • A minimal docstring at the top of the file with at least a short description of the architecture. It should contain as a title the name of the architecture, underlined with equal signs (=).

  • Some documentation for each hyperparameter.

For example, this would be a valid documentation.py file for an experimental architecture:

"""
My architecture
===============

This is an architecture that does amazing things.
"""

class ModelHypers:

    size = 150
    """Size of the model's hidden layers."""
    mode = "strict"
    """Mode of operation for the model."""

class TrainerHypers:
    learning_rate = 1e-3
    """Initial learning rate for the optimizer."""
    lr_scheduler = "CosineAnnealing"
    """Type of learning rate scheduler to use."""

You can check this section to understand how the module docstring will be used to generate the documentation page for the architecture.

For a stable architecture

Going from experimental to stable architecture requires one last step: documentation of the hyperparameters types. This is done using TypedDict and Python’s type hinting system, and it allows metatrain to automatically validate user inputs. By doing validation, metatrain can give users meaningful error messages when the provided hyperparameters are invalid, avoiding errors deep inside the architecture that would be harder to understand.

Here is the example of the previous documentation.py file, now ready for the architecture to be considered stable:

"""
My architecture
===============

This is an architecture that does amazing things.
"""
from typing_extensions import TypedDict
from typing import Literal

class ModelHypers(TypedDict):

    size: int = 150
    """Size of the model's hidden layers."""
    mode: Literal["strict", "lenient"] = "strict"
    """Mode of operation for the model."""

class TrainerHypers(TypedDict):
    learning_rate: float = 1e-3
    """Initial learning rate for the optimizer."""
    lr_scheduler: Literal["CosineAnnealing", "StepLR"] = "CosineAnnealing"
    """Type of learning rate scheduler to use."""

Note

It is important to use typing_extensions.TypedDict instead of typing.TypedDict for compatibility with python <= 3.12 in pydantic’s validation system.

With this, you will be almost ready to have your architecture accepted as stable. The last step is to update the Model and Trainer classes so that they are aware of the hyperparameter types. This will help static type checkers like mypy catch bugs in your code, as well as improving the development experience in IDE’s like VSCode or PyCharm. To do this, you just have to:

  • Make your model and trainer classes inherit from ModelInterface[ModelHypers] and TrainerInterface[TrainerHypers] respectively, instead of just ModelInterface and TrainerInterface.

  • Add the hypers type annotation to the hypers argument of the __init__ method of both classes, as well as any other method that takes hyperparameters as input (like Trainer.load_checkpoint).

For example, for the model:

import torch
from metatomic.torch import DatasetInfo, ModelMetadata

from metatrain.utils.abc import ModelInterface

# New import to get the ModelHypers type
from .documentation import ModelHypers

class MyModel(ModelInterface[ModelHypers]): # Add the hypers type here
    __checkpoint_version__ = 1
    __supported_devices__ = ["cuda", "cpu"]
    __supported_dtypes__ = [torch.float64, torch.float32]
    __default_metadata__ = ModelMetadata(
        references={"implementation": ["ref1"], "architecture": ["ref2"]}
    )

    # Type hint the hypers argument of __init__
    def __init__(self, hypers: ModelHypers, dataset_info: DatasetInfo):
        super().__init__(hypers, dataset_info)
        ...

Documentation page

By following the guidelines for documenting hyperparameters, metatrain will automatically generate a documentation page for the new architecture. This documentation page will contain information about how to install your architecture, the default hyperparameters, and the descriptions of all the hyperparameters for both the model and the trainer.

The documentation page will be generated from the docstring at the top of the documentation.py file, as well as the ModelHypers and TrainerHypers classes defined there. Here is the description of how the docstring will be generated:

class src.architectures.generate.ArchitectureDocVariables[source]

Variables to use inside the architecture documentation.

The docstring of the architecture will be processed as a jinja template. You can find documentation about them here , but the simplest functionality consists of using variables enclosed in double curly braces {{variable_name}}, which will be replaced by their corresponding value.

For example, a file with the following content:

This is the documentation for {{architecture}}.

generates a documentation file that for the architecture pet would be:

This is the documentation for pet.

There are some special variables that start with SECTION_. These contain the content of different sections of the documentation, and they will be appended to the docstring if they are not already present. For example, given the docstring:

"""
My architecture
===============

This is my architecture.

{{SECTION_DEFAULT_HYPERS}}

Some important section
======================

Explain something important here.
"""

The final documentation will append to the docstring all the sections except SECTION_DEFAULT_HYPERS, since it is already present.

Following you can find a description of all the available variables. The sections are appended in the order documented here.

SECTION_INSTALLATION: str

Section containing installation instructions for this architecture.

SECTION_DEFAULT_HYPERS: str

Section containing a yaml file with the default hyperparameters for this architecture.

SECTION_MODEL_HYPERS: str

Section containing the description of the model hyperparameters for this architecture.

SECTION_TRAINER_HYPERS: str

Section containing the description of the trainer hyperparameters for this architecture.

SECTION_REFERENCES: str

Section containing references for this architecture. It will render the references that have been used as :footcite:p: during the architecture documentation.

architecture: str

The name of the architecture.

This excludes any ‘experimental.’ or ‘deprecated.’ prefix.

architecture_path: str

The full python import path to the architecture.

E.g.: "metatrain.experimental.my_architecture"

default_hypers_path: str

Path to the yaml file with the default hyperparameters for this architecture.

This is a path relative to the docs/src/architectures/generated directory.

model_hypers_path: str

The full python import path to the model’s hypers class of this architecture.

E.g.: "metatrain.pet.documentation.ModelHypers"

trainer_hypers_path: str

The full python import path to the trainer’s hypers class of this architecture.

E.g.: "metatrain.pet.documentation.TrainerHypers"

model_hypers: list[str]

List of model hyperparameter names for this architecture.

trainer_hypers: list[str]

List of trainer hyperparameter names for this architecture.

Checkpoint versioning

Checkpoints are used to save the weights of a models and the state of the trainer to disk, enabling to restart interupted training runs, to fine-tune existing models on new dataset, and to export standalone models based on TorchScript.

A checkpoint created for one version might need to be read again by a later version of the architecture, where the internal structure might have changed. To enable this, all Model classes are required to have a __checkpoint_version__ class attribute containing the version of the checkoint, as a strictly inreasing integer. Additionally, architectures should provide an upgrade_checkpoint(checkpoint: Dict) -> Dict function, that will be called when a user is trying to load some outdated checkpoint. This function is responsible for updating the checkpoint data and returning a checkpoint compatible with the current version.

Similarly, the Trainer state is also saved in checkpoint and used to restart training. All trainer must thus have a __checkpoint_version__ class attribute as well as a upgrade_checkpoint(checkpoint: Dict) -> Dict function to updgrade from previous checkpoints.

Testing (tests/)

Metatrain aims to provide users with a consistent experience across architectures. To ensure this, we must test that all architectures behave in the “metatrain way”.

The good news is: you don’t have to write any tests! Since we know that writing tests is not an enjoyable experience, we provide the tests, you just have to make sure your architecture passes them. This approach has several advantages:

  • It saves you time and effort, since you don’t have to write tests.

  • It makes you confident that the architecture is well integrated into metatrain.

  • New architectures have many lines of new code and they can be hard to review, so the shared test suite helps us understanding if the architecture is compliant and ready to be merged.

  • Users benefit from it, since they are guaranteed a consistent experience across architectures.

To make the tests run for your architecture, you should follow these steps:

Step 1: Create a tests/ subdirectory inside your architecture directory.

Step 2: Inside the tests/ directory, create a new file called test_basic.py.

Step 3: The test_basic.py file should contain the relevant classes from metatrain.utils.testing. Each <*>Tests class tests a different kind of functionality, and can be tuned to enable/disable certain tests for your architecture. You can get inspired by existing architectures’ test_basic.py files, but here is an example for an architecture called experimental.myarchitecture:

from metatrain.utils.testing import (
    AutogradTests,
    CheckpointTests,
    ExportedTests,
    InputTests,
    OutputTests,
    TorchscriptTests,
    TrainingTests,
)

class TestInput(InputTests):
    architecture = "experimental.myarchitecture"

class TestAutograd(AutogradTests):
    architecture = "experimental.myarchitecture"

class TestTorchscript(TorchscriptTests):
    architecture = "experimental.myarchitecture"

class TestExported(ExportedTests):
    architecture = "experimental.myarchitecture"

class TestTraining(TrainingTests):
    architecture = "experimental.myarchitecture"

class TestCheckpoints(CheckpointTests):
    architecture = "experimental.myarchitecture"

Some test suite might not apply to your architecture, e.g. if your model does not support autograd. In that case, simply explain this in your PR and the maintainers will help you decide if it’s ok to just omit them. You can of course add more tests that you find relevant for your architecture, but passing metatrain’s shared test suite is a sufficient condition for merging a new architecture.

Step 4: Add your architecture tests to the tox.ini file. For this, you have to add a section [testenv:myarchitecture-tests]. You can get inspired by existing architectures, e.g. the section [testenv:pet-tests]. You will also need to add your tests to the envlist variable at the top of the tox.ini file.

Step 5: Run your tests. For this, you will need to install tox. You can do this with pip install tox. Then, from the root of the repository, run tox -e myarchitecture-tests. See the contributing page for more details on how to run tests.

Step 6: Add your architecture tests to the continuous integration (CI) system. This is done by adding myarchitecture-tests to the file .github/workflows/architecture-tests.yml.