Custom architectures with ModuleMap

This tutorial demonstrates how to build custom architectures compatible with TensorMap objects by combining native torch.nn modules with metatensor-learn’s ModuleMap.

Note

Prior to this tutorial, it is recommended to read the tutorial on using convenience modules, as this tutorial builds on the concepts introduced there.

from typing import List, Union

import torch

import metatensor.torch as mts
from metatensor.torch import Labels, TensorMap
from metatensor.torch.learn.nn import Linear, ModuleMap


torch.manual_seed(42)
torch.set_default_dtype(torch.float64)

Introduction

The previous tutorials cover how to use metatensor learn’s nn convenience modules to build simple multi-layer perceptrons and their equivariance-preserving analogs. Now we will explore the use of a special module called ModuleMap that allows users to wrap any native torch module in a TensorMap compatible manner.

This is useful for building arbitrary architectures containing layers more complex than those found in the standard available layers: namely Linear, Tanh, ReLU, SiLU and LayerNorm and their equivariant counterparts.

First we need to create some dummy data in the TensorMap format, with multiple TensorBlock objects. Here we will focus on unconstrained architectures, as opposed to equivariance preserving ones. The principles in the latter case will be similar, as long as care is taken to build architectures with equivarince-preserving transformations.

Let’s start by defining a random tensor that we will treat as some intermediate representation. We will build a multi-layer perceptron to transform this tensor into a prediction. Here we will define a 3-block tensor map, with variables with the in and out dimensions for each block.

n_samples = 100
in_features = [64, 128, 256]
out_features = [1, 2, 3]

feature_tensormap = TensorMap(
    keys=Labels(["key"], torch.arange(len(out_features)).reshape(-1, 1)),
    blocks=[
        mts.block_from_array(torch.randn(n_samples, in_feats))
        for in_feats in in_features
    ],
)

target_tensormap = TensorMap(
    keys=Labels(["key"], torch.arange(len(out_features)).reshape(-1, 1)),
    blocks=[
        mts.block_from_array(torch.randn(n_samples, out_feats))
        for out_feats in out_features
    ],
)
print("features:", feature_tensormap)
print("target:", target_tensormap)
features: TensorMap with 3 blocks
keys: key
       0
       1
       2
target: TensorMap with 3 blocks
keys: key
       0
       1
       2

Starting simple

Let’s start with a simple linear layer, but this time constructed manually using ModuleMap. Here we want a linear layer for each block, with the correct in and out feature shapes. The result will be a module that is equivalent to the metatensor.torch.learn.nn.Linear module.

in_keys = feature_tensormap.keys

modules = []
for key in in_keys:
    module = torch.nn.Linear(
        in_features=len(feature_tensormap[key].properties),
        out_features=len(target_tensormap[key].properties),
        bias=True,
    )
    modules.append(module)

# initialize the ModuleMap with the input keys, list of modules, and the output
# property labels' metadata.
linear_mmap = ModuleMap(
    in_keys,
    modules,
    out_properties=[target_tensormap[key].properties for key in in_keys],
)
print(linear_mmap)
ModuleMap(
  (0): Linear(in_features=64, out_features=1, bias=True)
  (1): Linear(in_features=128, out_features=2, bias=True)
  (2): Linear(in_features=256, out_features=3, bias=True)
)

