.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "examples/learn/5-nn-using-modulemap.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_examples_learn_5-nn-using-modulemap.py: .. _learn-tutorial-nn-using-modulemap: Custom architectures with ``ModuleMap`` ======================================= .. py:currentmodule:: metatensor.torch.learn.nn This tutorial demonstrates how to build custom architectures compatible with ``TensorMap`` objects by combining native ``torch.nn`` modules with metatensor-learn's ``ModuleMap``. .. note:: Prior to this tutorial, it is recommended to read the tutorial on :ref:`using convenience modules `, as this tutorial builds on the concepts introduced there. .. GENERATED FROM PYTHON SOURCE LINES 19-33 .. code-block:: Python from typing import List, Union import torch import metatensor.torch as mts from metatensor.torch import Labels, TensorMap from metatensor.torch.learn.nn import Linear, ModuleMap torch.manual_seed(42) torch.set_default_dtype(torch.float64) .. GENERATED FROM PYTHON SOURCE LINES 34-56 Introduction ------------ The previous tutorials cover how to use metatensor learn's ``nn`` convenience modules to build simple multi-layer perceptrons and their equivariance-preserving analogs. Now we will explore the use of a special module called ``ModuleMap`` that allows users to wrap any native torch module to be compatible with a ``TensorMap``. This is useful for building arbitrary architectures containing layers more complex than found in the standard available layers: namely ``Linear``, ``Tanh``, ``ReLU``, ``SiLU`` and ``LayerNorm`` and their equivariant counterparts. First we need to create some dummy data in :py:class:`TensorMap` format, with multiple :py:class:`TensorBlock`. Here we will focus on unconstrained architectures, as opposed to equivariance preserving ones. The principles in the latter case will be similar, as long as care is taken to build architectures with equivarince-preserving transformations. Let's start by defining a random tensor that we will treat as some intermediate representation. We will build a multi-layer perceptron to transform this tensor into a prediction. Here we will define a 3-block tensor map, with variable in and out dimensions for each block. .. GENERATED FROM PYTHON SOURCE LINES 57-79 .. code-block:: Python n_samples = 100 in_features = [64, 128, 256] out_features = [1, 2, 3] feature_tensormap = TensorMap( keys=Labels(["key"], torch.arange(len(out_features)).reshape(-1, 1)), blocks=[ mts.block_from_array(torch.randn(n_samples, in_feats)) for in_feats in in_features ], ) target_tensormap = TensorMap( keys=Labels(["key"], torch.arange(len(out_features)).reshape(-1, 1)), blocks=[ mts.block_from_array(torch.randn(n_samples, out_feats)) for out_feats in out_features ], ) print("features:", feature_tensormap) print("target:", target_tensormap) .. rst-class:: sphx-glr-script-out .. code-block:: none features: TensorMap with 3 blocks keys: key 0 1 2 target: TensorMap with 3 blocks keys: key 0 1 2 .. GENERATED FROM PYTHON SOURCE LINES 80-87 Starting simple --------------- Let's start with a simple linear layer, but this time constructed manually using ``ModuleMap``. Here we want a linear layer for each block, with the correct in and out feature sizes. The result will be a module that is equivalent to the ``metatensor.torch.learn.nn.Linear`` module. .. GENERATED FROM PYTHON SOURCE LINES 88-109 .. code-block:: Python in_keys = feature_tensormap.keys modules = [] for key in in_keys: module = torch.nn.Linear( in_features=len(feature_tensormap[key].properties), out_features=len(target_tensormap[key].properties), bias=True, ) modules.append(module) # initialize the ModuleMap with the input keys, list of modules, and the output # properties labels metadata. linear_mmap = ModuleMap( in_keys, modules, out_properties=[target_tensormap[key].properties for key in in_keys], ) print(linear_mmap) .. rst-class:: sphx-glr-script-out .. code-block:: none ModuleMap( (0): Linear(in_features=64, out_features=1, bias=True) (1): Linear(in_features=128, out_features=2, bias=True) (2): Linear(in_features=256, out_features=3, bias=True) ) .. GENERATED FROM PYTHON SOURCE LINES 110-115 ``ModuleMap`` automatically handles the forward pass for each block indexed by the ``in_keys`` used to intialize it. In the case where the input contains more keys/blocks than what's given as ``in_keys`, the forward pass will only be applied to the blocks that are present in the input. The output will be a new ``TensorMap`` with the same keys as the input, with the correct output meatdata. .. GENERATED FROM PYTHON SOURCE LINES 116-131 .. code-block:: Python # apply the ModuleMap to the whole feature tensor map prediction_full = linear_mmap(feature_tensormap) # filter the features to only contain one of the blocks, and pass through the ModuleMap prediction_subset = linear_mmap( mts.filter_blocks( feature_tensormap, Labels(["key"], torch.tensor([1]).reshape(-1, 1)) ) ) print(prediction_full.keys, prediction_full.blocks()) print(prediction_subset.keys, prediction_subset.blocks()) .. rst-class:: sphx-glr-script-out .. code-block:: none Labels( key 0 1 2 ) [TensorBlock samples (100): ['sample'] components (): [] properties (1): ['property'] gradients: None , TensorBlock samples (100): ['sample'] components (): [] properties (2): ['property'] gradients: None , TensorBlock samples (100): ['sample'] components (): [] properties (3): ['property'] gradients: None ] Labels( key 1 ) [TensorBlock samples (100): ['sample'] components (): [] properties (2): ['property'] gradients: None ] .. GENERATED FROM PYTHON SOURCE LINES 132-134 Now define a loss function and run a training loop. This is the same as done in the previous tutorials. .. GENERATED FROM PYTHON SOURCE LINES 135-195 .. code-block:: Python # define a custom loss function for TensorMaps that computes the squared error and # reduces by sum class TensorMapLoss(torch.nn.Module): """ A custom loss function for TensorMaps that computes the squared error and reduces by sum. """ def __init__(self) -> None: super().__init__() def forward(self, input: TensorMap, target: TensorMap) -> torch.Tensor: """ Computes the total squared error between the ``input`` and ``target`` TensorMaps. """ # input and target should have equal metadata over all axes assert mts.equal_metadata(input, target) squared_loss = 0 for key in input.keys: squared_loss += torch.sum((input[key].values - target[key].values) ** 2) return squared_loss # construct a basic training loop. For brevity we will not use datasets or dataloaders. def training_loop( model: torch.nn.Module, loss_fn: torch.nn.Module, features: Union[torch.Tensor, TensorMap], targets: Union[torch.Tensor, TensorMap], ) -> None: """A basic training loop for a model and loss function.""" optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) for epoch in range(501): optimizer.zero_grad() predictions = model(features) if isinstance(predictions, torch.ScriptObject): # assume a TensorMap and check metadata is equivalent assert mts.equal_metadata(predictions, targets) loss = loss_fn(predictions, targets) loss.backward() optimizer.step() if epoch % 100 == 0: print(f"epoch: {epoch}, loss: {loss}") loss_fn_mts = TensorMapLoss() print("with NN = [Linear]") training_loop(linear_mmap, loss_fn_mts, feature_tensormap, target_tensormap) .. rst-class:: sphx-glr-script-out .. code-block:: none with NN = [Linear] epoch: 0, loss: 769.1011737074874 epoch: 100, loss: 152.43308577101345 epoch: 200, loss: 95.21448836273512 epoch: 300, loss: 74.85845598989255 epoch: 400, loss: 64.58524404959572 epoch: 500, loss: 58.463045951762794 .. GENERATED FROM PYTHON SOURCE LINES 196-198 More complex architectures -------------------------- .. GENERATED FROM PYTHON SOURCE LINES 199-238 .. code-block:: Python # Defining more complicated architectures is a matter of building # ``torch.nn.Sequential`` objects for each block, and wrapping them into a single # ModuleMap. hidden_layer_width = 32 modules = [] for key in in_keys: module = torch.nn.Sequential( torch.nn.LayerNorm(len(feature_tensormap[key].properties)), torch.nn.Linear( in_features=len(feature_tensormap[key].properties), out_features=hidden_layer_width, bias=True, ), torch.nn.ReLU(), torch.nn.Linear( in_features=hidden_layer_width, out_features=len(target_tensormap[key].properties), bias=True, ), torch.nn.Tanh(), ) modules.append(module) # initialize the ModuleMap with the input keys, list of modules, and the output # properties labels metadata. custom_mmap = ModuleMap( in_keys, modules, out_properties=[target_tensormap[key].properties for key in in_keys], ) print(custom_mmap) print("with NN = [LayerNorm, Linear, ReLU, Linear, Tanh]") training_loop(custom_mmap, loss_fn_mts, feature_tensormap, target_tensormap) .. rst-class:: sphx-glr-script-out .. code-block:: none ModuleMap( (0): Sequential( (0): LayerNorm((64,), eps=1e-05, elementwise_affine=True) (1): Linear(in_features=64, out_features=32, bias=True) (2): ReLU() (3): Linear(in_features=32, out_features=1, bias=True) (4): Tanh() ) (1): Sequential( (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True) (1): Linear(in_features=128, out_features=32, bias=True) (2): ReLU() (3): Linear(in_features=32, out_features=2, bias=True) (4): Tanh() ) (2): Sequential( (0): LayerNorm((256,), eps=1e-05, elementwise_affine=True) (1): Linear(in_features=256, out_features=32, bias=True) (2): ReLU() (3): Linear(in_features=32, out_features=3, bias=True) (4): Tanh() ) ) with NN = [LayerNorm, Linear, ReLU, Linear, Tanh] epoch: 0, loss: 628.4115434117372 epoch: 100, loss: 102.15331090664478 epoch: 200, loss: 93.55173279931788 epoch: 300, loss: 92.10713211723296 epoch: 400, loss: 91.57975604835907 epoch: 500, loss: 91.33320804749201 .. GENERATED FROM PYTHON SOURCE LINES 239-248 ModuleMap can also be wrapped in a ``torch.nn.torch.nn.Module`` to allow construction of complex architectures. For instance, we can be a "ResNet"-style neural network module that takes a ModuleMap and applies it, then sums with some residual connections. Wikipedia has a good summary and diagram of this architectural motif, see: https://en.wikipedia.org/wiki/Residual_neural_network . To do the latter step, we can combine application of the ``ModuleMap`` with the ``Linear`` convenience layer from metatensor-learn, and the sparse addition operation from ``metatensor-operations`` to build a complex architecture. .. GENERATED FROM PYTHON SOURCE LINES 249-315 .. code-block:: Python class ResidualNetwork(torch.nn.Module): def __init__( self, in_keys: Labels, in_features: List[int], out_properties: List[Labels], ) -> None: super().__init__() # Build the module map as before hidden_layer_width = 32 modules = [] for in_feats, out_props in zip(in_features, out_properties): module = torch.nn.Sequential( torch.nn.LayerNorm(in_feats), torch.nn.Linear( in_features=in_feats, out_features=hidden_layer_width, bias=True, ), torch.nn.ReLU(), torch.nn.Linear( in_features=hidden_layer_width, out_features=len(out_props), bias=True, ), torch.nn.Tanh(), ) modules.append(module) self.module_map = ModuleMap( in_keys, modules, out_properties=out_properties, ) # build the input projection layer self.projection = Linear( in_keys=in_keys, in_features=in_features, out_properties=out_properties, bias=True, ) def forward(self, features: TensorMap) -> TensorMap: # apply the module map to the features prediction = self.module_map(features) # apply the projection layer to the features residual = self.projection(features) # add the prediction and residual together using the sparse addition from # metatensor-operations return mts.add(prediction, residual) model = ResidualNetwork( in_keys=in_keys, in_features=in_features, out_properties=[block.properties for block in target_tensormap], ) print("with NN = [LayerNorm, Linear, ReLU, Linear, Tanh] plus residual connections") training_loop(model, loss_fn_mts, feature_tensormap, target_tensormap) .. rst-class:: sphx-glr-script-out .. code-block:: none with NN = [LayerNorm, Linear, ReLU, Linear, Tanh] plus residual connections epoch: 0, loss: 877.4945870891794 epoch: 100, loss: 18.696672961002434 epoch: 200, loss: 3.845163971248074 epoch: 300, loss: 1.5772835196534585 epoch: 400, loss: 0.8379008238832794 epoch: 500, loss: 0.46250230046218666 .. GENERATED FROM PYTHON SOURCE LINES 316-329 Conclusion ---------- In this tutorial we have seen how to build custom architectures using ``ModuleMap``. This allows for arbitrary architectures to be built, as long as the metadata is preserved. We have also seen how to build a custom module that wraps a ``ModuleMap`` and adds residual connections. The key takeaway is that ``ModuleMap`` can be used to wrap any combination of native ``torch.nn`` modules to make them compatible with ``TensorMap``. In combination with convenience layers seen in the tutorial :ref:`nn modules basic `, and sparse-data operations from ``metatensor-operations``, complex architectures can be built with ease. .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 6.497 seconds) .. _sphx_glr_download_examples_learn_5-nn-using-modulemap.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: 5-nn-using-modulemap.ipynb <5-nn-using-modulemap.ipynb>` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: 5-nn-using-modulemap.py <5-nn-using-modulemap.py>` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: 5-nn-using-modulemap.zip <5-nn-using-modulemap.zip>` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_