Equivariance-preserving nn modules

This example demonstrates the use of convenience modules in metatensor-learn to build simple equivariance-preserving multi-layer perceptrons.

Note

Prior to this tutorial, it is recommended to read the tutorial on using convenience modules, as this tutorial builds on the concepts introduced there.

import torch

import metatensor.torch as mts
from metatensor.torch import Labels, TensorBlock, TensorMap
from metatensor.torch.learn.nn import (
    EquivariantLinear,
    InvariantReLU,
    Sequential,
)


torch.manual_seed(42)
torch.set_default_dtype(torch.float64)

Introduction

Often the targets of machine learning are physical observables with certain symmetries, such as invariance with respect to translation or equivariance with respect to rotation (rotating the input structure means that the target should be rotated in the same way).

Many successful approaches to these learning tasks use equivariance-preserving architectures to map equivariant features onto predictions of an equivariant target.

In this example we will demonstrate how to build an equivariance-preserving multi-layer perceptron (MLP) on top of some equivariant features.

Let’s load the spherical expansion from the first steps tutorial.

spherical_expansion = mts.load("../core/spherical-expansion.mts")

# metatensor-learn modules currently do not support TensorMaps with gradients
spherical_expansion = mts.remove_gradients(spherical_expansion)
print(spherical_expansion)
print("\nNumber of blocks in the spherical expansion:", len(spherical_expansion))
TensorMap with 12 blocks
keys: o3_lambda  o3_sigma  center_type  neighbor_type
          0         1           6             6
          1         1           6             6
                           ...
          1         1           8             8
          2         1           8             8

Number of blocks in the spherical expansion: 12

As a reminder, these are the coefficients of the spherical-basis decompositions of a smooth Gaussian density representation of 3D point cloud. In this case, the point cloud is a set of decorated atomic positions.

The important part here is that these features are block sparse in angular momentum channel (key dimension "o3_lambda"), with each block having a different behaviour under rigid rotation by the SO(3) group.

In general, blocks that are invariant under rotation (where o3_lambda == 0) can be transformed in arbitrary (i.e. nonlinear) ways in the mapping from features to target, while covariant blocks (where o3_lambda > 0) must be transformed in a way that preserves the equivariance of the features. The simplest way to do this is to use only linear transformations for the latter.

Define equivariant target data

Let’s build some dummy target data: we will predict a global (i.e. per-system) rank-2 symmetric tensor, which decomposes into o3_lambda = [0, 2] angular momenta channels when expressed in the spherical basis. An example of such a target in atomistic machine learning is the electronic polarizability of a molecule.

Our target will be block sparse with "o3_lambda" key dimensions equal to [0, 2], and as this is a real- (not pseudo-) tensor, the inversion sigma ("o3_sigma") will be +1.

target_tensormap = TensorMap(
    keys=Labels(["o3_lambda", "o3_sigma"], torch.tensor([[0, 1], [2, 1]])),
    blocks=[
        TensorBlock(
            values=torch.randn((1, 1, 1), dtype=torch.float64),
            # only one system
            samples=Labels(["system"], torch.tensor([[0]])),
            # o3_mu = [0]
            components=[Labels(["o3_mu"], torch.tensor([[0]]))],
            # only one 'property' (the L=0 part of the polarizability)
            properties=Labels(["_"], torch.tensor([[0]])),
        ),
        TensorBlock(
            values=torch.randn((1, 5, 1), dtype=torch.float64),
            # only one system
            samples=Labels(["system"], torch.tensor([[0]])),
            # o3_mu = [-2, -1, 0, +1, +2]
            components=[Labels(["o3_mu"], torch.tensor([[-2], [-1], [0], [1], [2]]))],
            # only one 'property' (the L=2 part of the polarizability)
            properties=Labels(["_"], torch.tensor([[0]])),
        ),
    ],
)
print(target_tensormap, target_tensormap[0])
TensorMap with 2 blocks
keys: o3_lambda  o3_sigma
          0         1
          2         1 TensorBlock
    samples (1): ['system']
    components (1): ['o3_mu']
    properties (1): ['_']
    gradients: None