ModuleMap automatically handles the forward pass for each block indexed by the in_keys used to initialize it. In cases where the input contains more keys/blocks than what is present in the in_keys` field, the forward pass will only be applied to the blocks that are present in the input. The output will be a new ``TensorMap with the same keys as the input, now with the correct output metadata.

# apply the ModuleMap to the whole tensor map of features
prediction_full = linear_mmap(feature_tensormap)

# filter the features to only contain one of the blocks,
# and pass it through the ModuleMap
prediction_subset = linear_mmap(
    mts.filter_blocks(
        feature_tensormap, Labels(["key"], torch.tensor([1]).reshape(-1, 1))
    )
)

print(prediction_full.keys, prediction_full.blocks())
print(prediction_subset.keys, prediction_subset.blocks())
Labels(
    key
     0
     1
     2
) [TensorBlock
    samples (100): ['sample']
    components (): []
    properties (1): ['property']
    gradients: None
, TensorBlock
    samples (100): ['sample']
    components (): []
    properties (2): ['property']
    gradients: None
, TensorBlock
    samples (100): ['sample']
    components (): []
    properties (3): ['property']
    gradients: None
]
Labels(
    key
     1
) [TensorBlock
    samples (100): ['sample']
    components (): []
    properties (2): ['property']
    gradients: None
]

Now we define a loss function and run a training loop. This is the same as in the previous tutorials.

# define a custom loss function for TensorMaps that computes the squared error and
# reduces by a summation operation
class TensorMapLoss(torch.nn.Module):
    """
    A custom loss function for TensorMaps that computes the squared error and reduces by
    sum.
    """

    def __init__(self) -> None:
        super().__init__()

    def forward(self, _input: TensorMap, target: TensorMap) -> torch.Tensor:
        """
        Computes the total squared error between the ``_input`` and ``target``
        TensorMaps.
        """
        # inputs and targets should have the same metadata over all axes
        assert mts.equal_metadata(_input, target)

        squared_loss = 0
        for key in _input.keys:
            squared_loss += torch.sum((_input[key].values - target[key].values) ** 2)

        return squared_loss


# construct a basic training loop. For brevity we will not use datasets or dataloaders.
def training_loop(
    model: torch.nn.Module,
    loss_fn: torch.nn.Module,
    features: Union[torch.Tensor, TensorMap],
    targets: Union[torch.Tensor, TensorMap],
) -> None:
    """A basic training loop for a model and loss function."""
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    for epoch in range(501):
        optimizer.zero_grad()

        predictions = model(features)

        if isinstance(predictions, torch.ScriptObject):
            # assume a TensorMap and check metadata is equivalent
            assert mts.equal_metadata(predictions, targets)

        loss = loss_fn(predictions, targets)
        loss.backward()
        optimizer.step()

        if epoch % 100 == 0:
            print(f"epoch: {epoch}, loss: {loss}")


loss_fn_mts = TensorMapLoss()

print("with NN = [Linear]")
training_loop(linear_mmap, loss_fn_mts, feature_tensormap, target_tensormap)
with NN = [Linear]
epoch: 0, loss: 769.1011737074874
epoch: 100, loss: 152.43308577101345
epoch: 200, loss: 95.21448836273512
epoch: 300, loss: 74.85845598989255
epoch: 400, loss: 64.58524404959572
epoch: 500, loss: 58.463045951762794

More complex architectures

# Defining more complicated architectures is a matter of building
# ``torch.nn.Sequential`` objects for each block, and wrapping them into a single
# ModuleMap.

hidden_layer_width = 32

modules = []
for key in in_keys:
    module = torch.nn.Sequential(
        torch.nn.LayerNorm(len(feature_tensormap[key].properties)),
        torch.nn.Linear(
            in_features=len(feature_tensormap[key].properties),
            out_features=hidden_layer_width,
            bias=True,
        ),
        torch.nn.ReLU(),
        torch.nn.Linear(
            in_features=hidden_layer_width,
            out_features=len(target_tensormap[key].properties),
            bias=True,
        ),
        torch.nn.Tanh(),
    )
    modules.append(module)

# initialize the ModuleMap as in the previous section.
custom_mmap = ModuleMap(
    in_keys,
    modules,
    out_properties=[target_tensormap[key].properties for key in in_keys],
)
print(custom_mmap)

print("with NN = [LayerNorm, Linear, ReLU, Linear, Tanh]")
training_loop(custom_mmap, loss_fn_mts, feature_tensormap, target_tensormap)
ModuleMap(
  (0): Sequential(
    (0): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
    (1): Linear(in_features=64, out_features=32, bias=True)
    (2): ReLU()
    (3): Linear(in_features=32, out_features=1, bias=True)
    (4): Tanh()
  )
  (1): Sequential(
    (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (1): Linear(in_features=128, out_features=32, bias=True)
    (2): ReLU()
    (3): Linear(in_features=32, out_features=2, bias=True)
    (4): Tanh()
  )
  (2): Sequential(
    (0): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    (1): Linear(in_features=256, out_features=32, bias=True)
    (2): ReLU()
    (3): Linear(in_features=32, out_features=3, bias=True)
    (4): Tanh()
  )
)
with NN = [LayerNorm, Linear, ReLU, Linear, Tanh]
epoch: 0, loss: 628.4115434117372
epoch: 100, loss: 102.15331090664478
epoch: 200, loss: 93.55173279931788
epoch: 300, loss: 92.10713211723296
epoch: 400, loss: 91.57975604835907
epoch: 500, loss: 91.33320804749201

ModuleMap objects can also be wrapped in a torch.nn.torch.nn.Module to allow construction of complex architectures. For instance, we can have a “ResNet”-style neural network module that takes a ModuleMap and applies it, then sums with some residual connections. Wikipedia has a good summary and diagram of this architectural motif, see: https://en.wikipedia.org/wiki/Residual_neural_network .

To do the latter step, we can combine application of the ModuleMap with a Linear convenience layer from metatensor-learn, and the sparse addition operation from metatensor-operations to build a complex architecture.

class ResidualNetwork(torch.nn.Module):
    def __init__(
        self,
        in_keys: Labels,
        in_features: List[int],
        out_properties: List[Labels],
    ) -> None:
        super().__init__()

        # Build the module map as before
        hidden_layer_width = 32
        modules = []
        for in_feats, out_props in zip(in_features, out_properties, strict=True):
            module = torch.nn.Sequential(
                torch.nn.LayerNorm(in_feats),
                torch.nn.Linear(
                    in_features=in_feats,
                    out_features=hidden_layer_width,
                    bias=True,
                ),
                torch.nn.ReLU(),
                torch.nn.Linear(
                    in_features=hidden_layer_width,
                    out_features=len(out_props),
                    bias=True,
                ),
                torch.nn.Tanh(),
            )
            modules.append(module)

        self.module_map = ModuleMap(
            in_keys,
            modules,
            out_properties=out_properties,
        )

        # build the input projection layer
        self.projection = Linear(
            in_keys=in_keys,
            in_features=in_features,
            out_properties=out_properties,
            bias=True,
        )

    def forward(self, features: TensorMap) -> TensorMap:
        # apply the module map to the features
        prediction = self.module_map(features)

        # apply the projection layer to the features
        residual = self.projection(features)

        # add the prediction and residual together using the sparse addition
        # from metatensor-operations
        return mts.add(prediction, residual)


model = ResidualNetwork(
    in_keys=in_keys,
    in_features=in_features,
    out_properties=[block.properties for block in target_tensormap],
)
print("with NN = [LayerNorm, Linear, ReLU, Linear, Tanh] plus residual connections")
training_loop(model, loss_fn_mts, feature_tensormap, target_tensormap)
with NN = [LayerNorm, Linear, ReLU, Linear, Tanh] plus residual connections
epoch: 0, loss: 877.4945870891794
epoch: 100, loss: 18.696672961002434
epoch: 200, loss: 3.845163971248074
epoch: 300, loss: 1.5772835196534585
epoch: 400, loss: 0.8379008238832794
epoch: 500, loss: 0.46250230046218666

Conclusion

In this tutorial we have seen how to build custom architectures using ModuleMap. This allows for arbitrary architectures to be built, as long as the metadata is preserved. We have also seen how to build a custom module that wraps a ModuleMap and adds residual connections.

The key takeaway is that ModuleMap can be used to wrap any combination of native torch.nn modules to make them compatible with TensorMap. In combination with convenience layers seen in the tutorial nn modules basic, and sparse-data operations from metatensor-operations, complex architectures can be built with ease.

Total running time of the script: (0 minutes 5.926 seconds)

Gallery generated by Sphinx-Gallery