Neural Network


class metatensor.learn.nn.ModuleMap(in_keys: Labels, modules: List[Module], out_properties: List[Labels] | None = None)[source]

A class that imitates torch.nn.ModuleDict. In its forward function the module at position i given on construction by :param modules: is applied to the tensor block that corresponding to the`i`th key in :param in_keys:.

  • in_keys (Labels) – A metatensor.Labels object with the keys of the module map that are assumed to be in the input tensor map in the forward() function.

  • modules (List[Module]) – A sequence of modules applied in the forward() function on the input TensorMap. Each module corresponds to one LabelsEntry in :param in_keys: that determines on which TensorBlock the module is applied on. :param modules: and :param in_keys: must match in length.

  • out_properties (List[Labels] | None) –

    A list of labels that is used to determine the properties labels of the output. Because a module could change the number of properties, the labels of the properties cannot be persevered. By default the output properties are relabeled using Labels.range with “_” as key.

    >>> import torch
    >>> import numpy as np
    >>> from copy import deepcopy
    >>> from metatensor import Labels, TensorBlock, TensorMap
    >>> from metatensor.learn.nn import ModuleMap

    Create simple block

    >>> block_1 = TensorBlock(
    ...     values=torch.tensor(
    ...         [
    ...             [1.0, 2.0, 4.0],
    ...             [3.0, 5.0, 6.0],
    ...         ]
    ...     ),
    ...     samples=Labels(
    ...         ["system", "atom"],
    ...         np.array(
    ...             [
    ...                 [0, 0],
    ...                 [0, 1],
    ...             ]
    ...         ),
    ...     ),
    ...     components=[],
    ...     properties=Labels(["properties"], np.array([[0], [1], [2]])),
    ... )
    >>> block_2 = TensorBlock(
    ...     values=torch.tensor(
    ...         [
    ...             [5.0, 8.0, 2.0],
    ...             [1.0, 2.0, 8.0],
    ...         ]
    ...     ),
    ...     samples=Labels(
    ...         ["system", "atom"],
    ...         np.array(
    ...             [
    ...                 [0, 0],
    ...                 [0, 1],
    ...             ]
    ...         ),
    ...     ),
    ...     components=[],
    ...     properties=Labels(["properties"], np.array([[3], [4], [5]])),
    ... )
    >>> keys = Labels(names=["key"], values=np.array([[0], [1]]))
    >>> tensor = TensorMap(keys, [block_1, block_2])

    Create modules

    >>> linear = torch.nn.Linear(3, 1, bias=False)
    >>> with torch.no_grad():
    ...     _ = linear.weight.copy_(torch.tensor([1.0, 1.0, 1.0]))
    >>> modules = [linear, deepcopy(linear)]
    >>> # you could also extend the module by some nonlinear activation function

    Create ModuleMap from this ModucDict and apply it

    >>> module_map = ModuleMap(tensor.keys, modules)
    >>> out = module_map(tensor)
    >>> out
    TensorMap with 2 blocks
    keys: key
    >>> out[0].values
    tensor([[ 7.],
            [14.]], grad_fn=<MmBackward0>)
    >>> out[1].values
            [11.]], grad_fn=<MmBackward0>)

    Let’s look at the metadata

    >>> tensor[0]
        samples (2): ['system', 'atom']
        components (): []
        properties (3): ['properties']
        gradients: None
    >>> out[0]
        samples (2): ['system', 'atom']
        components (): []
        properties (1): ['_']
        gradients: None

    It got completely lost because we cannot know in general what the output is. You can add in the initialization of the ModuleMap a TensorMap that contains the intended output Labels.

classmethod from_module(in_keys: Labels, module: Module, out_properties: List[Labels] | None = None)[source]

A wrapper around one torch.nn.Module applying the same type of module on each tensor block.

  • in_keys (Labels) – A metatensor.Labels object that determines the keys of the module map that are ass the TensorMaps that are assumed to be in the input tensor map in the forward() function.

  • module (Module) – The module that is applied on each block.

  • out_properties (List[Labels] | None) – A list of labels that is used to determine the properties labels of the output. Because a module could change the number of properties, the labels of the properties cannot be persevered. By default the output properties are relabeled using Labels.range with “_” as key.

