Source code for metatensor.learn.nn.equivariant_transformation

from typing import List, Optional, Union

import torch

from .._backend import Labels, TensorMap
from .._dispatch import int_array_like
from ._utils import _check_module_map_parameter
from .module_map import ModuleMap


[docs] class EquivariantTransformation(torch.nn.Module): """ A custom :py:class:`torch.nn.Module` that applies an arbitrary shape- and equivariance-preserving transformation to an input :py:class:`TensorMap`. For invariant blocks (specified with ``invariant_keys``), the respective transformation contained in :param modules: is applied as is. For covariant blocks, an invariant multiplier is created, applying the transformation to the norm of the block over the component dimensions. :param modules: a :py:class:`list` of :py:class:`torch.nn.Module` containing the transformations to be applied to each block indexed by :param in_keys:. Transformations for invariant and covariant blocks differ. See above. :param in_keys: :py:class:`Labels`, the keys that are assumed to be in the input :py:class:`TensorMap` in the :py:meth:`forward` method. :param in_features: :py:class:`list` of :py:class:`int`, the number of features in the input tensor for each block indexed by the keys in :param in_keys:. If passed as a single value, the same number of features is assumed for all blocks. :param out_properties: :py:class:`list` of :py:class`Labels` (optional), the properties labels of the output. By default the output properties are relabeled using Labels.range. :param invariant_keys: a :py:class:`Labels` object that is used to select the invariant keys from ``in_keys``. If not provided, the invariant keys are assumed to be those where key dimensions ``["o3_lambda", "o3_sigma"]`` are equal to ``[0, 1]``. >>> import torch >>> import numpy as np >>> from metatensor import Labels, TensorBlock, TensorMap >>> from metatensor.learn.nn import EquivariantTransformation Define a dummy invariant TensorBlock >>> block_1 = TensorBlock( ... values=torch.randn(2, 1, 3), ... samples=Labels( ... ["system", "atom"], ... np.array( ... [ ... [0, 0], ... [0, 1], ... ] ... ), ... ), ... components=[Labels(["o3_mu"], np.array([[0]]))], ... properties=Labels(["properties"], np.array([[0], [1], [2]])), ... ) Define a dummy covariant TensorBlock >>> block_2 = TensorBlock( ... values=torch.randn(2, 3, 3), ... samples=Labels( ... ["system", "atom"], ... np.array( ... [ ... [0, 0], ... [0, 1], ... ] ... ), ... ), ... components=[Labels(["o3_mu"], np.array([[-1], [0], [1]]))], ... properties=Labels(["properties"], np.array([[3], [4], [5]])), ... ) Create a TensorMap containing the dummy TensorBlocks >>> keys = Labels(names=["o3_lambda"], values=np.array([[0], [1]])) >>> tensor = TensorMap(keys, [block_1, block_2]) Define the transformation to apply to the TensorMap >>> modules = [torch.nn.Tanh(), torch.nn.Tanh()] >>> in_features = [len(tensor.block(key).properties) for key in tensor.keys] Define the EquivariantTransformation module >>> transformation = EquivariantTransformation( ... modules, ... tensor.keys, ... in_features, ... out_properties=[tensor.block(key).properties for key in tensor.keys], ... invariant_keys=Labels( ... ["o3_lambda"], np.array([0], dtype=np.int64).reshape(-1, 1) ... ), ... ) The output metadata are the same as the input >>> transformation(tensor) TensorMap with 2 blocks keys: o3_lambda 0 1 >>> transformation(tensor)[0] TensorBlock samples (2): ['system', 'atom'] components (1): ['o3_mu'] properties (3): ['properties'] gradients: None """ def __init__( self, modules: List[torch.nn.Module], in_keys: Labels, in_features: Union[int, List[int]], out_features: Optional[Union[int, List[int]]] = None, out_properties: Optional[List[Labels]] = None, invariant_keys: Optional[Labels] = None, ) -> None: super().__init__() # Set a default for invariant keys if invariant_keys is None: invariant_keys = Labels( names=["o3_lambda", "o3_sigma"], values=int_array_like([0, 1], like=in_keys.values).reshape(-1, 2), ) invariant_key_idxs = in_keys.select(invariant_keys) # Infer `out_features` if not provided if out_features is None: if out_properties is None: raise ValueError( "If `out_features` is not provided," " `out_properties` must be provided." ) out_features = [len(p) for p in out_properties] # Check input parameters, convert to lists (for each key) if necessary in_features = _check_module_map_parameter( in_features, "in_features", int, len(in_keys), "in_keys" ) modules_for_map: List[torch.nn.Module] = [] for i in range(len(in_keys)): if i in invariant_key_idxs: module_i = modules[i] else: module_i = _CovariantTransform( module=modules[i], ) modules_for_map.append(module_i) self.module_map = ModuleMap(in_keys, modules_for_map, out_properties)
[docs] def forward(self, tensor: TensorMap) -> TensorMap: """ Apply the transformation to the input tensor map `tensor`. :param tensor: :py:class:`TensorMap` with the input tensor to be transformed. :return: :py:class:`TensorMap` corresponding to the transformed input ``tensor``. """ return self.module_map(tensor)
class _CovariantTransform(torch.nn.Module): """ Applies an arbitrary shape-preserving transformation defined in ``module`` to a 3-dimensional tensor in a way that preserves equivariance. The transformation is applied to the norm of the :py:class:`torch.Tensor` over the component dimension. The resulting :py:class:`torch.Tensor` is elementwise multiplied back to the original one, thus preserving covariance. :param in_features: a :py:class:`int`, the input feature dimension. This also corresponds to the output feature size as the shape of the tensor passed to :py:meth:`forward` is preserved. :param module: :py:class:`torch.nn.Module` containing the transformation to be applied to the invariants constructed from the norms over the component dimension of the input :py:class:`torch.Tensor` passed to the :py:meth:`forward` method. """ def __init__( self, module: torch.nn.Module, ) -> None: super().__init__() self.module = module def forward(self, input: torch.Tensor) -> torch.Tensor: """ Creates an invariant block from the ``input`` covariant, and transforms it by applying the torch ``module`` passed to the class constructor. Then uses the transformed invariant as an elementwise multiplier for the ``input`` block. Transformations are applied consistently to components (axis 1) to preserve equivariance. """ assert len(input.shape) == 3, "``input`` must be a three-dimensional tensor" invariant = input.norm(dim=1, keepdim=True) invariant_transformed = self.module(invariant) tensor_out = invariant_transformed * input return tensor_out