Source code for metatensor.learn.nn.layer_norm

"""
Module containing classes :class:`LayerNorm` and :class:`InvariantLayerNorm`, i.e.
module maps that apply layer norms in a generic and equivariant way,
respectively.
"""

from typing import List, Optional

import torch
from torch.nn import Module, init
from torch.nn.parameter import Parameter

from .._backend import Labels, TensorMap
from .._dispatch import int_array_like
from ._utils import _check_module_map_parameter
from .module_map import ModuleMap


[docs] class LayerNorm(Module): """ Module similar to :py:class:`torch.nn.LayerNorm` that works with :py:class:`metatensor.torch.TensorMap` objects. Applies a layer normalization to each block of a :py:class:`TensorMap` passed to its :py:meth:`forward` method, indexed by :param in_keys:. The main difference from :py:class:`torch.nn.LayerNorm` is that there is no `normalized_shape` parameter. Instead, the standard deviation and mean (if applicable) are calculated over all dimensions except the samples (first) dimension of each :py:class:`TensorBlock`. The extra parameter :param mean: controls whether or not the mean over these dimensions is subtracted from the input tensor in the transformation. Refer to the :py:class`torch.nn.LayerNorm` documentation for a more detailed description of the other parameters. Each parameter is passed as a single value of its expected type, which is used as the parameter for all blocks. :param in_keys: :py:class:`Labels`, the keys that are assumed to be in the input tensor map in the :py:meth:`forward` method. :param in_features: list of int, the number of features in the input tensor for each block indexed by the keys in :param in_keys:. If passed as a single value, the same number of features is assumed for all blocks. :param out_properties: list of :py:class`Labels` (optional), the properties labels of the output. By default (if none) the output properties are relabeled using Labels.range. :mean bool: whether or not to subtract the mean over all dimensions except the samples (first) dimension of each block of the input passed to :py:meth:`forward`. """ def __init__( self, in_keys: Labels, in_features: List[int], out_properties: Optional[List[Labels]] = None, *, eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True, mean: bool = True, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ) -> None: super().__init__() # Check input parameters, convert to lists (for each key) if necessary in_features = _check_module_map_parameter( in_features, "in_features", int, len(in_keys), "in_keys" ) eps = _check_module_map_parameter(eps, "eps", float, len(in_keys), "in_keys") elementwise_affine = _check_module_map_parameter( elementwise_affine, "elementwise_affine", bool, len(in_keys), "in_keys" ) bias = _check_module_map_parameter(bias, "bias", bool, len(in_keys), "in_keys") mean = _check_module_map_parameter(mean, "mean", bool, len(in_keys), "in_keys") # Build module list modules: List[Module] = [] for i in range(len(in_keys)): module = _LayerNorm( in_features=in_features[i], eps=eps[i], elementwise_affine=elementwise_affine[i], bias=bias[i], mean=mean[i], device=device, dtype=dtype, ) modules.append(module) self.module_map = ModuleMap(in_keys, modules, out_properties)
[docs] def forward(self, tensor: TensorMap) -> TensorMap: """ Apply the transformation to the input tensor map `tensor`. :param tensor: :py:class:`TensorMap` with the input tensor to be transformed. :return: :py:class:`TensorMap` """ # Currently not supporting gradients if len(tensor[0].gradients_list()) != 0: raise ValueError( "Gradients not supported. Please use metatensor.remove_gradients()" " before using this module" ) return self.module_map(tensor)
[docs] class InvariantLayerNorm(Module): """ Module similar to :py:class:`torch.nn.LayerNorm` that works with :py:class:`metatensor.torch.TensorMap` objects, applying the transformation only to the invariant blocks. Applies a layer normalization to each invariant block of a :py:class:`TensorMap` passed to :py:meth:`forward` method. These are indexed by the keys in :param in_keys: that correspond to the selection passed in :param invariant_keys:. The main difference from :py:class:`torch.nn.LayerNorm` is that there is no `normalized_shape` parameter. Instead, the standard deviation and mean (if applicable) are calculated over all dimensions except the samples (first) dimension of each :py:class:`TensorBlock`. The extra parameter :param mean: controls whether or not the mean over these dimensions is subtracted from the input tensor in the transformation. Refer to the :py:class`torch.nn.LayerNorm` documentation for a more detailed description of the other parameters. Each parameter is passed as a single value of its expected type, which is used as the parameter for all blocks. :param in_keys: :py:class:`Labels`, the keys that are assumed to be in the input tensor map in the :py:meth:`forward` method. :param in_features: list of int, the number of features in the input tensor for each block indexed by the keys in :param in_keys:. If passed as a single value, the same number of features is assumed for all blocks. :param out_properties: list of :py:class`Labels` (optional), the properties labels of the output. By default (if none) the output properties are relabeled using Labels.range. :param invariant_keys: a :py:class:`Labels` object that is used to select the invariant keys from ``in_keys``. If not provided, the invariant keys are assumed to be those where key dimensions ``["o3_lambda", "o3_sigma"]`` are equal to ``[0, 1]``. :mean bool: whether or not to subtract the mean over all dimensions except the samples (first) dimension of each block of the input passed to :py:meth:`forward`. """ def __init__( self, in_keys: Labels, in_features: List[int], out_properties: Optional[List[Labels]] = None, invariant_keys: Optional[Labels] = None, *, eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True, mean: bool = True, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ) -> None: super().__init__() # Set a default for invariant keys if invariant_keys is None: invariant_keys = Labels( names=["o3_lambda", "o3_sigma"], values=int_array_like([0, 1], like=in_keys.values).reshape(-1, 1), ) invariant_key_idxs = in_keys.select(invariant_keys) # Check input parameters, convert to lists (for each *invariant* key) if # necessary. in_features = _check_module_map_parameter( in_features, "in_features", int, len(invariant_key_idxs), "invariant_key_idxs", ) eps = _check_module_map_parameter( eps, "eps", float, len(invariant_key_idxs), "invariant_key_idxs" ) elementwise_affine = _check_module_map_parameter( elementwise_affine, "elementwise_affine", bool, len(invariant_key_idxs), "invariant_key_idxs", ) bias = _check_module_map_parameter( bias, "bias", bool, len(invariant_key_idxs), "invariant_key_idxs" ) mean = _check_module_map_parameter( mean, "mean", bool, len(invariant_key_idxs), "invariant_key_idxs" ) # Build module list modules: List[Module] = [] invariant_idx: int = 0 for i in range(len(in_keys)): if i in invariant_key_idxs: # Invariant block: apply LayerNorm for j in range(len(invariant_key_idxs)): if invariant_key_idxs[j] == i: invariant_idx = j module = _LayerNorm( in_features=in_features[invariant_idx], eps=eps[invariant_idx], elementwise_affine=elementwise_affine[invariant_idx], bias=bias[invariant_idx], mean=mean[invariant_idx], device=device, dtype=dtype, ) else: # Covariant block: apply the identity operator module = torch.nn.Identity() modules.append(module) self.module_map = ModuleMap(in_keys, modules, out_properties)
[docs] def forward(self, tensor: TensorMap) -> TensorMap: """ Apply the layer norm to the input tensor map `tensor`. :param tensor: :py:class:`TensorMap` with the input tensor to be transformed. :return: :py:class:`TensorMap` """ # Currently not supporting gradients if len(tensor[0].gradients_list()) != 0: raise ValueError( "Gradients not supported. Please use metatensor.remove_gradients()" " before using this module" ) return self.module_map(tensor)
class _LayerNorm(Module): """ Custom :py:class:`Module` re-implementing :py:class:`torch.nn.LayerNorm`. In this case, `normalized_shape` is not provided as a parameter. Instead, the standard deviation and mean (if applicable) are calculated over all dimensions except the samples of the input tensor. Subtraction of this mean can be switched on or off using the extra :param mean: parameter. Refer to :py:class:`torch.nn.LayerNorm` documentation for more information on the other parameters. :param mean: bool, whether or not to subtract the mean over all dimension except the samples of the input tensor passed to :py:meth:`forward`. """ __constants__ = ["in_features", "eps", "elementwise_affine"] eps: float elementwise_affine: bool def __init__( self, in_features: int, eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, mean: bool = True, ) -> None: super().__init__() self.in_features = in_features self.eps = eps self.elementwise_affine = elementwise_affine self.mean = mean if self.elementwise_affine: self.weight = Parameter( torch.empty(in_features, device=device, dtype=dtype) ) if bias: self.bias = Parameter( torch.empty(in_features, device=device, dtype=dtype) ) else: self.register_parameter("bias", None) else: self.register_parameter("weight", None) self.register_parameter("bias", None) self.reset_parameters() def reset_parameters(self) -> None: if self.elementwise_affine: init.ones_(self.weight) if self.bias is not None: init.zeros_(self.bias) def forward(self, input: torch.Tensor) -> torch.Tensor: return _layer_norm( input, weight=self.weight, bias=self.bias, eps=self.eps, mean=self.mean ) def extra_repr(self) -> str: return "eps={eps}, elementwise_affine={elementwise_affine}".format( **self.__dict__ ) def _layer_norm( tensor: torch.Tensor, weight: Optional[torch.Tensor], bias: Optional[torch.Tensor], eps: float, mean: bool, ) -> torch.Tensor: """ Apply layer normalization to the input `tensor`. See :py:class:`torch.nn.functional.layer_norm` for more information on the other parameters. In addition to base torch implementation, this function has the added control of whether or not the mean over all dimensions except the samples (first) dimension is subtracted from the input tensor. :param mean: whether or not to subtract from the input :param tensor: the mean over all dimensions except the samples (first) dimension. :return: :py:class:`torch.Tensor` with layer normalization applied. """ # Contract over all dimensions except samples dim: List[int] = list(range(1, len(tensor.shape))) if mean: # subtract mean over properties dimension tensor_out = tensor - torch.mean(tensor, dim=dim, keepdim=True) else: tensor_out = tensor # Divide by standard deviation over properties dimension. `correction=0` for biased # estimator, in accordance with the torch implementation. tensor_out /= torch.sqrt( torch.var(tensor, dim=dim, correction=0, keepdim=True) + eps ) if weight is not None: # apply affine transformation tensor_out *= weight if bias is not None: # apply bias tensor_out += bias return tensor_out