Convenience nn modules

This example demonstrates the use of convenience modules in metatensor-learn to build simple multi-layer perceptrons.

Note

The convenience modules introduced in this tutorial are designed to be used to prototype new architectures for simple models. If you already have a more complex models, you can also wrap it in ModuleMap objects to make it compatible with metatensor. This is covered in the later tutorial using module maps.

from typing import Union

import torch

import metatensor.torch as mts
from metatensor.torch import Labels, TensorMap
from metatensor.torch.learn.nn import Linear, ReLU, Sequential


torch.manual_seed(42)
torch.set_default_dtype(torch.float64)

Introduction to native torch.nn modules

metatensor-learn’s neural network modules are designed as TensorMap-compatible analogues to the torch API. Before looking into the metatensor-learn version, it is instructive to recap torch’s native nn modules to see how they work.

First, let’s define a random tensor that we will treat as some intermediate representation. We will build a multi-layer perceptron to transform this tensor into a prediction.

Let’s say we have 100 samples, the size of the input latent space is 128, and the target property is of dimension 1. We will start with a simple linear layer to map the latent representation to a prediction of the target

n_samples = 100
in_features = 128
out_features = 1
feature_tensor = torch.randn(n_samples, in_features)

# define a dummy target
target_tensor = torch.randn(n_samples, 1)

# initialize the torch linear layer
linear_torch = torch.nn.Linear(in_features, out_features, bias=True)

# define a loss function
loss_fn_torch = torch.nn.MSELoss(reduction="sum")


# construct a basic training loop. For brevity we will not use datasets or dataloaders.
def training_loop(
    model: torch.nn.Module,
    loss_fn: torch.nn.Module,
    features: Union[torch.Tensor, TensorMap],
    targets: Union[torch.Tensor, TensorMap],
) -> None:
    """A basic training loop for a model and loss function."""
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    for epoch in range(1001):
        optimizer.zero_grad()

        predictions = model(features)

        if isinstance(predictions, torch.ScriptObject):
            # assume a TensorMap and check metadata is equivalent
            assert mts.equal_metadata(predictions, targets)

        loss = loss_fn(predictions, targets)
        loss.backward()
        optimizer.step()

        if epoch % 100 == 0:
            print(f"epoch: {epoch}, loss: {loss}")


print("with NN = [Linear]")
training_loop(linear_torch, loss_fn_torch, feature_tensor, target_tensor)
with NN = [Linear]
epoch: 0, loss: 154.65348388744997
epoch: 100, loss: 38.054702446604324
epoch: 200, loss: 15.78208526134593
epoch: 300, loss: 8.832599443098466
epoch: 400, loss: 5.785132286612015
epoch: 500, loss: 4.024997543785835
epoch: 600, loss: 2.853829558976215
epoch: 700, loss: 2.041404671831072
epoch: 800, loss: 1.4712668157705777
epoch: 900, loss: 1.0669854246909265
epoch: 1000, loss: 0.7762888847168385

Now run the training loop, this time with a nonlinear multi-layer perceptron using torch.nn.Sequential

hidden_layer_width = 64
mlp_torch = torch.nn.Sequential(
    torch.nn.Linear(in_features, hidden_layer_width),
    torch.nn.ReLU(),
    torch.nn.Linear(hidden_layer_width, out_features),
)

# run training again
print("with NN = [Linear, ReLU, Linear]")
training_loop(mlp_torch, loss_fn_torch, feature_tensor, target_tensor)
with NN = [Linear, ReLU, Linear]
epoch: 0, loss: 124.86745059607944
epoch: 100, loss: 0.0033502901826742383
epoch: 200, loss: 1.3219911454824964e-07
epoch: 300, loss: 3.237401278978944e-12
epoch: 400, loss: 1.2661473760530034e-16
epoch: 500, loss: 1.6327396230735035e-21
epoch: 600, loss: 7.231597104245314e-26
epoch: 700, loss: 3.1175780326180713e-30
epoch: 800, loss: 9.021808175894339e-31
epoch: 900, loss: 1.3391743219708977e-30
epoch: 1000, loss: 1.3798114437974684e-30

Using metatensor-learn nn layers

Now we’re ready to see how the nn module in metatensor-learn works.

Create some dummy data, this time in TensorMap format. Starting simple, we will define a TensorMap with only one TensorBlock, containing the latent space features from above.

