Note
Go to the end to download the full example code.
Convenience nn
modules¶
This example demonstrates the use of convenience modules in metatensor-learn to build simple multi-layer perceptrons.
Note
The convenience modules introduced in this tutorial are designed to be used to
prototype new architectures for simple models. If you already have a more complex
models, you can also wrap it in ModuleMap
objects to make it compatible with
metatensor. This is covered in the later tutorial using module maps.
from typing import Union
import torch
import metatensor.torch as mts
from metatensor.torch import Labels, TensorMap
from metatensor.torch.learn.nn import Linear, ReLU, Sequential
torch.manual_seed(42)
torch.set_default_dtype(torch.float64)
Introduction to native torch.nn
modules¶
metatensor-learn’s neural network modules are designed as
TensorMap
-compatible analogues to the torch API. Before looking into the
metatensor-learn version
, it is instructive to recap torch’s native nn
modules
to see how they work.
First, let’s define a random tensor that we will treat as some intermediate representation. We will build a multi-layer perceptron to transform this tensor into a prediction.
Let’s say we have 100 samples, the size of the input latent space is 128, and the target property is of dimension 1. We will start with a simple linear layer to map the latent representation to a prediction of the target
n_samples = 100
in_features = 128
out_features = 1
feature_tensor = torch.randn(n_samples, in_features)
# define a dummy target
target_tensor = torch.randn(n_samples, 1)
# initialize the torch linear layer
linear_torch = torch.nn.Linear(in_features, out_features, bias=True)
# define a loss function
loss_fn_torch = torch.nn.MSELoss(reduction="sum")
# construct a basic training loop. For brevity we will not use datasets or dataloaders.
def training_loop(
model: torch.nn.Module,
loss_fn: torch.nn.Module,
features: Union[torch.Tensor, TensorMap],
targets: Union[torch.Tensor, TensorMap],
) -> None:
"""A basic training loop for a model and loss function."""
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for epoch in range(1001):
optimizer.zero_grad()
predictions = model(features)
if isinstance(predictions, torch.ScriptObject):
# assume a TensorMap and check metadata is equivalent
assert mts.equal_metadata(predictions, targets)
loss = loss_fn(predictions, targets)
loss.backward()
optimizer.step()
if epoch % 100 == 0:
print(f"epoch: {epoch}, loss: {loss}")
print("with NN = [Linear]")
training_loop(linear_torch, loss_fn_torch, feature_tensor, target_tensor)
with NN = [Linear]
epoch: 0, loss: 154.65348388744997
epoch: 100, loss: 38.054702446604324
epoch: 200, loss: 15.78208526134593
epoch: 300, loss: 8.832599443098466
epoch: 400, loss: 5.785132286612015
epoch: 500, loss: 4.024997543785835
epoch: 600, loss: 2.853829558976215
epoch: 700, loss: 2.041404671831072
epoch: 800, loss: 1.4712668157705777
epoch: 900, loss: 1.0669854246909265
epoch: 1000, loss: 0.7762888847168385
Now run the training loop, this time with a nonlinear multi-layer perceptron using
torch.nn.Sequential
hidden_layer_width = 64
mlp_torch = torch.nn.Sequential(
torch.nn.Linear(in_features, hidden_layer_width),
torch.nn.ReLU(),
torch.nn.Linear(hidden_layer_width, out_features),
)
# run training again
print("with NN = [Linear, ReLU, Linear]")
training_loop(mlp_torch, loss_fn_torch, feature_tensor, target_tensor)
with NN = [Linear, ReLU, Linear]
epoch: 0, loss: 124.86745059607944
epoch: 100, loss: 0.0033502901826742383
epoch: 200, loss: 1.3219911454824964e-07
epoch: 300, loss: 3.237401278978944e-12
epoch: 400, loss: 1.2661473760530034e-16
epoch: 500, loss: 1.6327396230735035e-21
epoch: 600, loss: 7.231597104245314e-26
epoch: 700, loss: 3.1175780326180713e-30
epoch: 800, loss: 9.021808175894339e-31
epoch: 900, loss: 1.3391743219708977e-30
epoch: 1000, loss: 1.3798114437974684e-30
Using metatensor-learn nn
layers¶
Now we’re ready to see how the nn
module in metatensor-learn
works.
Create some dummy data, this time in TensorMap
format. Starting simple, we
will define a TensorMap
with only one TensorBlock
, containing
the latent space features from above.
feature_tensormap = TensorMap(
keys=Labels.single(),
blocks=[mts.block_from_array(feature_tensor)],
)
target_tensormap = TensorMap(
keys=Labels.single(),
blocks=[mts.block_from_array(target_tensor)],
)
# for supervised learning, inputs and labels must have the same metadata for all axes
# except the properties dimension, as this is the dimension that is transformed by the
# neural network.
if mts.equal_metadata(
feature_tensormap, target_tensormap, check=["samples", "components"]
):
print("metadata check passed!")
else:
raise ValueError(
"input and output TensorMaps must have matching keys, samples, "
"and components metadata"
)
# use metatensor-learn's Linear layer. We need to pass the target properties labels so
# that the prediction TensorMap is annotated with the correct metadata.
in_keys = feature_tensormap.keys
linear_mts = Linear(
in_keys=in_keys,
in_features=in_features,
out_properties=[block.properties for block in target_tensormap],
bias=True,
)
# define a custom loss function for TensorMaps that computes the squared error and
# reduces by sum
class TensorMapLoss(torch.nn.Module):
"""
A custom loss function for TensorMaps that computes the squared error and reduces by
sum.
"""
def __init__(self) -> None:
super().__init__()
def forward(self, input: TensorMap, target: TensorMap) -> torch.Tensor:
"""
Computes the total squared error between the ``input`` and ``target``
TensorMaps.
"""
# input and target should have equal metadata over all axes
assert mts.equal_metadata(input, target)
squared_loss = 0
for key in input.keys:
squared_loss += torch.sum((input[key].values - target[key].values) ** 2)
return squared_loss
loss_fn_mts = TensorMapLoss()
# run the training loop
print("with NN = [Linear]")
training_loop(linear_mts, loss_fn_mts, feature_tensormap, target_tensormap)
metadata check passed!
with NN = [Linear]
epoch: 0, loss: 162.54238216562683
epoch: 100, loss: 37.13420486842371
epoch: 200, loss: 14.928916323825199
epoch: 300, loss: 7.814822220244354
epoch: 400, loss: 4.9465831283493635
epoch: 500, loss: 3.4951382254338874
epoch: 600, loss: 2.592461940544501
epoch: 700, loss: 1.962076742407893
epoch: 800, loss: 1.4971224460199135
epoch: 900, loss: 1.1434925338675401
epoch: 1000, loss: 0.8695465460715763
Now construct a nonlinear MLP instead. Here we use metatensor-learn’s Sequential module, along with some nonlinear activation modules. We only need to pass the properties metadata for the output layer, for the hidden layers, we can just pass the layer dimension
mlp_mts = Sequential(
in_keys,
Linear(
in_keys=in_keys,
in_features=in_features,
out_features=hidden_layer_width,
bias=True,
),
ReLU(in_keys=in_keys), # can also use Tanh or SiLU
Linear(
in_keys=in_keys,
in_features=hidden_layer_width,
out_properties=[block.properties for block in target_tensormap],
bias=True,
),
)
# run the training loop
print("with NN = [Linear, ReLU, Linear]")
training_loop(mlp_mts, loss_fn_mts, feature_tensormap, target_tensormap)
with NN = [Linear, ReLU, Linear]
epoch: 0, loss: 129.38070042737664
epoch: 100, loss: 0.002842084062553065
epoch: 200, loss: 7.457968365452231e-08
epoch: 300, loss: 2.8078539783534333e-12
epoch: 400, loss: 7.27929598306562e-17
epoch: 500, loss: 1.2412278242779927e-21
epoch: 600, loss: 5.968650993176766e-26
epoch: 700, loss: 4.33389908899583e-30
epoch: 800, loss: 1.5688173937147708e-30
epoch: 900, loss: 1.5122191274740909e-30
epoch: 1000, loss: 1.827108673381404e-30
Conclusion¶
This tutorial introduced the convenience modules in metatensor-learn for building
simple neural networks. As we’ve seen, the API is similar to native torch.nn
and
the TensorMap
data type can be easily switched in place for torch Tensors in
existing training loops with minimal changes.
Combined with other learning utilities to construct Datasets and Dataloaders, covered in basic and advanced tutorials, metatensor-learn provides a powerful framework for building and training machine learning models based on the TensorMap data format.
Total running time of the script: (0 minutes 5.082 seconds)