>>> import torch
>>> import numpy as np
>>> from metatensor import Labels, TensorBlock, TensorMap
>>> block_1 = TensorBlock(
...     values=torch.tensor(
...         [
...             [1.0, 2.0, 4.0],
...             [3.0, 5.0, 6.0],
...         ]
...     ),
...     samples=Labels(
...         ["system", "atom"],
...         np.array(
...             [
...                 [0, 0],
...                 [0, 1],
...             ]
...         ),
...     ),
...     components=[],
...     properties=Labels(["properties"], np.array([[0], [1], [2]])),
... )
>>> block_2 = TensorBlock(
...     values=torch.tensor(
...         [
...             [5.0, 8.0, 2.0],
...             [1.0, 2.0, 8.0],
...         ]
...     ),
...     samples=Labels(
...         ["system", "atom"],
...         np.array(
...             [
...                 [0, 0],
...                 [0, 1],
...             ]
...         ),
...     ),
...     components=[],
...     properties=Labels(["properties"], np.array([[0], [1], [2]])),
... )
>>> keys = Labels(names=["key"], values=np.array([[0], [1]]))
>>> tensor = TensorMap(keys, [block_1, block_2])
>>> linear = torch.nn.Linear(3, 1, bias=False)
>>> with torch.no_grad():
...     _ = linear.weight.copy_(torch.tensor([1.0, 1.0, 1.0]))
>>> # you could also extend the module by some nonlinear activation function
>>> from metatensor.learn.nn import ModuleMap
>>> module_map = ModuleMap.from_module(tensor.keys, linear)
>>> out = module_map(tensor)
>>> out[0].values
tensor([[ 7.],
        [14.]], grad_fn=<MmBackward0>)
>>> out[1].values
        [11.]], grad_fn=<MmBackward0>)
forward(tensor: TensorMap) TensorMap[source]

Apply the modules on each block in tensor. tensor must have the same set of keys as the modules used to initialize this ModuleMap.


tensor (TensorMap) – input tensor map

get_module(key: LabelsEntry)[source]

key (LabelsEntry) – key of module which should be returned

Return module:

returns he torch.nn.Module corresponding to the :param key:

property in_keys: Labels

A list of labels that defines the initialized keys with corresponding modules of this module map.

property out_properties: None | List[Labels]

A list of labels that is used to determine properties labels of the output of forward function.

repr_as_module_dict() str[source]

Returns a string that is easier to read that the standard __repr__ showing the mapping from label entry key to module.

class metatensor.learn.nn.Sequential(in_keys: Labels, *args: List[ModuleMap])[source]

A sequential model that applies a list of ModuleMaps to the input in order.

  • in_keys (Labels) – The keys that are assumed to be in the input tensor map in the forward() function.

  • args (List[ModuleMap]) – A list of ModuleMap objects that will be applied in order to the input tensor map in the forward() function.

forward(tensor: TensorMap) TensorMap[source]

Apply the transformations to the input tensor map tensor.


tensor (TensorMap)

class metatensor.learn.nn.Linear(in_keys: Labels, in_features: int | List[int], out_features: List[int] | int | None = None, out_properties: List[Labels] | None = None, *, bias: bool = True, device: device | None = None, dtype: dtype | None = None)[source]

Module similar to torch.nn.Linear that works with metatensor.torch.TensorMap.

Applies a linear transformation to each block of a TensorMap passed to its forward method, indexed by :param in_keys:.

Refer to the :py:class`torch.nn.Linear` documentation for a more detailed description of the other parameters.

Each parameter is passed as a single value of its expected type, which is used as the parameter for all blocks.

  • in_keys (Labels) – Labels, the keys that are assumed to be in the input tensor map in the forward() method.

  • in_features (int | List[int]) – int or list of int, the number of input features for each block. If passed as a single value, the same feature size is taken for all blocks.

  • out_features (List[int] | int | None) – int or lint of int, the number of output features for each block. If passed as a single value, the same feature size is taken for all blocks.

  • out_properties (List[Labels] | None) – list of :py:class`Labels` (optional), the properties labels of the output. By default the output properties are relabeled using Labels.range. If provided, :param out_features: can be inferred and need not be provided.

  • bias (bool)

  • device (device | None)

  • dtype (dtype | None)

forward(tensor: TensorMap) TensorMap[source]