feature_tensormap = TensorMap(
    keys=Labels.single(),
    blocks=[mts.block_from_array(feature_tensor)],
)

target_tensormap = TensorMap(
    keys=Labels.single(),
    blocks=[mts.block_from_array(target_tensor)],
)

# for supervised learning, inputs and labels must have the same metadata for all axes
# except the properties dimension, as this is the dimension that is transformed by the
# neural network.
if mts.equal_metadata(
    feature_tensormap, target_tensormap, check=["samples", "components"]
):
    print("metadata check passed!")
else:
    raise ValueError(
        "input and output TensorMaps must have matching keys, samples, "
        "and components metadata"
    )

# use metatensor-learn's Linear layer. We need to pass the target properties labels so
# that the prediction TensorMap is annotated with the correct metadata.
in_keys = feature_tensormap.keys
linear_mts = Linear(
    in_keys=in_keys,
    in_features=in_features,
    out_properties=[block.properties for block in target_tensormap],
    bias=True,
)


# define a custom loss function for TensorMaps that computes the squared error and
# reduces by sum
class TensorMapLoss(torch.nn.Module):
    """
    A custom loss function for TensorMaps that computes the squared error and reduces by
    sum.
    """

    def __init__(self) -> None:
        super().__init__()

    def forward(self, input: TensorMap, target: TensorMap) -> torch.Tensor:
        """
        Computes the total squared error between the ``input`` and ``target``
        TensorMaps.
        """
        # input and target should have equal metadata over all axes
        assert mts.equal_metadata(input, target)

        squared_loss = 0
        for key in input.keys:
            squared_loss += torch.sum((input[key].values - target[key].values) ** 2)

        return squared_loss


loss_fn_mts = TensorMapLoss()

# run the training loop
print("with NN = [Linear]")
training_loop(linear_mts, loss_fn_mts, feature_tensormap, target_tensormap)
metadata check passed!
with NN = [Linear]
epoch: 0, loss: 162.54238216562683
epoch: 100, loss: 37.13420486842371
epoch: 200, loss: 14.928916323825199
epoch: 300, loss: 7.814822220244354
epoch: 400, loss: 4.9465831283493635
epoch: 500, loss: 3.4951382254338874
epoch: 600, loss: 2.592461940544501
epoch: 700, loss: 1.962076742407893
epoch: 800, loss: 1.4971224460199135
epoch: 900, loss: 1.1434925338675401
epoch: 1000, loss: 0.8695465460715763

Now construct a nonlinear MLP instead. Here we use metatensor-learn’s Sequential module, along with some nonlinear activation modules. We only need to pass the properties metadata for the output layer, for the hidden layers, we can just pass the layer dimension

mlp_mts = Sequential(
    in_keys,
    Linear(
        in_keys=in_keys,
        in_features=in_features,
        out_features=hidden_layer_width,
        bias=True,
    ),
    ReLU(in_keys=in_keys),  # can also use Tanh or SiLU
    Linear(
        in_keys=in_keys,
        in_features=hidden_layer_width,
        out_properties=[block.properties for block in target_tensormap],
        bias=True,
    ),
)


# run the training loop
print("with NN = [Linear, ReLU, Linear]")
training_loop(mlp_mts, loss_fn_mts, feature_tensormap, target_tensormap)
with NN = [Linear, ReLU, Linear]
epoch: 0, loss: 129.38070042737664
epoch: 100, loss: 0.002842084062553065
epoch: 200, loss: 7.457968365452231e-08
epoch: 300, loss: 2.8078539783534333e-12
epoch: 400, loss: 7.27929598306562e-17
epoch: 500, loss: 1.2412278242779927e-21
epoch: 600, loss: 5.968650993176766e-26
epoch: 700, loss: 4.33389908899583e-30
epoch: 800, loss: 1.5688173937147708e-30
epoch: 900, loss: 1.5122191274740909e-30
epoch: 1000, loss: 1.827108673381404e-30

Conclusion

This tutorial introduced the convenience modules in metatensor-learn for building simple neural networks. As we’ve seen, the API is similar to native torch.nn and the TensorMap data type can be easily switched in place for torch Tensors in existing training loops with minimal changes.

Combined with other learning utilities to construct Datasets and Dataloaders, covered in basic and advanced tutorials, metatensor-learn provides a powerful framework for building and training machine learning models based on the TensorMap data format.

Total running time of the script: (0 minutes 5.082 seconds)

Gallery generated by Sphinx-Gallery