Filter the feature blocks to only keep the blocks with symmetries that match the target: as our target only contains o3_lambda = [0, 2] channels, we only need these!

spherical_expansion = mts.filter_blocks(spherical_expansion, target_tensormap.keys)
print(spherical_expansion)
print(
    "\nNumber of blocks in the filtered spherical expansion:", len(spherical_expansion)
)
TensorMap with 8 blocks
keys: o3_lambda  o3_sigma  center_type  neighbor_type
          0         1           6             6
          2         1           6             6
                           ...
          0         1           8             8
          2         1           8             8

Number of blocks in the filtered spherical expansion: 8

Using equivariant convenience layers

Now we can build our neural network. Our architecture will consist of separate “block models”, i.e. transformations with separate learnable weights for each block in the spherical expansion. This is in contrast to the previous tutorial nn modules basic, where we only had a single block in our features and targets.

Furthermore, as the features are a per-atom quantity, we will use some sparse tensor operations to sum the contributions of all atoms in the system to give a per-sytem prediction. For this we will use metatensor-operations.

Starting simple, let’s define the neural network as just a simple linear layer. As stated before, only linear transformations must be applied to covariant blocks, in this case those with o3_lambda = 2, while nonlinear transformations can be applied to invariant blocks where o3_lambda = 0. We will use the EquivariantLinear module for this.

in_keys = spherical_expansion.keys
equi_linear = EquivariantLinear(
    in_keys=in_keys,
    in_features=[len(spherical_expansion[key].properties) for key in in_keys],
    out_features=1,  # for all blocks
)
print(in_keys)
print(equi_linear)
Labels(
    o3_lambda  o3_sigma  center_type  neighbor_type
        0         1           6             6
        2         1           6             6
                         ...
        0         1           8             8
        2         1           8             8
)
EquivariantLinear(
  (module_map): ModuleMap(
    (0): Linear(in_features=5, out_features=1, bias=True)
    (1): Linear(in_features=5, out_features=1, bias=False)
    (2): Linear(in_features=5, out_features=1, bias=True)
    (3): Linear(in_features=5, out_features=1, bias=False)
    (4): Linear(in_features=5, out_features=1, bias=True)
    (5): Linear(in_features=5, out_features=1, bias=False)
    (6): Linear(in_features=5, out_features=1, bias=True)
    (7): Linear(in_features=5, out_features=1, bias=False)
  )
)

We can see by printing the archectecture of the EquivariantLinear module, that there are 8 ‘Linear’ layers, one for each block. In order to preserve equivariance, bias is always turned off for all covariant blocks. For invariant blocks, bias can be switched on or off by passing the bool parameter bias when initializing EquivariantLinear.

# Let's see what happens when we pass the features through the network.
per_atom_predictions = equi_linear(spherical_expansion)
print(per_atom_predictions)
print(per_atom_predictions[0])
TensorMap with 8 blocks
keys: o3_lambda  o3_sigma  center_type  neighbor_type
          0         1           6             6
          2         1           6             6
                           ...
          0         1           8             8
          2         1           8             8
TensorBlock
    samples (1): ['system', 'atom']
    components (1): ['o3_mu']
    properties (1): ['_']
    gradients: None

The output of the EquivariantLinear module are still per-atom and block sparse in both “center_type” and “neighbor_type”. To get the per-system prediction, we can densify the prediction in these key dimensions by moving them to samples, then sum over all sample dimensions except “system”.

per_atom_predictions = per_atom_predictions.keys_to_samples(
    ["center_type", "neighbor_type"]
)
per_system_predictions = mts.sum_over_samples(
    per_atom_predictions, ["atom", "center_type", "neighbor_type"]
)
assert mts.equal_metadata(per_system_predictions, target_tensormap)
print(per_system_predictions, per_system_predictions[0])
TensorMap with 2 blocks
keys: o3_lambda  o3_sigma
          0         1
          2         1 TensorBlock
    samples (1): ['system']
    components (1): ['o3_mu']
    properties (1): ['_']
    gradients: None

The overall ‘model’ that maps features to targets contains both the application of a neural network and some extra transformations, we can wrap it all in a single torch module.