Apply the transformation to the input tensor map tensor.


tensor (TensorMap) – TensorMap with the input tensor to be transformed.



class metatensor.learn.nn.EquivariantLinear(in_keys: Labels, in_features: int | List[int], out_features: List[int] | int | None = None, out_properties: List[Labels] | None = None, invariant_keys: Labels | None = None, *, bias: bool = True, device: device | None = None, dtype: dtype | None = None)[source]

Module similar to torch.nn.Linear that works with equivariant metatensor.torch.TensorMap objects.

Applies a linear transformation to each block of a TensorMap passed to its forward method, indexed by :param in_keys:.

Refer to the :py:class`torch.nn.Linear` documentation for a more detailed description of the other parameters.

For EquivariantLinear, by contrast to Linear, the parameter :param bias: is only applied to modules corresponding to invariant blocks, i.e. keys in :param in_keys: that correspond to the selection in :param invariant_keys:.

  • in_keys (Labels) – Labels, the keys that are assumed to be in the input tensor map in the forward() method.

  • in_features (int | List[int]) – int or list of int, the number of input features for each block. If passed as a single value, the same feature size is taken for all blocks.

  • out_features (List[int] | int | None) – int or lint of int, the number of output features for each block. If passed as a single value, the same feature size is taken for all blocks.

  • out_properties (List[Labels] | None) – list of :py:class`Labels` (optional), the properties labels of the output. By default the output properties are relabeled using Labels.range. If provided, :param out_features: can be inferred and need not be provided.

  • invariant_keys (Labels | None) – a Labels object that is used to select the invariant keys from in_keys. If not provided, the invariant keys are assumed to be those where key dimensions ["o3_lambda", "o3_sigma"] are equal to [0, 1].

  • bias (bool)

  • device (device | None)

  • dtype (dtype | None)

forward(tensor: TensorMap) TensorMap[source]

Apply the transformation to the input tensor map tensor.


tensor (TensorMap) – TensorMap with the input tensor to be transformed.



class metatensor.learn.nn.Tanh(in_keys: Labels, out_properties: Labels | None = None)[source]

Module similar to torch.nn.Tanh that works with metatensor.torch.TensorMap objects.

Applies a hyperbolic tangent transformation to each block of a TensorMap passed to its forward method, indexed by :param in_keys:.

Refer to the :py:class`torch.nn.Tanh` documentation for a more detailed description of the parameters.

  • in_keys (Labels) – Labels, the keys that are assumed to be in the input tensor map in the forward() method.

  • out_properties (Labels | None) – list of :py:class`Labels` (optional), the properties labels of the output. By default the output properties are relabeled using Labels.range.

forward(tensor: TensorMap) TensorMap[source]

Apply the transformation to the input tensor map tensor.

Note: currently not supporting gradients.


tensor (TensorMap) – TensorMap with the input tensor to be transformed.



class metatensor.learn.nn.InvariantTanh(in_keys: Labels, out_properties: Labels | None = None, invariant_keys: Labels | None = None)[source]

Module similar to torch.nn.Tanh that works with metatensor.torch.TensorMap objects, applying the transformation only to the invariant blocks.

Applies a hyperbolic tangent transformation to each invariant block of a TensorMap passed to its forward() method. These are indexed by the keys in :param in_keys: that correspond to the selection passed in :param invariant_keys:.

Refer to the :py:class`torch.nn.Tanh` documentation for a more detailed description of the parameters.

  • in_keys (Labels) – Labels, the keys that are assumed to be in the input tensor map in the forward() method.

  • out_properties (Labels | None) – list of :py:class`Labels` (optional), the properties labels of the output. By default the output properties are relabeled using Labels.range.

  • invariant_keys (Labels | None) – a Labels object that is used to select the invariant keys from in_keys. If not provided, the invariant keys are assumed to be those where key dimensions ["o3_lambda", "o3_sigma"] are equal to [0, 1].

forward(tensor: TensorMap) TensorMap[source]

Apply the transformation to the input tensor map tensor.

Note: currently not supporting gradients.


tensor (TensorMap) – TensorMap with the input tensor to be transformed.



class metatensor.learn.nn.ReLU(in_keys: Labels, out_properties: Labels | None = None, *, in_place: bool = False)[source]

