Source code for metatensor.learn.nn.equivariant_transformation
from typing import List, Optional, Union
import torch
from .._backend import Labels, TensorMap
from .._dispatch import int_array_like
from ._utils import _check_module_map_parameter
from .module_map import ModuleMap
[docs]
class EquivariantTransformation(torch.nn.Module):
    """
    A custom :py:class:`torch.nn.Module` that applies an arbitrary shape- and
    equivariance-preserving transformation to an input :py:class:`TensorMap`.
    For invariant blocks (specified with ``invariant_keys``), the respective
    transformation contained in :param modules: is applied as is. For covariant blocks,
    an invariant multiplier is created, applying the transformation to the norm of the
    block over the component dimensions.
    :param modules: a :py:class:`list` of :py:class:`torch.nn.Module` containing the
        transformations to be applied to each block indexed by
        :param in_keys:. Transformations for invariant and covariant blocks differ. See
        above.
    :param in_keys: :py:class:`Labels`, the keys that are assumed to be in the input
        :py:class:`TensorMap` in the :py:meth:`forward` method.
    :param in_features: :py:class:`list` of :py:class:`int`, the number of features in
        the input tensor for each block indexed by the keys in :param in_keys:. If
        passed as a single value, the same number of features is assumed for all blocks.
    :param out_properties: :py:class:`list` of :py:class`Labels` (optional), the
        properties labels
        of the output. By default the output properties are relabeled using
        Labels.range.
    :param invariant_keys: a :py:class:`Labels` object that is used to select the
        invariant keys from ``in_keys``. If not provided, the invariant keys are assumed
        to be those where key dimensions ``["o3_lambda", "o3_sigma"]`` are equal to
        ``[0, 1]``.
        >>> import torch
        >>> import numpy as np
        >>> from metatensor import Labels, TensorBlock, TensorMap
        >>> from metatensor.learn.nn import EquivariantTransformation
        Define a dummy invariant TensorBlock
        >>> block_1 = TensorBlock(
        ...     values=torch.randn(2, 1, 3),
        ...     samples=Labels(
        ...         ["system", "atom"],
        ...         np.array(
        ...             [
        ...                 [0, 0],
        ...                 [0, 1],
        ...             ]
        ...         ),
        ...     ),
        ...     components=[Labels(["o3_mu"], np.array([[0]]))],
        ...     properties=Labels(["properties"], np.array([[0], [1], [2]])),
        ... )
        Define a dummy covariant TensorBlock
        >>> block_2 = TensorBlock(
        ...     values=torch.randn(2, 3, 3),
        ...     samples=Labels(
        ...         ["system", "atom"],
        ...         np.array(
        ...             [
        ...                 [0, 0],
        ...                 [0, 1],
        ...             ]
        ...         ),
        ...     ),
        ...     components=[Labels(["o3_mu"], np.array([[-1], [0], [1]]))],
        ...     properties=Labels(["properties"], np.array([[3], [4], [5]])),
        ... )
        Create a TensorMap containing the dummy TensorBlocks
        >>> keys = Labels(names=["o3_lambda"], values=np.array([[0], [1]]))
        >>> tensor = TensorMap(keys, [block_1, block_2])
        Define the transformation to apply to the TensorMap
        >>> modules = [torch.nn.Tanh(), torch.nn.Tanh()]
        >>> in_features = [len(tensor.block(key).properties) for key in tensor.keys]
        Define the EquivariantTransformation module
        >>> transformation = EquivariantTransformation(
        ...     modules,
        ...     tensor.keys,
        ...     in_features,
        ...     out_properties=[tensor.block(key).properties for key in tensor.keys],
        ...     invariant_keys=Labels(
        ...         ["o3_lambda"], np.array([0], dtype=np.int64).reshape(-1, 1)
        ...     ),
        ... )
        The output metadata are the same as the input
        >>> transformation(tensor)
        TensorMap with 2 blocks
        keys: o3_lambda
                  0
                  1
        >>> transformation(tensor)[0]
        TensorBlock
            samples (2): ['system', 'atom']
            components (1): ['o3_mu']
            properties (3): ['properties']
            gradients: None
    """
    def __init__(
        self,
        modules: List[torch.nn.Module],
        in_keys: Labels,
        in_features: Union[int, List[int]],
        out_features: Optional[Union[int, List[int]]] = None,
        out_properties: Optional[List[Labels]] = None,
        invariant_keys: Optional[Labels] = None,
    ) -> None:
        super().__init__()
        # Set a default for invariant keys
        if invariant_keys is None:
            invariant_keys = Labels(
                names=["o3_lambda", "o3_sigma"],
                values=int_array_like([0, 1], like=in_keys.values).reshape(-1, 2),
            )
        invariant_key_idxs = in_keys.select(invariant_keys)
        # Infer `out_features` if not provided
        if out_features is None:
            if out_properties is None:
                raise ValueError(
                    "If `out_features` is not provided,"
                    " `out_properties` must be provided."
                )
            out_features = [len(p) for p in out_properties]
        # Check input parameters, convert to lists (for each key) if necessary
        in_features = _check_module_map_parameter(
            in_features, "in_features", int, len(in_keys), "in_keys"
        )
        modules_for_map: List[torch.nn.Module] = []
        for i in range(len(in_keys)):
            if i in invariant_key_idxs:
                module_i = modules[i]
            else:
                module_i = _CovariantTransform(
                    module=modules[i],
                )
            modules_for_map.append(module_i)
        self.module_map = ModuleMap(in_keys, modules_for_map, out_properties)
[docs]
    def forward(self, tensor: TensorMap) -> TensorMap:
        """
        Apply the transformation to the input tensor map `tensor`.
        :param tensor: :py:class:`TensorMap` with the input tensor to be transformed.
        :return: :py:class:`TensorMap` corresponding to the transformed input
            ``tensor``.
        """
        return self.module_map(tensor)
class _CovariantTransform(torch.nn.Module):
    """
    Applies an arbitrary shape-preserving transformation defined in ``module`` to a
    3-dimensional tensor in a way that preserves equivariance. The transformation is
    applied to the norm of the :py:class:`torch.Tensor` over the component dimension.
    The resulting :py:class:`torch.Tensor` is elementwise multiplied back to the
    original one, thus preserving covariance.
    :param in_features: a :py:class:`int`, the input feature dimension. This also
        corresponds to the output feature size as the shape of the tensor passed to
        :py:meth:`forward` is preserved.
    :param module: :py:class:`torch.nn.Module` containing the transformation to be
        applied to the invariants constructed from the norms over the component
        dimension of the input :py:class:`torch.Tensor` passed to the :py:meth:`forward`
        method.
    """
    def __init__(
        self,
        module: torch.nn.Module,
    ) -> None:
        super().__init__()
        self.module = module
    def forward(self, input: torch.Tensor) -> torch.Tensor:
        """
        Creates an invariant block from the ``input`` covariant, and transforms it by
        applying the torch ``module`` passed to the class constructor. Then uses the
        transformed invariant as an elementwise multiplier for the ``input`` block.
        Transformations are applied consistently to components (axis 1) to preserve
        equivariance.
        """
        assert len(input.shape) == 3, "``input`` must be a three-dimensional tensor"
        invariant = input.norm(dim=1, keepdim=True)
        invariant_transformed = self.module(invariant)
        tensor_out = invariant_transformed * input
        return tensor_out