Source code for metatensor.learn.nn.layer_norm

"""
Module containing classes :class:`LayerNorm` and :class:`InvariantLayerNorm`, i.e.
module maps that apply layer norms in a generic and equivariant way,
respectively.
"""

from typing import List, Optional, Union

import torch
from torch.nn import Module, init
from torch.nn.parameter import Parameter

from .._backend import Labels, TensorMap
from ._utils import _check_module_map_parameter
from .module_map import ModuleMap


[docs] class LayerNorm(Module): """ Module similar to :py:class:`torch.nn.LayerNorm` that works with :py:class:`metatensor.torch.TensorMap` objects. Applies a layer normalization to each block of a :py:class:`TensorMap` passed to its :py:meth:`forward` method, indexed by :param in_keys:. The main difference from :py:class:`torch.nn.LayerNorm` is that there is no `normalized_shape` parameter. Instead, the standard deviation and mean (if applicable) are calculated over all dimensions except the samples (first) dimension of each :py:class:`TensorBlock`. The extra parameter :param mean: controls whether or not the mean over these dimensions is subtracted from the input tensor in the transformation. Refer to the :py:class`torch.nn.LayerNorm` documentation for a more detailed description of the other parameters. Each parameter can be passed as a single value of its expected type, which is used as the parameter for all blocks. Alternatively, they can be passed as a list to control the parameters applied to each block indexed by the keys in :param in_keys:. :param in_keys: :py:class:`Labels`, the keys that are assumed to be in the input tensor map in the :py:meth:`forward` method. :param in_features: list of int, the number of features in the input tensor for each block indexed by the keys in :param in_keys:. If passed as a single value, the same number of features is assumed for all blocks. :param out_properties: list of :py:class`Labels` (optional), the properties labels of the output. By default (if none) the output properties are relabeled using Labels.range. :mean bool: whether or not to subtract the mean over all dimensions except the samples (first) dimension of each block of the input passed to :py:meth:`forward`. """ def __init__( self, in_keys: Labels, in_features: List[int], out_properties: Optional[List[Labels]] = None, *, eps: Union[float, List[float]] = 1e-5, elementwise_affine: Union[bool, List[bool]] = True, bias: Union[bool, List[bool]] = True, mean: Union[bool, List[bool]] = True, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ) -> None: super().__init__() # Check input parameters, convert to lists (for each key) if necessary in_features = _check_module_map_parameter( in_features, "in_features", int, len(in_keys), "in_keys" ) eps = _check_module_map_parameter(eps, "eps", float, len(in_keys), "in_keys") elementwise_affine = _check_module_map_parameter( elementwise_affine, "elementwise_affine", bool, len(in_keys), "in_keys" ) bias = _check_module_map_parameter(bias, "bias", bool, len(in_keys), "in_keys") mean = _check_module_map_parameter(mean, "mean", bool, len(in_keys), "in_keys") # Build module list modules: List[Module] = [] for i in range(len(in_keys)): module = _LayerNorm( in_features=in_features[i], eps=eps[i], elementwise_affine=elementwise_affine[i], bias=bias[i], mean=mean[i], device=device, dtype=dtype, ) modules.append(module) self.module_map = ModuleMap(in_keys, modules, out_properties)
[docs] def forward(self, tensor: TensorMap) -> TensorMap: """ Apply the transformation to the input tensor map `tensor`. :param tensor: :py:class:`TensorMap` with the input tensor to be transformed. :return: :py:class:`TensorMap` """ # Currently not supporting gradients if len(tensor[0].gradients_list()) != 0: raise ValueError( "Gradients not supported. Please use metatensor.remove_gradients()" " before using this module" ) return self.module_map(tensor)
[docs] class InvariantLayerNorm(Module): """ Module similar to :py:class:`torch.nn.LayerNorm` that works with :py:class:`metatensor.torch.TensorMap` objects, applying the transformation only to the invariant blocks. Applies a layer normalization to each invariant block of a :py:class:`TensorMap` passed to :py:meth:`forward` method. These correspond to the blocks indexed by :param in_keys: at numeric positions passed in :param invariant_key_idxs:. The main difference from :py:class:`torch.nn.LayerNorm` is that there is no `normalized_shape` parameter. Instead, the standard deviation and mean (if applicable) are calculated over all dimensions except the samples (first) dimension of each :py:class:`TensorBlock`. The extra parameter :param mean: controls whether or not the mean over these dimensions is subtracted from the input tensor in the transformation. Refer to the :py:class`torch.nn.LayerNorm` documentation for a more detailed description of the other parameters. Each parameter can be passed as a single value of its expected type, which is used as the parameter for all blocks. Alternatively, they can be passed as a list to control the parameters applied to each block indexed by the invariant keys in :param in_keys: at numeric poositions :param invariant_key_idxs:. :param in_keys: :py:class:`Labels`, the keys that are assumed to be in the input tensor map in the :py:meth:`forward` method. :param invariant_key_idxs: list of int, the indices of the invariant keys present in `in_keys` in the input :py:class:`TensorMap`. Only blocks for these keys will have layer norm applied. The other blocks will have the identity operator applied. :param in_features: list of int, the number of features in the input tensor for each block indexed by the keys in :param in_keys:. If passed as a single value, the same number of features is assumed for all blocks. :param out_properties: list of :py:class`Labels` (optional), the properties labels of the output. By default (if none) the output properties are relabeled using Labels.range. :mean bool: whether or not to subtract the mean over all dimensions except the samples (first) dimension of each block of the input passed to :py:meth:`forward`. If passed as a list, must have the same length as :param invariant_key_idxs:. """ def __init__( self, in_keys: Labels, invariant_key_idxs: List[int], in_features: List[int], out_properties: Optional[List[Labels]] = None, *, eps: Union[float, List[float]] = 1e-5, elementwise_affine: Union[bool, List[bool]] = True, bias: Union[bool, List[bool]] = True, mean: Union[bool, List[bool]] = True, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ) -> None: super().__init__() # Check input parameters, convert to lists (for each *invariant* key) if # necessary. in_features = _check_module_map_parameter( in_features, "in_features", int, len(invariant_key_idxs), "invariant_key_idxs", ) eps = _check_module_map_parameter( eps, "eps", float, len(invariant_key_idxs), "invariant_key_idxs" ) elementwise_affine = _check_module_map_parameter( elementwise_affine, "elementwise_affine", bool, len(invariant_key_idxs), "invariant_key_idxs", ) bias = _check_module_map_parameter( bias, "bias", bool, len(invariant_key_idxs), "invariant_key_idxs" ) mean = _check_module_map_parameter( mean, "mean", bool, len(invariant_key_idxs), "invariant_key_idxs" ) # Build module list modules: List[Module] = [] invariant_idx: int = 0 for i in range(len(in_keys)): if i in invariant_key_idxs: # Invariant block: apply LayerNorm for j in range(len(invariant_key_idxs)): if invariant_key_idxs[j] == i: invariant_idx = j module = _LayerNorm( in_features=in_features[invariant_idx], eps=eps[invariant_idx], elementwise_affine=elementwise_affine[invariant_idx], bias=bias[invariant_idx], mean=mean[invariant_idx], device=device, dtype=dtype, ) else: # Covariant block: apply the identity operator module = torch.nn.Identity() modules.append(module) self.module_map = ModuleMap(in_keys, modules, out_properties)
[docs] def forward(self, tensor: TensorMap) -> TensorMap: """ Apply the layer norm to the input tensor map `tensor`. :param tensor: :py:class:`TensorMap` with the input tensor to be transformed. :return: :py:class:`TensorMap` """ # Currently not supporting gradients if len(tensor[0].gradients_list()) != 0: raise ValueError( "Gradients not supported. Please use metatensor.remove_gradients()" " before using this module" ) return self.module_map(tensor)
class _LayerNorm(Module): """ Custom :py:class:`Module` re-implementing :py:class:`torch.nn.LayerNorm`. In this case, `normalized_shape` is not provided as a parameter. Instead, the standard deviation and mean (if applicable) are calculated over all dimensions except the samples of the input tensor. Subtraction of this mean can be switched on or off using the extra :param mean: parameter. Refer to :py:class:`torch.nn.LayerNorm` documentation for more information on the other parameters. :param mean: bool, whether or not to subtract the mean over all dimension except the samples of the input tensor passed to :py:meth:`forward`. """ __constants__ = ["in_features", "eps", "elementwise_affine"] eps: float elementwise_affine: bool def __init__( self, in_features: int, eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, mean: bool = True, ) -> None: super().__init__() self.in_features = in_features self.eps = eps self.elementwise_affine = elementwise_affine self.mean = mean if self.elementwise_affine: self.weight = Parameter( torch.empty(in_features, device=device, dtype=dtype) ) if bias: self.bias = Parameter( torch.empty(in_features, device=device, dtype=dtype) ) else: self.register_parameter("bias", None) else: self.register_parameter("weight", None) self.register_parameter("bias", None) self.reset_parameters() def reset_parameters(self) -> None: if self.elementwise_affine: init.ones_(self.weight) if self.bias is not None: init.zeros_(self.bias) def forward(self, input: torch.Tensor) -> torch.Tensor: return _layer_norm( input, weight=self.weight, bias=self.bias, eps=self.eps, mean=self.mean ) def extra_repr(self) -> str: return "eps={eps}, elementwise_affine={elementwise_affine}".format( **self.__dict__ ) def _layer_norm( tensor: torch.Tensor, weight: Optional[torch.Tensor], bias: Optional[torch.Tensor], eps: float, mean: bool, ) -> torch.Tensor: """ Apply layer normalization to the input `tensor`. See :py:class:`torch.nn.functional.layer_norm` for more information on the other parameters. In addition to base torch implementation, this function has the added control of whether or not the mean over all dimensions except the samples (first) dimension is subtracted from the input tensor. :param mean: whether or not to subtract from the input :param tensor: the mean over all dimensions except the samples (first) dimension. :return: :py:class:`torch.Tensor` with layer normalization applied. """ # Contract over all dimensions except samples dim: List[int] = list(range(1, len(tensor.shape))) if mean: # subtract mean over properties dimension tensor_out = tensor - torch.mean(tensor, dim=dim, keepdim=True) else: tensor_out = tensor # Divide by standard deviation over properties dimension. `correction=0` for biased # estimator, in accordance with the torch implementation. tensor_out /= torch.sqrt( torch.var(tensor, dim=dim, correction=0, keepdim=True) + eps ) if weight is not None: # apply affine transformation tensor_out *= weight if bias is not None: # apply bias tensor_out += bias return tensor_out