Source code for metatensor.learn.nn.linear

from typing import List, Optional, Union

import torch
from torch.nn import Module

from .._backend import Labels, TensorMap
from .._dispatch import int_array_like
from ._utils import _check_module_map_parameter
from .module_map import ModuleMap


[docs] class Linear(Module): """ Module similar to :py:class:`torch.nn.Linear` that works with :py:class:`metatensor.torch.TensorMap`. Applies a linear transformation to each block of a :py:class:`TensorMap` passed to its forward method, indexed by :param in_keys:. Refer to the :py:class`torch.nn.Linear` documentation for a more detailed description of the other parameters. Each parameter is passed as a single value of its expected type, which is used as the parameter for all blocks. :param in_keys: :py:class:`Labels`, the keys that are assumed to be in the input tensor map in the :py:meth:`forward` method. :param in_features: :py:class:`int` or :py:class:`list` of :py:class:`int`, the number of input features for each block. If passed as a single value, the same feature size is taken for all blocks. :param out_features: :py:class:`int` or :py:class:`lint` of :py:class:`int`, the number of output features for each block. If passed as a single value, the same feature size is taken for all blocks. :param out_properties: list of :py:class`Labels` (optional), the properties labels of the output. By default the output properties are relabeled using Labels.range. If provided, :param out_features: can be inferred and need not be provided. """ def __init__( self, in_keys: Labels, in_features: Union[int, List[int]], out_features: Optional[Union[int, List[int]]] = None, out_properties: Optional[List[Labels]] = None, *, bias: bool = True, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ) -> None: super().__init__() # Infer `out_features` if not provided if out_features is None: if out_properties is None: raise ValueError( "If `out_features` is not provided," " `out_properties` must be provided." ) out_features = [len(p) for p in out_properties] # Check input parameters, convert to lists (for each key) if necessary in_features = _check_module_map_parameter( in_features, "in_features", int, len(in_keys), "in_keys" ) out_features = _check_module_map_parameter( out_features, "out_features", int, len(in_keys), "in_keys" ) bias = _check_module_map_parameter(bias, "bias", bool, len(in_keys), "in_keys") modules: List[Module] = [] for i in range(len(in_keys)): module = torch.nn.Linear( in_features=in_features[i], out_features=out_features[i], bias=bias[i], device=device, dtype=dtype, ) modules.append(module) self.module_map = ModuleMap(in_keys, modules, out_properties)
[docs] def forward(self, tensor: TensorMap) -> TensorMap: """ Apply the transformation to the input tensor map `tensor`. :param tensor: :py:class:`TensorMap` with the input tensor to be transformed. :return: :py:class:`TensorMap` """ return self.module_map(tensor)
[docs] class EquivariantLinear(Module): """ Module similar to :py:class:`torch.nn.Linear` that works with equivariant :py:class:`metatensor.torch.TensorMap` objects. Applies a linear transformation to each block of a :py:class:`TensorMap` passed to its forward method, indexed by :param in_keys:. Refer to the :py:class`torch.nn.Linear` documentation for a more detailed description of the other parameters. For :py:class:`EquivariantLinear`, by contrast to :py:class:`Linear`, the parameter :param bias: is only applied to modules corresponding to invariant blocks, i.e. keys in :param in_keys: that correspond to the selection in :param invariant_keys:. :param in_keys: :py:class:`Labels`, the keys that are assumed to be in the input tensor map in the :py:meth:`forward` method. :param in_features: :py:class:`int` or :py:class:`list` of :py:class:`int`, the number of input features for each block. If passed as a single value, the same feature size is taken for all blocks. :param out_features: :py:class:`int` or :py:class:`lint` of :py:class:`int`, the number of output features for each block. If passed as a single value, the same feature size is taken for all blocks. :param out_properties: list of :py:class`Labels` (optional), the properties labels of the output. By default the output properties are relabeled using Labels.range. If provided, :param out_features: can be inferred and need not be provided. :param invariant_keys: a :py:class:`Labels` object that is used to select the invariant keys from ``in_keys``. If not provided, the invariant keys are assumed to be those where key dimensions ``["o3_lambda", "o3_sigma"]`` are equal to ``[0, 1]``. """ def __init__( self, in_keys: Labels, in_features: Union[int, List[int]], out_features: Optional[Union[int, List[int]]] = None, out_properties: Optional[List[Labels]] = None, invariant_keys: Optional[Labels] = None, *, bias: bool = True, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ) -> None: super().__init__() # Set a default for invariant keys if invariant_keys is None: invariant_keys = Labels( names=["o3_lambda", "o3_sigma"], values=int_array_like([0, 1], like=in_keys.values).reshape(-1, 1), ) invariant_key_idxs = in_keys.select(invariant_keys) # Infer `out_features` if not provided if out_features is None: if out_properties is None: raise ValueError( "If `out_features` is not provided," " `out_properties` must be provided." ) out_features = [len(p) for p in out_properties] # Check input parameters, convert to lists (for each key) if necessary in_features = _check_module_map_parameter( in_features, "in_features", int, len(in_keys), "in_keys" ) out_features = _check_module_map_parameter( out_features, "out_features", int, len(in_keys), "in_keys" ) bias = _check_module_map_parameter( bias, "bias", bool, len(invariant_key_idxs), "invariant_key_idxs" ) modules: List[Module] = [] for i in range(len(in_keys)): if ( i in invariant_key_idxs ): # Invariant block: apply bias according to user choice for j in range(len(invariant_key_idxs)): if invariant_key_idxs[j] == i: bias_block = bias[j] else: # Covariant block: do not apply bias bias_block = False module = torch.nn.Linear( in_features=in_features[i], out_features=out_features[i], bias=bias_block, device=device, dtype=dtype, ) modules.append(module) self.module_map = ModuleMap(in_keys, modules, out_properties)
[docs] def forward(self, tensor: TensorMap) -> TensorMap: """ Apply the transformation to the input tensor map `tensor`. :param tensor: :py:class:`TensorMap` with the input tensor to be transformed. :return: :py:class:`TensorMap` """ return self.module_map(tensor)