Module similar to torch.nn.ReLU that works with metatensor.torch.TensorMap objects.

Applies a rectified linear unit transformation transformation to each block of a TensorMap passed to its forward method, indexed by :param in_keys:.

Refer to the :py:class`torch.nn.ReLU` documentation for a more detailed description of the parameters.

  • in_keys (Labels) – Labels, the keys that are assumed to be in the input tensor map in the forward() method.

  • out_properties (Labels | None) – list of :py:class`Labels` (optional), the properties labels of the output. By default the output properties are relabeled using Labels.range.

  • in_place (bool)

forward(tensor: TensorMap) TensorMap[source]

Apply the transformation to the input tensor map tensor.

Note: currently not supporting gradients.


tensor (TensorMap) – TensorMap with the input tensor to be transformed.



class metatensor.learn.nn.InvariantReLU(in_keys: Labels, out_properties: Labels | None = None, invariant_keys: Labels | None = None, *, in_place: bool = False)[source]

Module similar to torch.nn.ReLU that works with metatensor.torch.TensorMap objects, applying the transformation only to the invariant blocks.

Applies a rectified linear unit transformation to each invariant block of a TensorMap passed to its forward() method. These are indexed by the keys in :param in_keys: that correspond to the selection passed in :param invariant_keys:.

Refer to the :py:class`torch.nn.ReLU` documentation for a more detailed description of the parameters.

  • in_keys (Labels) – Labels, the keys that are assumed to be in the input tensor map in the forward() method.

  • out_properties (Labels | None) – list of :py:class`Labels` (optional), the properties labels of the output. By default the output properties are relabeled using Labels.range.

  • invariant_keys (Labels | None) – a Labels object that is used to select the invariant keys from in_keys. If not provided, the invariant keys are assumed to be those where key dimensions ["o3_lambda", "o3_sigma"] are equal to [0, 1].

  • in_place (bool)

forward(tensor: TensorMap) TensorMap[source]

Apply the transformation to the input tensor map tensor.

Note: currently not supporting gradients.


tensor (TensorMap) – TensorMap with the input tensor to be transformed.



class metatensor.learn.nn.SiLU(in_keys: Labels, out_properties: Labels | None = None, *, in_place: bool = False)[source]

Module similar to torch.nn.SiLU that works with metatensor.torch.TensorMap objects.

Applies a sigmoid linear unit transformation transformation to each block of a TensorMap passed to its forward method, indexed by :param in_keys:.

Refer to the :py:class`torch.nn.SiLU` documentation for a more detailed description of the parameters.

  • in_keys (Labels) – Labels, the keys that are assumed to be in the input tensor map in the forward() method.

  • out_properties (Labels | None) – list of :py:class`Labels` (optional), the properties labels of the output. By default the output properties are relabeled using Labels.range.

  • in_place (bool)

forward(tensor: TensorMap) TensorMap[source]

Apply the transformation to the input tensor map tensor.

Note: currently not supporting gradients.


tensor (TensorMap) – TensorMap with the input tensor to be transformed.



class metatensor.learn.nn.InvariantSiLU(in_keys: Labels, out_properties: Labels | None = None, invariant_keys: Labels | None = None, *, in_place: bool = False)[source]

Module similar to torch.nn.SiLU that works with metatensor.torch.TensorMap objects, applying the transformation only to the invariant blocks.

Applies a sigmoid linear unit transformation to each invariant block of a TensorMap passed to its forward() method. These are indexed by the keys in :param in_keys: that correspond to the selection passed in :param invariant_keys:.

Refer to the :py:class`torch.nn.SiLU` documentation for a more detailed description of the parameters.

  • in_keys (Labels) – Labels, the keys that are assumed to be in the input tensor map in the forward() method.

  • out_properties (Labels | None) – list of :py:class`Labels` (optional), the properties labels of the output. By default the output properties are relabeled using Labels.range.

  • invariant_keys (Labels | None) – a Labels object that is used to select the invariant keys from in_keys. If not provided, the invariant keys are assumed to be those where key dimensions ["o3_lambda", "o3_sigma"] are equal to [0, 1].

  • in_place (bool)

forward(tensor: TensorMap) TensorMap[source]

Apply the transformation to the input tensor map tensor.

