Neural Network building blocks¶
Modules¶
- class metatensor.learn.nn.Module[source]¶
This class should be used instead of
torch.nn.Module
when your module contains data stored insidemetatensor.torch.Labels
,metatensor.torch.TensorBlock
ormetatensor.torch.TensorMap
.It ensures that this data is properly moved to other dtype and devices when calling
.to()
,.cuda()
,.float()
and other related functions. We also handle the corresponding data instate_dict()
andload_state_dict()
.We support storing these class either directly as attributes (
self.name = ...
), or inside arbitrarily nested dict, list, or tuple (self.name = {"dict": [...]}
).Below is an example creating a custom linear model, that stores the output
properties
as an attribute. The corresponding labels will automatically be moved on device at the same time as the module.>>> import torch >>> from typing import List >>> from metatensor.torch.learn import nn >>> from metatensor.torch import Labels, TensorMap, TensorBlock >>> >>> class CustomLinear(nn.Module): ... def __init__(self, in_features, out_features): ... super().__init__() ... self.properties = Labels( ... ["out_features"], torch.arange(out_features).reshape(-1, 1) ... ) ... self.linear = torch.nn.Linear(in_features, out_features) ... ... def forward(self, tensor: TensorMap) -> TensorMap: ... blocks: List[TensorBlock] = [] ... for block in tensor: ... new_values = self.linear(block.values) ... new_block = TensorBlock( ... values, ... block.samples, ... block.components, ... self.properties, ... ) ... blocks.append(new_block) ... return TensorBlock(tensor.keys, blocks)
- get_extra_state()[source]¶
Return any extra state to include in the module’s state_dict.
Implement this and a corresponding
set_extra_state()
for your module if you need to store extra state. This function is called when building the module’s state_dict().Note that extra state should be picklable to ensure working serialization of the state_dict. We only provide backwards compatibility guarantees for serializing Tensors; other objects may break backwards compatibility if their serialized pickled form changes.
- Returns:
object: Any extra state to store in the module’s state_dict
- set_extra_state(extra)[source]¶
Set extra state contained in the loaded state_dict.
This function is called from
load_state_dict()
to handle any extra state found within the state_dict. Implement this function and a correspondingget_extra_state()
for your module if you need to store extra state within its state_dict.- Args:
state (dict): Extra state from the state_dict
- class metatensor.learn.nn.ModuleMap(in_keys: Labels, modules: List[Module], out_properties: List[Labels] | None = None)[source]¶
A class that imitates
torch.nn.ModuleDict
. In its forward function the module at position i given on construction by :param modules: is applied to the tensor block that corresponding to the`i`th key in :param in_keys:.- Parameters:
in_keys (Labels) – A
metatensor.Labels
object with the keys of the module map that are assumed to be in the input tensor map in theforward()
function.modules (List[Module]) – A sequence of modules applied in the
forward()
function on the inputTensorMap
. Each module corresponds to oneLabelsEntry
in :param in_keys: that determines on whichTensorBlock
the module is applied on. :param modules: and :param in_keys: must match in length.out_properties (List[Labels] | None) –
A list of labels that is used to determine the properties labels of the output. Because a module could change the number of properties, the labels of the properties cannot be persevered. By default the output properties are relabeled using Labels.range with “_” as key.
>>> import torch >>> import numpy as np >>> from copy import deepcopy >>> from metatensor import Labels, TensorBlock, TensorMap >>> from metatensor.learn.nn import ModuleMap
Create simple block
>>> block_1 = TensorBlock( ... values=torch.tensor( ... [ ... [1.0, 2.0, 4.0], ... [3.0, 5.0, 6.0], ... ] ... ), ... samples=Labels( ... ["system", "atom"], ... np.array( ... [ ... [0, 0], ... [0, 1], ... ] ... ), ... ), ... components=[], ... properties=Labels(["properties"], np.array([[0], [1], [2]])), ... ) >>> block_2 = TensorBlock( ... values=torch.tensor( ... [ ... [5.0, 8.0, 2.0], ... [1.0, 2.0, 8.0], ... ] ... ), ... samples=Labels( ... ["system", "atom"], ... np.array( ... [ ... [0, 0], ... [0, 1], ... ] ... ), ... ), ... components=[], ... properties=Labels(["properties"], np.array([[3], [4], [5]])), ... ) >>> keys = Labels(names=["key"], values=np.array([[0], [1]])) >>> tensor = TensorMap(keys, [block_1, block_2])
Create modules
>>> linear = torch.nn.Linear(3, 1, bias=False) >>> with torch.no_grad(): ... _ = linear.weight.copy_(torch.tensor([1.0, 1.0, 1.0])) >>> modules = [linear, deepcopy(linear)] >>> # you could also extend the module by some nonlinear activation function
Create ModuleMap from this ModucDict and apply it
>>> module_map = ModuleMap(tensor.keys, modules) >>> out = module_map(tensor) >>> out TensorMap with 2 blocks keys: key 0 1 >>> out[0].values tensor([[ 7.], [14.]], grad_fn=<MmBackward0>) >>> out[1].values tensor([[15.], [11.]], grad_fn=<MmBackward0>)
Let’s look at the metadata
>>> tensor[0] TensorBlock samples (2): ['system', 'atom'] components (): [] properties (3): ['properties'] gradients: None >>> out[0] TensorBlock samples (2): ['system', 'atom'] components (): [] properties (1): ['_'] gradients: None
It got completely lost because we cannot know in general what the output is. You can add in the initialization of the ModuleMap a TensorMap that contains the intended output Labels.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- classmethod from_module(in_keys: Labels, module: Module, out_properties: List[Labels] | None = None)[source]¶
A wrapper around one
torch.nn.Module
applying the same type of module on each tensor block.- Parameters:
in_keys (Labels) – A
metatensor.Labels
object that determines the keys of the module map that are ass the TensorMaps that are assumed to be in the input tensor map in theforward()
function.module (Module) – The module that is applied on each block.
out_properties (List[Labels] | None) – A list of labels that is used to determine the properties labels of the output. Because a module could change the number of properties, the labels of the properties cannot be persevered. By default the output properties are relabeled using Labels.range with “_” as key.
>>> import torch >>> import numpy as np >>> from metatensor import Labels, TensorBlock, TensorMap >>> block_1 = TensorBlock( ... values=torch.tensor( ... [ ... [1.0, 2.0, 4.0], ... [3.0, 5.0, 6.0], ... ] ... ), ... samples=Labels( ... ["system", "atom"], ... np.array( ... [ ... [0, 0], ... [0, 1], ... ] ... ), ... ), ... components=[], ... properties=Labels(["properties"], np.array([[0], [1], [2]])), ... ) >>> block_2 = TensorBlock( ... values=torch.tensor( ... [ ... [5.0, 8.0, 2.0], ... [1.0, 2.0, 8.0], ... ] ... ), ... samples=Labels( ... ["system", "atom"], ... np.array( ... [ ... [0, 0], ... [0, 1], ... ] ... ), ... ), ... components=[], ... properties=Labels(["properties"], np.array([[0], [1], [2]])), ... ) >>> keys = Labels(names=["key"], values=np.array([[0], [1]])) >>> tensor = TensorMap(keys, [block_1, block_2]) >>> linear = torch.nn.Linear(3, 1, bias=False) >>> with torch.no_grad(): ... _ = linear.weight.copy_(torch.tensor([1.0, 1.0, 1.0])) >>> # you could also extend the module by some nonlinear activation function >>> from metatensor.learn.nn import ModuleMap >>> module_map = ModuleMap.from_module(tensor.keys, linear) >>> out = module_map(tensor) >>> out[0].values tensor([[ 7.], [14.]], grad_fn=<MmBackward0>) >>> out[1].values tensor([[15.], [11.]], grad_fn=<MmBackward0>)
- forward(tensor: TensorMap) TensorMap [source]¶
Apply the modules on each block in
tensor
.tensor
must have the same set of keys as the modules used to initialize thisModuleMap
.
- get_module(key: LabelsEntry)[source]¶
- Parameters:
key (LabelsEntry) – key of module which should be returned
- Return module:
returns he torch.nn.Module corresponding to the :param key:
- property in_keys: Labels¶
A list of labels that defines the initialized keys with corresponding modules of this module map.
- class metatensor.learn.nn.Sequential(in_keys: Labels, *args: List[ModuleMap])[source]¶
A sequential model that applies a list of ModuleMaps to the input in order.
- Parameters:
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- class metatensor.learn.nn.Linear(in_keys: Labels, in_features: int | List[int], out_features: int | List[int] | None = None, out_properties: List[Labels] | None = None, *, bias: bool = True, device: device | None = None, dtype: dtype | None = None)[source]¶
Module similar to
torch.nn.Linear
that works withmetatensor.torch.TensorMap
.Applies a linear transformation to each block of a
TensorMap
passed to its forward method, indexed by :param in_keys:.Refer to the
torch.nn.Linear
documentation for a more detailed description of the other parameters.Each parameter is passed as a single value of its expected type, which is used as the parameter for all blocks.
- Parameters:
in_keys (Labels) –
Labels
, the keys that are assumed to be in the input tensor map in theforward()
method.in_features (int | List[int]) –
int
orlist
ofint
, the number of input features for each block. If passed as a single value, the same feature size is taken for all blocks.out_features (int | List[int] | None) –
int
orlint
ofint
, the number of output features for each block. If passed as a single value, the same feature size is taken for all blocks.out_properties (List[Labels] | None) – list of
Labels
(optional), the properties labels of the output. By default the output properties are relabeled using Labels.range. If provided, :param out_features: can be inferred and need not be provided.bias (bool)
device (device | None)
dtype (dtype | None)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- class metatensor.learn.nn.EquivariantLinear(in_keys: Labels, in_features: int | List[int], out_features: int | List[int] | None = None, out_properties: List[Labels] | None = None, invariant_keys: Labels | None = None, *, bias: bool = True, device: device | None = None, dtype: dtype | None = None)[source]¶
Module similar to
torch.nn.Linear
that works with equivariantmetatensor.torch.TensorMap
objects.Applies a linear transformation to each block of a
TensorMap
passed to its forward method, indexed by :param in_keys:.Refer to the
torch.nn.Linear
documentation for a more detailed description of the other parameters.For
EquivariantLinear
, by contrast toLinear
, the parameter :param bias: is only applied to modules corresponding to invariant blocks, i.e. keys in :param in_keys: that correspond to the selection in :param invariant_keys:.- Parameters:
in_keys (Labels) –
Labels
, the keys that are assumed to be in the input tensor map in theforward()
method.in_features (int | List[int]) –
int
orlist
ofint
, the number of input features for each block. If passed as a single value, the same feature size is taken for all blocks.out_features (int | List[int] | None) –
int
orlint
ofint
, the number of output features for each block. If passed as a single value, the same feature size is taken for all blocks.out_properties (List[Labels] | None) – list of
Labels
(optional), the properties labels of the output. By default the output properties are relabeled using Labels.range. If provided, :param out_features: can be inferred and need not be provided.invariant_keys (Labels | None) – a
Labels
object that is used to select the invariant keys fromin_keys
. If not provided, the invariant keys are assumed to be those where key dimensions["o3_lambda", "o3_sigma"]
are equal to[0, 1]
.bias (bool)
device (device | None)
dtype (dtype | None)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- class metatensor.learn.nn.Tanh(in_keys: Labels, out_properties: Labels | None = None)[source]¶
Module similar to
torch.nn.Tanh
that works withmetatensor.torch.TensorMap
objects.Applies a hyperbolic tangent transformation to each block of a
TensorMap
passed to its forward method, indexed by :param in_keys:.Refer to the
torch.nn.Tanh
documentation for a more detailed description of the parameters.- Parameters:
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- class metatensor.learn.nn.InvariantTanh(in_keys: Labels, out_properties: Labels | None = None, invariant_keys: Labels | None = None)[source]¶
Module similar to
torch.nn.Tanh
that works withmetatensor.torch.TensorMap
objects, applying the transformation only to the invariant blocks.Applies a hyperbolic tangent transformation to each invariant block of a
TensorMap
passed to itsforward()
method. These are indexed by the keys in :param in_keys: that correspond to the selection passed in :param invariant_keys:.Refer to the
torch.nn.Tanh
documentation for a more detailed description of the parameters.- Parameters:
in_keys (Labels) –
Labels
, the keys that are assumed to be in the input tensor map in theforward()
method.out_properties (Labels | None) – list of
Labels
(optional), the properties labels of the output. By default the output properties are relabeled using Labels.range.invariant_keys (Labels | None) – a
Labels
object that is used to select the invariant keys fromin_keys
. If not provided, the invariant keys are assumed to be those where key dimensions["o3_lambda", "o3_sigma"]
are equal to[0, 1]
.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- class metatensor.learn.nn.ReLU(in_keys: Labels, out_properties: Labels | None = None, *, in_place: bool = False)[source]¶
Module similar to
torch.nn.ReLU
that works withmetatensor.torch.TensorMap
objects.Applies a rectified linear unit transformation transformation to each block of a
TensorMap
passed to its forward method, indexed by :param in_keys:.Refer to the
torch.nn.ReLU
documentation for a more detailed description of the parameters.- Parameters:
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- class metatensor.learn.nn.InvariantReLU(in_keys: Labels, out_properties: Labels | None = None, invariant_keys: Labels | None = None, *, in_place: bool = False)[source]¶
Module similar to
torch.nn.ReLU
that works withmetatensor.torch.TensorMap
objects, applying the transformation only to the invariant blocks.Applies a rectified linear unit transformation to each invariant block of a
TensorMap
passed to itsforward()
method. These are indexed by the keys in :param in_keys: that correspond to the selection passed in :param invariant_keys:.Refer to the
torch.nn.ReLU
documentation for a more detailed description of the parameters.- Parameters:
in_keys (Labels) –
Labels
, the keys that are assumed to be in the input tensor map in theforward()
method.out_properties (Labels | None) – list of
Labels
(optional), the properties labels of the output. By default the output properties are relabeled using Labels.range.invariant_keys (Labels | None) – a
Labels
object that is used to select the invariant keys fromin_keys
. If not provided, the invariant keys are assumed to be those where key dimensions["o3_lambda", "o3_sigma"]
are equal to[0, 1]
.in_place (bool)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- class metatensor.learn.nn.SiLU(in_keys: Labels, out_properties: Labels | None = None, *, in_place: bool = False)[source]¶
Module similar to
torch.nn.SiLU
that works withmetatensor.torch.TensorMap
objects.Applies a sigmoid linear unit transformation transformation to each block of a
TensorMap
passed to its forward method, indexed by :param in_keys:.Refer to the
torch.nn.SiLU
documentation for a more detailed description of the parameters.- Parameters:
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- class metatensor.learn.nn.InvariantSiLU(in_keys: Labels, out_properties: Labels | None = None, invariant_keys: Labels | None = None, *, in_place: bool = False)[source]¶
Module similar to
torch.nn.SiLU
that works withmetatensor.torch.TensorMap
objects, applying the transformation only to the invariant blocks.Applies a sigmoid linear unit transformation to each invariant block of a
TensorMap
passed to itsforward()
method. These are indexed by the keys in :param in_keys: that correspond to the selection passed in :param invariant_keys:.Refer to the
torch.nn.SiLU
documentation for a more detailed description of the parameters.- Parameters:
in_keys (Labels) –
Labels
, the keys that are assumed to be in the input tensor map in theforward()
method.out_properties (Labels | None) – list of
Labels
(optional), the properties labels of the output. By default the output properties are relabeled using Labels.range.invariant_keys (Labels | None) – a
Labels
object that is used to select the invariant keys fromin_keys
. If not provided, the invariant keys are assumed to be those where key dimensions["o3_lambda", "o3_sigma"]
are equal to[0, 1]
.in_place (bool)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- class metatensor.learn.nn.LayerNorm(in_keys: Labels, in_features: List[int], out_properties: List[Labels] | None = None, *, eps: float = 1e-05, elementwise_affine: bool = True, bias: bool = True, mean: bool = True, device: device | None = None, dtype: dtype | None = None)[source]¶
Module similar to
torch.nn.LayerNorm
that works withmetatensor.torch.TensorMap
objects.Applies a layer normalization to each block of a
TensorMap
passed to itsforward()
method, indexed by :param in_keys:.The main difference from
torch.nn.LayerNorm
is that there is no normalized_shape parameter. Instead, the standard deviation and mean (if applicable) are calculated over all dimensions except the samples (first) dimension of eachTensorBlock
.The extra parameter :param mean: controls whether or not the mean over these dimensions is subtracted from the input tensor in the transformation.
Refer to the
torch.nn.LayerNorm
documentation for a more detailed description of the other parameters.Each parameter is passed as a single value of its expected type, which is used as the parameter for all blocks.
- Parameters:
in_keys (Labels) –
Labels
, the keys that are assumed to be in the input tensor map in theforward()
method.in_features (List[int]) – list of int, the number of features in the input tensor for each block indexed by the keys in :param in_keys:. If passed as a single value, the same number of features is assumed for all blocks.
out_properties (List[Labels] | None) – list of
Labels
(optional), the properties labels of the output. By default (if none) the output properties are relabeled using Labels.range.eps (float)
elementwise_affine (bool)
bias (bool)
mean (bool)
device (device | None)
dtype (dtype | None)
- Mean bool:
whether or not to subtract the mean over all dimensions except the samples (first) dimension of each block of the input passed to
forward()
.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- class metatensor.learn.nn.InvariantLayerNorm(in_keys: Labels, in_features: List[int], out_properties: List[Labels] | None = None, invariant_keys: Labels | None = None, *, eps: float = 1e-05, elementwise_affine: bool = True, bias: bool = True, mean: bool = True, device: device | None = None, dtype: dtype | None = None)[source]¶
Module similar to
torch.nn.LayerNorm
that works withmetatensor.torch.TensorMap
objects, applying the transformation only to the invariant blocks.Applies a layer normalization to each invariant block of a
TensorMap
passed toforward()
method. These are indexed by the keys in :param in_keys: that correspond to the selection passed in :param invariant_keys:.The main difference from
torch.nn.LayerNorm
is that there is no normalized_shape parameter. Instead, the standard deviation and mean (if applicable) are calculated over all dimensions except the samples (first) dimension of eachTensorBlock
.The extra parameter :param mean: controls whether or not the mean over these dimensions is subtracted from the input tensor in the transformation.
Refer to the
torch.nn.LayerNorm
documentation for a more detailed description of the other parameters.Each parameter is passed as a single value of its expected type, which is used as the parameter for all blocks.
- Parameters:
in_keys (Labels) –
Labels
, the keys that are assumed to be in the input tensor map in theforward()
method.in_features (List[int]) – list of int, the number of features in the input tensor for each block indexed by the keys in :param in_keys:. If passed as a single value, the same number of features is assumed for all blocks.
out_properties (List[Labels] | None) – list of
Labels
(optional), the properties labels of the output. By default (if none) the output properties are relabeled using Labels.range.invariant_keys (Labels | None) – a
Labels
object that is used to select the invariant keys fromin_keys
. If not provided, the invariant keys are assumed to be those where key dimensions["o3_lambda", "o3_sigma"]
are equal to[0, 1]
.eps (float)
elementwise_affine (bool)
bias (bool)
mean (bool)
device (device | None)
dtype (dtype | None)
- Mean bool:
whether or not to subtract the mean over all dimensions except the samples (first) dimension of each block of the input passed to
forward()
.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- class metatensor.learn.nn.EquivariantTransformation(modules: List[Module], in_keys: Labels, in_features: int | List[int], out_features: int | List[int] | None = None, out_properties: List[Labels] | None = None, invariant_keys: Labels | None = None)[source]¶
A custom
torch.nn.Module
that applies an arbitrary shape- and equivariance-preserving transformation to an inputTensorMap
.For invariant blocks (specified with
invariant_keys
), the respective transformation contained in :param modules: is applied as is. For covariant blocks, an invariant multiplier is created, applying the transformation to the norm of the block over the component dimensions.- Parameters:
modules (List[Module]) – a
list
oftorch.nn.Module
containing the transformations to be applied to each block indexed by :param in_keys:. Transformations for invariant and covariant blocks differ. See above.in_keys (Labels) –
Labels
, the keys that are assumed to be in the inputTensorMap
in theforward()
method.in_features (int | List[int]) –
list
ofint
, the number of features in the input tensor for each block indexed by the keys in :param in_keys:. If passed as a single value, the same number of features is assumed for all blocks.out_properties (List[Labels] | None) –
list
ofLabels
(optional), the properties labels of the output. By default the output properties are relabeled using Labels.range.invariant_keys (Labels | None) –
a
Labels
object that is used to select the invariant keys fromin_keys
. If not provided, the invariant keys are assumed to be those where key dimensions["o3_lambda", "o3_sigma"]
are equal to[0, 1]
.>>> import torch >>> import numpy as np >>> from metatensor import Labels, TensorBlock, TensorMap >>> from metatensor.learn.nn import EquivariantTransformation
Define a dummy invariant TensorBlock
>>> block_1 = TensorBlock( ... values=torch.randn(2, 1, 3), ... samples=Labels( ... ["system", "atom"], ... np.array( ... [ ... [0, 0], ... [0, 1], ... ] ... ), ... ), ... components=[Labels(["o3_mu"], np.array([[0]]))], ... properties=Labels(["properties"], np.array([[0], [1], [2]])), ... )
Define a dummy covariant TensorBlock
>>> block_2 = TensorBlock( ... values=torch.randn(2, 3, 3), ... samples=Labels( ... ["system", "atom"], ... np.array( ... [ ... [0, 0], ... [0, 1], ... ] ... ), ... ), ... components=[Labels(["o3_mu"], np.array([[-1], [0], [1]]))], ... properties=Labels(["properties"], np.array([[3], [4], [5]])), ... )
Create a TensorMap containing the dummy TensorBlocks
>>> keys = Labels(names=["o3_lambda"], values=np.array([[0], [1]])) >>> tensor = TensorMap(keys, [block_1, block_2])
Define the transformation to apply to the TensorMap
>>> modules = [torch.nn.Tanh(), torch.nn.Tanh()] >>> in_features = [len(tensor.block(key).properties) for key in tensor.keys]
Define the EquivariantTransformation module
>>> transformation = EquivariantTransformation( ... modules, ... tensor.keys, ... in_features, ... out_properties=[tensor.block(key).properties for key in tensor.keys], ... invariant_keys=Labels( ... ["o3_lambda"], np.array([0], dtype=np.int64).reshape(-1, 1) ... ), ... )
The output metadata are the same as the input
>>> transformation(tensor) TensorMap with 2 blocks keys: o3_lambda 0 1 >>> transformation(tensor)[0] TensorBlock samples (2): ['system', 'atom'] components (1): ['o3_mu'] properties (3): ['properties'] gradients: None
Initialize internal Module state, shared by both nn.Module and ScriptModule.