class EquivariantMLP(torch.nn.Module):
    """
    A simple equivariant MLP that maps per-atom features to per-structure targets.
    """

    def __init__(self, mlp: torch.nn.Module):
        super().__init__()
        self.mlp = mlp

    def forward(self, features: TensorMap) -> TensorMap:
        # apply the multi-layer perceptron to the features
        per_atom_predictions = self.mlp(features)

        # densify the predictions in the "center_type" and "neighbor_type" key
        # dimensions
        per_atom_predictions = per_atom_predictions.keys_to_samples(
            ["center_type", "neighbor_type"]
        )

        # sum over all sample dimensions except "system"
        per_system_predictions = mts.sum_over_samples(
            per_atom_predictions, ["atom", "center_type", "neighbor_type"]
        )

        return per_system_predictions

Now we will construct the loss function and run the training loop as we did in the previous tutorial, nn modules basic.

# define a custom loss function for TensorMaps that computes the squared error and
# reduces by sum
class TensorMapLoss(torch.nn.Module):
    """
    A custom loss function for TensorMaps that computes the squared error and reduces by
    sum.
    """

    def __init__(self) -> None:
        super().__init__()

    def forward(self, input: TensorMap, target: TensorMap) -> torch.Tensor:
        """
        Computes the total squared error between the ``input`` and ``target``
        TensorMaps.
        """
        # input and target should have equal metadata over all axes
        assert mts.equal_metadata(input, target)

        squared_loss = 0
        for key in input.keys:
            squared_loss += torch.sum((input[key].values - target[key].values) ** 2)

        return squared_loss


# construct a basic training loop. For brevity we will not use datasets or dataloaders.
def training_loop(
    model: torch.nn.Module,
    loss_fn: torch.nn.Module,
    features: TensorMap,
    targets: TensorMap,
) -> None:
    """A basic training loop for a model and loss function."""
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    for epoch in range(301):
        optimizer.zero_grad()

        predictions = model(features)

        assert mts.equal_metadata(predictions, targets)

        loss = loss_fn(predictions, targets)
        loss.backward()
        optimizer.step()

        if epoch % 100 == 0:
            print(f"epoch: {epoch}, loss: {loss}")


loss_fn_mts = TensorMapLoss()
model = EquivariantMLP(equi_linear)
print("with NN = [EquivariantLinear]")
training_loop(model, loss_fn_mts, spherical_expansion, target_tensormap)
with NN = [EquivariantLinear]
epoch: 0, loss: 1.809741514725238
epoch: 100, loss: 1.4297721240098198
epoch: 200, loss: 1.369045943496478
epoch: 300, loss: 1.3525927125547264

Let’s inspect the per-block losses using predictions from the trained model. Note that the model is able to perfectly fit the invariant target blocks, but not the covariant blocks. This is to be expected, as the target data was generated with random numbers and is not itself equivariant, making the learning task impossible.

See also the atomistic cookbook example on rotational equivariance for a more detailed discussion of this topic: https://atomistic-cookbook.org/examples/rotate-equivariants/rotate-equivariants.html

print("per-block loss:")
prediction = model(spherical_expansion)
for key, block in prediction.items():
    print(key, torch.sum((block.values - target_tensormap[key].values) ** 2).item())
per-block loss:
LabelsEntry(o3_lambda=0, o3_sigma=1) 9.874580335918655e-15
LabelsEntry(o3_lambda=2, o3_sigma=1) 1.3525269172807264

Now let’s consider a more complex nonlinear architecture. In the simplest case we are restricted to linear layers for covariant blocks, but we can use nonlinear layers for invariant blocks.

We will use the InvariantReLU activation function. It has the prefix “Invariant” as it only applies the activation function to invariant blocks where o3_lambda = 0, and leaves the covariant blocks unchanged.

