Note
Go to the end to download the full example code.
Custom architectures with ModuleMap
¶
This tutorial demonstrates how to build custom architectures compatible with
TensorMap
objects by combining native torch.nn
modules with metatensor-learn’s
ModuleMap
.
Note
Prior to this tutorial, it is recommended to read the tutorial on using convenience modules, as this tutorial builds on the concepts introduced there.
from typing import List, Union
import torch
import metatensor.torch as mts
from metatensor.torch import Labels, TensorMap
from metatensor.torch.learn.nn import Linear, ModuleMap
torch.manual_seed(42)
torch.set_default_dtype(torch.float64)
Introduction¶
The previous tutorials cover how to use metatensor learn’s nn
convenience modules
to build simple multi-layer perceptrons and their equivariance-preserving analogs. Now
we will explore the use of a special module called ModuleMap
that allows users to
wrap any native torch module to be compatible with a TensorMap
.
This is useful for building arbitrary architectures containing layers more complex
than found in the standard available layers: namely Linear
, Tanh
, ReLU
,
SiLU
and LayerNorm
and their equivariant counterparts.
First we need to create some dummy data in TensorMap
format, with multiple
TensorBlock
. Here we will focus on unconstrained architectures, as opposed
to equivariance preserving ones. The principles in the latter case will be similar, as
long as care is taken to build architectures with equivarince-preserving
transformations.
Let’s start by defining a random tensor that we will treat as some intermediate representation. We will build a multi-layer perceptron to transform this tensor into a prediction. Here we will define a 3-block tensor map, with variable in and out dimensions for each block.
n_samples = 100
in_features = [64, 128, 256]
out_features = [1, 2, 3]
feature_tensormap = TensorMap(
keys=Labels(["key"], torch.arange(len(out_features)).reshape(-1, 1)),
blocks=[
mts.block_from_array(torch.randn(n_samples, in_feats))
for in_feats in in_features
],
)
target_tensormap = TensorMap(
keys=Labels(["key"], torch.arange(len(out_features)).reshape(-1, 1)),
blocks=[
mts.block_from_array(torch.randn(n_samples, out_feats))
for out_feats in out_features
],
)
print("features:", feature_tensormap)
print("target:", target_tensormap)
features: TensorMap with 3 blocks
keys: key
0
1
2
target: TensorMap with 3 blocks
keys: key
0
1
2
Starting simple¶
Let’s start with a simple linear layer, but this time constructed manually using
ModuleMap
. Here we want a linear layer for each block, with the correct in and out
feature sizes. The result will be a module that is equivalent to the
metatensor.torch.learn.nn.Linear
module.
in_keys = feature_tensormap.keys
modules = []
for key in in_keys:
module = torch.nn.Linear(
in_features=len(feature_tensormap[key].properties),
out_features=len(target_tensormap[key].properties),
bias=True,
)
modules.append(module)
# initialize the ModuleMap with the input keys, list of modules, and the output
# properties labels metadata.
linear_mmap = ModuleMap(
in_keys,
modules,
out_properties=[target_tensormap[key].properties for key in in_keys],
)
print(linear_mmap)
ModuleMap(
(0): Linear(in_features=64, out_features=1, bias=True)
(1): Linear(in_features=128, out_features=2, bias=True)
(2): Linear(in_features=256, out_features=3, bias=True)
)
ModuleMap
automatically handles the forward pass for each block indexed by the
in_keys
used to intialize it. In the case where the input contains more
keys/blocks than what’s given as in_keys`, the forward pass will only be applied to
the blocks that are present in the input. The output will be a new ``TensorMap
with
the same keys as the input, with the correct output meatdata.
# apply the ModuleMap to the whole feature tensor map
prediction_full = linear_mmap(feature_tensormap)
# filter the features to only contain one of the blocks, and pass through the ModuleMap
prediction_subset = linear_mmap(
mts.filter_blocks(
feature_tensormap, Labels(["key"], torch.tensor([1]).reshape(-1, 1))
)
)
print(prediction_full.keys, prediction_full.blocks())
print(prediction_subset.keys, prediction_subset.blocks())
Labels(
key
0
1
2
) [TensorBlock
samples (100): ['sample']
components (): []
properties (1): ['property']
gradients: None
, TensorBlock
samples (100): ['sample']
components (): []
properties (2): ['property']
gradients: None
, TensorBlock
samples (100): ['sample']
components (): []
properties (3): ['property']
gradients: None
]
Labels(
key
1
) [TensorBlock
samples (100): ['sample']
components (): []
properties (2): ['property']
gradients: None
]
Now define a loss function and run a training loop. This is the same as done in the previous tutorials.
# define a custom loss function for TensorMaps that computes the squared error and
# reduces by sum
class TensorMapLoss(torch.nn.Module):
"""
A custom loss function for TensorMaps that computes the squared error and reduces by
sum.
"""
def __init__(self) -> None:
super().__init__()
def forward(self, input: TensorMap, target: TensorMap) -> torch.Tensor:
"""
Computes the total squared error between the ``input`` and ``target``
TensorMaps.
"""
# input and target should have equal metadata over all axes
assert mts.equal_metadata(input, target)
squared_loss = 0
for key in input.keys:
squared_loss += torch.sum((input[key].values - target[key].values) ** 2)
return squared_loss
# construct a basic training loop. For brevity we will not use datasets or dataloaders.
def training_loop(
model: torch.nn.Module,
loss_fn: torch.nn.Module,
features: Union[torch.Tensor, TensorMap],
targets: Union[torch.Tensor, TensorMap],
) -> None:
"""A basic training loop for a model and loss function."""
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for epoch in range(501):
optimizer.zero_grad()
predictions = model(features)
if isinstance(predictions, torch.ScriptObject):
# assume a TensorMap and check metadata is equivalent
assert mts.equal_metadata(predictions, targets)
loss = loss_fn(predictions, targets)
loss.backward()
optimizer.step()
if epoch % 100 == 0:
print(f"epoch: {epoch}, loss: {loss}")
loss_fn_mts = TensorMapLoss()
print("with NN = [Linear]")
training_loop(linear_mmap, loss_fn_mts, feature_tensormap, target_tensormap)
with NN = [Linear]
epoch: 0, loss: 769.1011737074874
epoch: 100, loss: 152.43308577101345
epoch: 200, loss: 95.21448836273512
epoch: 300, loss: 74.85845598989255
epoch: 400, loss: 64.58524404959572
epoch: 500, loss: 58.463045951762794
More complex architectures¶
# Defining more complicated architectures is a matter of building
# ``torch.nn.Sequential`` objects for each block, and wrapping them into a single
# ModuleMap.
hidden_layer_width = 32
modules = []
for key in in_keys:
module = torch.nn.Sequential(
torch.nn.LayerNorm(len(feature_tensormap[key].properties)),
torch.nn.Linear(
in_features=len(feature_tensormap[key].properties),
out_features=hidden_layer_width,
bias=True,
),
torch.nn.ReLU(),
torch.nn.Linear(
in_features=hidden_layer_width,
out_features=len(target_tensormap[key].properties),
bias=True,
),
torch.nn.Tanh(),
)
modules.append(module)
# initialize the ModuleMap with the input keys, list of modules, and the output
# properties labels metadata.
custom_mmap = ModuleMap(
in_keys,
modules,
out_properties=[target_tensormap[key].properties for key in in_keys],
)
print(custom_mmap)
print("with NN = [LayerNorm, Linear, ReLU, Linear, Tanh]")
training_loop(custom_mmap, loss_fn_mts, feature_tensormap, target_tensormap)
ModuleMap(
(0): Sequential(
(0): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
(1): Linear(in_features=64, out_features=32, bias=True)
(2): ReLU()
(3): Linear(in_features=32, out_features=1, bias=True)
(4): Tanh()
)
(1): Sequential(
(0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
(1): Linear(in_features=128, out_features=32, bias=True)
(2): ReLU()
(3): Linear(in_features=32, out_features=2, bias=True)
(4): Tanh()
)
(2): Sequential(
(0): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
(1): Linear(in_features=256, out_features=32, bias=True)
(2): ReLU()
(3): Linear(in_features=32, out_features=3, bias=True)
(4): Tanh()
)
)
with NN = [LayerNorm, Linear, ReLU, Linear, Tanh]
epoch: 0, loss: 628.4115434117372
epoch: 100, loss: 102.15331090664478
epoch: 200, loss: 93.55173279931788
epoch: 300, loss: 92.10713211723296
epoch: 400, loss: 91.57975604835907
epoch: 500, loss: 91.33320804749201
ModuleMap can also be wrapped in a torch.nn.torch.nn.Module
to allow construction
of complex architectures. For instance, we can be a “ResNet”-style neural network
module that takes a ModuleMap and applies it, then sums with some residual
connections. Wikipedia has a good summary and diagram of this architectural motif,
see: https://en.wikipedia.org/wiki/Residual_neural_network .
To do the latter step, we can combine application of the ModuleMap
with the
Linear
convenience layer from metatensor-learn, and the sparse addition operation
from metatensor-operations
to build a complex architecture.
class ResidualNetwork(torch.nn.Module):
def __init__(
self,
in_keys: Labels,
in_features: List[int],
out_properties: List[Labels],
) -> None:
super().__init__()
# Build the module map as before
hidden_layer_width = 32
modules = []
for in_feats, out_props in zip(in_features, out_properties):
module = torch.nn.Sequential(
torch.nn.LayerNorm(in_feats),
torch.nn.Linear(
in_features=in_feats,
out_features=hidden_layer_width,
bias=True,
),
torch.nn.ReLU(),
torch.nn.Linear(
in_features=hidden_layer_width,
out_features=len(out_props),
bias=True,
),
torch.nn.Tanh(),
)
modules.append(module)
self.module_map = ModuleMap(
in_keys,
modules,
out_properties=out_properties,
)
# build the input projection layer
self.projection = Linear(
in_keys=in_keys,
in_features=in_features,
out_properties=out_properties,
bias=True,
)
def forward(self, features: TensorMap) -> TensorMap:
# apply the module map to the features
prediction = self.module_map(features)
# apply the projection layer to the features
residual = self.projection(features)
# add the prediction and residual together using the sparse addition from
# metatensor-operations
return mts.add(prediction, residual)
model = ResidualNetwork(
in_keys=in_keys,
in_features=in_features,
out_properties=[block.properties for block in target_tensormap],
)
print("with NN = [LayerNorm, Linear, ReLU, Linear, Tanh] plus residual connections")
training_loop(model, loss_fn_mts, feature_tensormap, target_tensormap)
with NN = [LayerNorm, Linear, ReLU, Linear, Tanh] plus residual connections
epoch: 0, loss: 877.4945870891794
epoch: 100, loss: 18.696672961002434
epoch: 200, loss: 3.845163971248074
epoch: 300, loss: 1.5772835196534585
epoch: 400, loss: 0.8379008238832794
epoch: 500, loss: 0.46250230046218666
Conclusion¶
In this tutorial we have seen how to build custom architectures using ModuleMap
.
This allows for arbitrary architectures to be built, as long as the metadata is
preserved. We have also seen how to build a custom module that wraps a ModuleMap
and adds residual connections.
The key takeaway is that ModuleMap
can be used to wrap any combination of native
torch.nn
modules to make them compatible with TensorMap
. In combination with
convenience layers seen in the tutorial nn modules basic, and sparse-data operations from
metatensor-operations
, complex architectures can be built with ease.
Total running time of the script: (0 minutes 6.497 seconds)