Note: currently not supporting gradients.


tensor (TensorMap) – TensorMap with the input tensor to be transformed.



class metatensor.learn.nn.LayerNorm(in_keys: Labels, in_features: List[int], out_properties: List[Labels] | None = None, *, eps: float = 1e-05, elementwise_affine: bool = True, bias: bool = True, mean: bool = True, device: device | None = None, dtype: dtype | None = None)[source]

Module similar to torch.nn.LayerNorm that works with metatensor.torch.TensorMap objects.

Applies a layer normalization to each block of a TensorMap passed to its forward() method, indexed by :param in_keys:.

The main difference from torch.nn.LayerNorm is that there is no normalized_shape parameter. Instead, the standard deviation and mean (if applicable) are calculated over all dimensions except the samples (first) dimension of each TensorBlock.

The extra parameter :param mean: controls whether or not the mean over these dimensions is subtracted from the input tensor in the transformation.

Refer to the :py:class`torch.nn.LayerNorm` documentation for a more detailed description of the other parameters.

Each parameter is passed as a single value of its expected type, which is used as the parameter for all blocks.

  • in_keys (Labels) – Labels, the keys that are assumed to be in the input tensor map in the forward() method.

  • in_features (List[int]) – list of int, the number of features in the input tensor for each block indexed by the keys in :param in_keys:. If passed as a single value, the same number of features is assumed for all blocks.

  • out_properties (List[Labels] | None) – list of :py:class`Labels` (optional), the properties labels of the output. By default (if none) the output properties are relabeled using Labels.range.

  • eps (float)

  • elementwise_affine (bool)

  • bias (bool)

  • mean (bool)

  • device (device | None)

  • dtype (dtype | None)

Mean bool:

whether or not to subtract the mean over all dimensions except the samples (first) dimension of each block of the input passed to forward().

forward(tensor: TensorMap) TensorMap[source]

Apply the transformation to the input tensor map tensor.


tensor (TensorMap) – TensorMap with the input tensor to be transformed.



class metatensor.learn.nn.InvariantLayerNorm(in_keys: Labels, in_features: List[int], out_properties: List[Labels] | None = None, invariant_keys: Labels | None = None, *, eps: float = 1e-05, elementwise_affine: bool = True, bias: bool = True, mean: bool = True, device: device | None = None, dtype: dtype | None = None)[source]

Module similar to torch.nn.LayerNorm that works with metatensor.torch.TensorMap objects, applying the transformation only to the invariant blocks.

Applies a layer normalization to each invariant block of a TensorMap passed to forward() method. These are indexed by the keys in :param in_keys: that correspond to the selection passed in :param invariant_keys:.

The main difference from torch.nn.LayerNorm is that there is no normalized_shape parameter. Instead, the standard deviation and mean (if applicable) are calculated over all dimensions except the samples (first) dimension of each TensorBlock.

The extra parameter :param mean: controls whether or not the mean over these dimensions is subtracted from the input tensor in the transformation.

Refer to the :py:class`torch.nn.LayerNorm` documentation for a more detailed description of the other parameters.

Each parameter is passed as a single value of its expected type, which is used as the parameter for all blocks.

  • in_keys (Labels) – Labels, the keys that are assumed to be in the input tensor map in the forward() method.

  • in_features (List[int]) – list of int, the number of features in the input tensor for each block indexed by the keys in :param in_keys:. If passed as a single value, the same number of features is assumed for all blocks.

  • out_properties (List[Labels] | None) – list of :py:class`Labels` (optional), the properties labels of the output. By default (if none) the output properties are relabeled using Labels.range.

  • invariant_keys (Labels | None) – a Labels object that is used to select the invariant keys from in_keys. If not provided, the invariant keys are assumed to be those where key dimensions ["o3_lambda", "o3_sigma"] are equal to [0, 1].

  • eps (float)

  • elementwise_affine (bool)

  • bias (bool)

  • mean (bool)

  • device (device | None)

  • dtype (dtype | None)

Mean bool:

whether or not to subtract the mean over all dimensions except the samples (first) dimension of each block of the input passed to forward().

forward(tensor: TensorMap) TensorMap[source]

Apply the layer norm to the input tensor map tensor.


tensor (TensorMap) – TensorMap with the input tensor to be transformed.



