Source code for metatensor.learn.nn.module_map

from copy import deepcopy
from typing import List, Optional, Union

import torch

from metatensor.operations import _dispatch

from .._backend import Labels, LabelsEntry, TensorBlock, TensorMap
from ._module import Module


@torch.jit.interface
class ModuleMapInterface(torch.nn.Module):
    """
    This interface required for TorchScript to index the :py:class:`torch.nn.ModuleDict`
    with non-literals in ModuleMap. Any module that is used with ModuleMap must
    implement this interface to be TorchScript compilable.

    Note that the *typings and argument names must match exactly* so that an interface
    is correctly implemented.

    Reference
    ---------
    https://github.com/pytorch/pytorch/pull/45716
    """

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        pass


[docs] class ModuleMap(Module): """ A class that imitates :py:class:`torch.nn.ModuleDict`. In its forward function the module at position `i` given on construction by :param modules: is applied to the tensor block that corresponding to the`i`th key in :param in_keys:. :param in_keys: A :py:class:`metatensor.Labels` object with the keys of the module map that are assumed to be in the input tensor map in the :py:meth:`forward` function. :param modules: A sequence of modules applied in the :py:meth:`forward` function on the input :py:class:`TensorMap`. Each module corresponds to one :py:class:`LabelsEntry` in :param in_keys: that determines on which :py:class:`TensorBlock` the module is applied on. :param modules: and :param in_keys: must match in length. :param out_properties: A list of labels that is used to determine the properties labels of the output. Because a module could change the number of properties, the labels of the properties cannot be persevered. By default the output properties are relabeled using Labels.range with "_" as key. >>> import torch >>> import numpy as np >>> from copy import deepcopy >>> from metatensor import Labels, TensorBlock, TensorMap >>> from metatensor.learn.nn import ModuleMap Create simple block >>> block_1 = TensorBlock( ... values=torch.tensor( ... [ ... [1.0, 2.0, 4.0], ... [3.0, 5.0, 6.0], ... ] ... ), ... samples=Labels( ... ["system", "atom"], ... np.array( ... [ ... [0, 0], ... [0, 1], ... ] ... ), ... ), ... components=[], ... properties=Labels(["properties"], np.array([[0], [1], [2]])), ... ) >>> block_2 = TensorBlock( ... values=torch.tensor( ... [ ... [5.0, 8.0, 2.0], ... [1.0, 2.0, 8.0], ... ] ... ), ... samples=Labels( ... ["system", "atom"], ... np.array( ... [ ... [0, 0], ... [0, 1], ... ] ... ), ... ), ... components=[], ... properties=Labels(["properties"], np.array([[3], [4], [5]])), ... ) >>> keys = Labels(names=["key"], values=np.array([[0], [1]])) >>> tensor = TensorMap(keys, [block_1, block_2]) Create modules >>> linear = torch.nn.Linear(3, 1, bias=False) >>> with torch.no_grad(): ... _ = linear.weight.copy_(torch.tensor([1.0, 1.0, 1.0])) >>> modules = [linear, deepcopy(linear)] >>> # you could also extend the module by some nonlinear activation function Create ModuleMap from this ModucDict and apply it >>> module_map = ModuleMap(tensor.keys, modules) >>> out = module_map(tensor) >>> out TensorMap with 2 blocks keys: key 0 1 >>> out[0].values tensor([[ 7.], [14.]], grad_fn=<MmBackward0>) >>> out[1].values tensor([[15.], [11.]], grad_fn=<MmBackward0>) Let's look at the metadata >>> tensor[0] TensorBlock samples (2): ['system', 'atom'] components (): [] properties (3): ['properties'] gradients: None >>> out[0] TensorBlock samples (2): ['system', 'atom'] components (): [] properties (1): ['_'] gradients: None It got completely lost because we cannot know in general what the output is. You can add in the initialization of the ModuleMap a TensorMap that contains the intended output Labels. """ def __init__( self, in_keys: Labels, modules: List[Module], out_properties: Optional[List[Labels]] = None, ): super().__init__() self.module_list = torch.nn.ModuleList(modules) if len(in_keys) != len(modules): raise ValueError( "in_keys and modules must match in length, but found " f"{len(in_keys) != len(modules)} [len(in_keys) != len(modules)]" ) self._in_keys: Labels = in_keys self._out_properties = out_properties
[docs] @classmethod def from_module( cls, in_keys: Labels, module: Module, out_properties: Optional[List[Labels]] = None, ): """ A wrapper around one :py:class:`torch.nn.Module` or :py:class:`metatensor.learn.nn.Module` applying the same type of module on each tensor block. :param in_keys: A :py:class:`metatensor.Labels` object that determines the keys of the module map that are ass the TensorMaps that are assumed to be in the input tensor map in the :py:meth:`forward` function. :param module: The module that is applied on each block. :param out_properties: A list of labels that is used to determine the properties labels of the output. Because a module could change the number of properties, the labels of the properties cannot be persevered. By default the output properties are relabeled using Labels.range with "_" as key. >>> import torch >>> import numpy as np >>> from metatensor import Labels, TensorBlock, TensorMap >>> block_1 = TensorBlock( ... values=torch.tensor( ... [ ... [1.0, 2.0, 4.0], ... [3.0, 5.0, 6.0], ... ] ... ), ... samples=Labels(["system", "atom"], np.array([[0, 0], [0, 1]])), ... components=[], ... properties=Labels(["properties"], np.array([[0], [1], [2]])), ... ) >>> block_2 = TensorBlock( ... values=torch.tensor( ... [ ... [5.0, 8.0, 2.0], ... [1.0, 2.0, 8.0], ... ] ... ), ... samples=Labels(["system", "atom"], np.array([[0, 0], [0, 1]])), ... components=[], ... properties=Labels(["properties"], np.array([[0], [1], [2]])), ... ) >>> keys = Labels(names=["key"], values=np.array([[0], [1]])) >>> tensor = TensorMap(keys, [block_1, block_2]) >>> linear = torch.nn.Linear(3, 1, bias=False) >>> with torch.no_grad(): ... _ = linear.weight.copy_(torch.tensor([1.0, 1.0, 1.0])) >>> # you could also extend the module by some nonlinear activation function >>> from metatensor.learn.nn import ModuleMap >>> module_map = ModuleMap.from_module(tensor.keys, linear) >>> out = module_map(tensor) >>> out[0].values tensor([[ 7.], [14.]], grad_fn=<MmBackward0>) >>> out[1].values tensor([[15.], [11.]], grad_fn=<MmBackward0>) """ module = deepcopy(module) modules = [] for _ in range(len(in_keys)): modules.append(deepcopy(module)) return cls(in_keys, modules, out_properties)
[docs] def forward(self, tensor: TensorMap) -> TensorMap: """ Apply the modules on each block in ``tensor``. ``tensor`` must have the same set of keys as the modules used to initialize this :py:class:`ModuleMap`. :param tensor: input tensor map """ out_blocks: List[TensorBlock] = [] for key, block in tensor.items(): out_block = self._forward_block(key, block) for parameter, gradient in block.gradients(): if len(gradient.gradients_list()) != 0: raise NotImplementedError( "gradients of gradients are not supported" ) out_block.add_gradient( parameter=parameter, gradient=self._forward_block(key, gradient), ) out_blocks.append(out_block) return TensorMap(tensor.keys, out_blocks)
def _forward_block(self, key: LabelsEntry, block: TensorBlock) -> TensorBlock: module_idx = self.in_keys.position(key) if module_idx is None: raise KeyError(f"key {key} not found in modules.") module: ModuleMapInterface = self.module_list[module_idx] out_values = module.forward(block.values) if self._out_properties is None: # we do not use `Labels.range` because of metatensor/issues/410 properties = Labels( "_", _dispatch.int_array_like( list(range(out_values.shape[-1])), block.samples.values ).reshape(-1, 1), # unique because `list(range(...))` produces unique entries assume_unique=True, ) else: properties = self._out_properties[module_idx] return TensorBlock( values=out_values, properties=properties, components=block.components, samples=block.samples, ) @property def in_keys(self) -> Labels: """ Labels that defines the keys this module map expects as input """ return self._in_keys @property def out_properties(self) -> Union[None, List[Labels]]: """ A list of labels that is used to determine properties of the output """ return self._out_properties