# Let's build a new MLP with two linear layers and one activation function.
hidden_layer_width = 64
equi_mlp = Sequential(
    in_keys,
    EquivariantLinear(
        in_keys=in_keys,
        in_features=[len(spherical_expansion[key].properties) for key in in_keys],
        out_features=hidden_layer_width,
    ),
    InvariantReLU(in_keys),  # could also use InvariantTanh, InvariantSiLU
    EquivariantLinear(
        in_keys=in_keys,
        in_features=[hidden_layer_width for _ in in_keys],
        out_features=1,  # for all blocks
    ),
)
print(in_keys)
print(equi_mlp)
Labels(
    o3_lambda  o3_sigma  center_type  neighbor_type
        0         1           6             6
        2         1           6             6
                         ...
        0         1           8             8
        2         1           8             8
)
Sequential(
  (module_map): ModuleMap(
    (0): Sequential(
      (0): Linear(in_features=5, out_features=64, bias=True)
      (1): ReLU()
      (2): Linear(in_features=64, out_features=1, bias=True)
    )
    (1): Sequential(
      (0): Linear(in_features=5, out_features=64, bias=False)
      (1): Identity()
      (2): Linear(in_features=64, out_features=1, bias=False)
    )
    (2): Sequential(
      (0): Linear(in_features=5, out_features=64, bias=True)
      (1): ReLU()
      (2): Linear(in_features=64, out_features=1, bias=True)
    )
    (3): Sequential(
      (0): Linear(in_features=5, out_features=64, bias=False)
      (1): Identity()
      (2): Linear(in_features=64, out_features=1, bias=False)
    )
    (4): Sequential(
      (0): Linear(in_features=5, out_features=64, bias=True)
      (1): ReLU()
      (2): Linear(in_features=64, out_features=1, bias=True)
    )
    (5): Sequential(
      (0): Linear(in_features=5, out_features=64, bias=False)
      (1): Identity()
      (2): Linear(in_features=64, out_features=1, bias=False)
    )
    (6): Sequential(
      (0): Linear(in_features=5, out_features=64, bias=True)
      (1): ReLU()
      (2): Linear(in_features=64, out_features=1, bias=True)
    )
    (7): Sequential(
      (0): Linear(in_features=5, out_features=64, bias=False)
      (1): Identity()
      (2): Linear(in_features=64, out_features=1, bias=False)
    )
  )
)

Notice now that for invariant blocks, the ‘block model’ is a nonlinear MLP whereas for invariant blocks it is the sequential application of two linear layers, wihtout bias. Re-running the training loop with this new architecture:

model = EquivariantMLP(equi_mlp)
print("with NN = [EquivariantLinear, InvariantSiLU, EquivariantLinear]")
training_loop(model, loss_fn_mts, spherical_expansion, target_tensormap)
with NN = [EquivariantLinear, InvariantSiLU, EquivariantLinear]
epoch: 0, loss: 1.4347149592840838
epoch: 100, loss: 1.3492204184722987
epoch: 200, loss: 1.3492187581772856
epoch: 300, loss: 1.3492187581534367

With the trained model, let’s see the per-block decomposition of the loss. As before, the model can perfectly fit the invariants, but not the covariants, as expected.

print("per-block loss:")
prediction = model(spherical_expansion)
for key, block in prediction.items():
    print(key, torch.sum((block.values - target_tensormap[key].values) ** 2).item())
per-block loss:
LabelsEntry(o3_lambda=0, o3_sigma=1) 5.353718304873162e-16
LabelsEntry(o3_lambda=2, o3_sigma=1) 1.3492187581534363

Conclusion

This tutorial has demonstrated how to build equivariance-preserving architectures using the metatensor-learn convenience neural network modules. These modules, such as EquivariantLinear and InvariantReLU are modified analogs of the standard convenience layers, such as Linear and ReLU.

The key difference is that the invariant or covariant nature (via the “o3_lambda” key dimension) of the input blocks are taken into account, and used to determine the transformations applied to each block. For the convenience modules shown above

Other examples

See the atomistic cookbook for an example on learning the polarizability using EquivariantLinear applied to higher body order features:

https://atomistic-cookbook.org/examples/polarizability/polarizability.html

and those for checking the rotational equivariance of quantities in TensorMap format:

https://atomistic-cookbook.org/examples/rotate-equivariants/rotate-equivariants.html

Total running time of the script: (0 minutes 4.107 seconds)

Gallery generated by Sphinx-Gallery