Source code for metatensor.learn.nn.relu
from typing import List, Optional
import torch
from torch.nn import Module
from .._backend import Labels, TensorMap
from .._dispatch import int_array_like
from .module_map import ModuleMap
[docs]
class ReLU(Module):
"""
Module similar to :py:class:`torch.nn.ReLU` that works with
:py:class:`metatensor.torch.TensorMap` objects.
Applies a rectified linear unit transformation transformation to each block of a
:py:class:`TensorMap` passed to its forward method, indexed by :param in_keys:.
Refer to the :py:class:`torch.nn.ReLU` documentation for a more detailed description
of the parameters.
:param in_keys: :py:class:`Labels`, the keys that are assumed to be in the input
tensor map in the :py:meth:`forward` method.
:param out_properties: list of :py:class:`Labels` (optional), the properties labels
of the output. By default the output properties are relabeled using
Labels.range.
"""
def __init__(
self,
in_keys: Labels,
out_properties: Optional[Labels] = None,
*,
in_place: bool = False,
) -> None:
super().__init__()
modules: List[Module] = [torch.nn.ReLU(in_place) for i in range(len(in_keys))]
self.module_map = ModuleMap(in_keys, modules, out_properties)
[docs]
def forward(self, tensor: TensorMap) -> TensorMap:
"""
Apply the transformation to the input tensor map `tensor`.
Note: currently not supporting gradients.
:param tensor: :py:class:`TensorMap` with the input tensor to be transformed.
:return: :py:class:`TensorMap`
"""
# Currently not supporting gradients
if len(tensor[0].gradients_list()) != 0:
raise ValueError(
"Gradients not supported. Please use metatensor.remove_gradients()"
" before using this module"
)
return self.module_map(tensor)
[docs]
class InvariantReLU(torch.nn.Module):
"""
Module similar to :py:class:`torch.nn.ReLU` that works with
:py:class:`metatensor.torch.TensorMap` objects, applying the transformation only to
the invariant blocks.
Applies a rectified linear unit transformation to each invariant block of a
:py:class:`TensorMap` passed to its :py:meth:`forward` method. These are indexed by
the keys in :param in_keys: that correspond to the selection passed in :param
invariant_keys:.
Refer to the :py:class:`torch.nn.ReLU` documentation for a more detailed description
of the parameters.
:param in_keys: :py:class:`Labels`, the keys that are assumed to be in the input
tensor map in the :py:meth:`forward` method.
:param out_properties: list of :py:class:`Labels` (optional), the properties labels
of the output. By default the output properties are relabeled using
Labels.range.
:param invariant_keys: a :py:class:`Labels` object that is used to select the
invariant keys from ``in_keys``. If not provided, the invariant keys are assumed
to be those where key dimensions ``["o3_lambda", "o3_sigma"]`` are equal to
``[0, 1]``.
"""
def __init__(
self,
in_keys: Labels,
out_properties: Optional[Labels] = None,
invariant_keys: Optional[Labels] = None,
*,
in_place: bool = False,
) -> None:
super().__init__()
# Set a default for invariant keys
if invariant_keys is None:
invariant_keys = Labels(
names=["o3_lambda", "o3_sigma"],
values=int_array_like([0, 1], like=in_keys.values).reshape(-1, 2),
)
invariant_key_idxs = in_keys.select(invariant_keys)
modules: List[Module] = []
for i in range(len(in_keys)):
if i in invariant_key_idxs: # Invariant block: apply ReLU
module = torch.nn.ReLU(in_place)
else: # Covariant block: apply identity operator
module = torch.nn.Identity()
modules.append(module)
self.module_map: ModuleMap = ModuleMap(in_keys, modules, out_properties)
[docs]
def forward(self, tensor: TensorMap) -> TensorMap:
"""
Apply the transformation to the input tensor map `tensor`.
Note: currently not supporting gradients.
:param tensor: :py:class:`TensorMap` with the input tensor to be transformed.
:return: :py:class:`TensorMap`
"""
# Currently not supporting gradients
if len(tensor[0].gradients_list()) != 0:
raise ValueError(
"Gradients not supported. Please use metatensor.remove_gradients()"
" before using this module"
)
return self.module_map(tensor)