Source code for metatensor.learn.nn.module_map

from copy import deepcopy
from typing import List, Optional, Union

import torch
from torch.nn import Module, ModuleList

from metatensor.operations import _dispatch

from .._classes import Labels, LabelsEntry, TensorBlock, TensorMap


@torch.jit.interface
class ModuleMapInterface(torch.nn.Module):
    """
    This interface required for TorchScript to index the :py:class:`torch.nn.ModuleDict`
    with non-literals in ModuleMap. Any module that is used with ModuleMap must
    implement this interface to be TorchScript compilable.

    Note that the *typings and argument names must match exactly* so that an interface
    is correctly implemented.

    Reference
    ---------
    https://github.com/pytorch/pytorch/pull/45716
    """

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        pass


[docs] class ModuleMap(ModuleList): """ A class that imitates :py:class:`torch.nn.ModuleDict`. In its forward function the module at position `i` given on construction by :param modules: is applied to the tensor block that corresponding to the`i`th key in :param in_keys:. :param in_keys: A :py:class:`metatensor.Labels` object with the keys of the module map that are assumed to be in the input tensor map in the :py:meth:`forward` function. :param modules: A sequence of modules applied in the :py:meth:`forward` function on the input :py:class:`TensorMap`. Each module corresponds to one :py:class:`LabelsEntry` in :param in_keys: that determines on which :py:class:`TensorBlock` the module is applied on. :param modules: and :param in_keys: must match in length. :param out_properties: A list of labels that is used to determine the properties labels of the output. Because a module could change the number of properties, the labels of the properties cannot be persevered. By default the output properties are relabeled using Labels.range with "_" as key. >>> import torch >>> import numpy as np >>> from copy import deepcopy >>> from metatensor import Labels, TensorBlock, TensorMap >>> from metatensor.learn.nn import ModuleMap Create simple block >>> block_1 = TensorBlock( ... values=torch.tensor( ... [ ... [1.0, 2.0, 4.0], ... [3.0, 5.0, 6.0], ... ] ... ), ... samples=Labels( ... ["structure", "center"], ... np.array( ... [ ... [0, 0], ... [0, 1], ... ] ... ), ... ), ... components=[], ... properties=Labels(["properties"], np.array([[0], [1], [2]])), ... ) >>> block_2 = TensorBlock( ... values=torch.tensor( ... [ ... [5.0, 8.0, 2.0], ... [1.0, 2.0, 8.0], ... ] ... ), ... samples=Labels( ... ["structure", "center"], ... np.array( ... [ ... [0, 0], ... [0, 1], ... ] ... ), ... ), ... components=[], ... properties=Labels(["properties"], np.array([[3], [4], [5]])), ... ) >>> keys = Labels(names=["key"], values=np.array([[0], [1]])) >>> tensor = TensorMap(keys, [block_1, block_2]) Create modules >>> linear = torch.nn.Linear(3, 1, bias=False) >>> with torch.no_grad(): ... _ = linear.weight.copy_(torch.tensor([1.0, 1.0, 1.0])) ... >>> modules = [linear, deepcopy(linear)] >>> # you could also extend the module by some nonlinear activation function Create ModuleMap from this ModucDict and apply it >>> module_map = ModuleMap(tensor.keys, modules) >>> out = module_map(tensor) >>> out TensorMap with 2 blocks keys: key 0 1 >>> out[0].values tensor([[ 7.], [14.]], grad_fn=<MmBackward0>) >>> out[1].values tensor([[15.], [11.]], grad_fn=<MmBackward0>) Let's look at the metadata >>> tensor[0] TensorBlock samples (2): ['structure', 'center'] components (): [] properties (3): ['properties'] gradients: None >>> out[0] TensorBlock samples (2): ['structure', 'center'] components (): [] properties (1): ['_'] gradients: None It got completely lost because we cannot know in general what the output is. You can add in the initialization of the ModuleMap a TensorMap that contains the intended output Labels. """ def __init__( self, in_keys: Labels, modules: List[Module], out_properties: Optional[List[Labels]] = None, ): super().__init__(modules) if len(in_keys) != len(modules): raise ValueError( "in_keys and modules must match in length, but found " f"{len(in_keys) != len(modules)} [len(in_keys) != len(modules)]" ) self._in_keys: Labels = in_keys self._out_properties = out_properties
[docs] @classmethod def from_module( cls, in_keys: Labels, module: Module, out_properties: Optional[List[Labels]] = None, ): """ A wrapper around one :py:class:`torch.nn.Module` applying the same type of module on each tensor block. :param in_keys: A :py:class:`metatensor.Labels` object that determines the keys of the module map that are ass the TensorMaps that are assumed to be in the input tensor map in the :py:meth:`forward` function. :param module: The module that is applied on each block. :param out_properties: A list of labels that is used to determine the properties labels of the output. Because a module could change the number of properties, the labels of the properties cannot be persevered. By default the output properties are relabeled using Labels.range with "_" as key. >>> import torch >>> import numpy as np >>> from metatensor import Labels, TensorBlock, TensorMap >>> block_1 = TensorBlock( ... values=torch.tensor( ... [ ... [1.0, 2.0, 4.0], ... [3.0, 5.0, 6.0], ... ] ... ), ... samples=Labels( ... ["structure", "center"], ... np.array( ... [ ... [0, 0], ... [0, 1], ... ] ... ), ... ), ... components=[], ... properties=Labels(["properties"], np.array([[0], [1], [2]])), ... ) >>> block_2 = TensorBlock( ... values=torch.tensor( ... [ ... [5.0, 8.0, 2.0], ... [1.0, 2.0, 8.0], ... ] ... ), ... samples=Labels( ... ["structure", "center"], ... np.array( ... [ ... [0, 0], ... [0, 1], ... ] ... ), ... ), ... components=[], ... properties=Labels(["properties"], np.array([[0], [1], [2]])), ... ) >>> keys = Labels(names=["key"], values=np.array([[0], [1]])) >>> tensor = TensorMap(keys, [block_1, block_2]) >>> linear = torch.nn.Linear(3, 1, bias=False) >>> with torch.no_grad(): ... _ = linear.weight.copy_(torch.tensor([1.0, 1.0, 1.0])) ... >>> # you could also extend the module by some nonlinear activation function >>> from metatensor.learn.nn import ModuleMap >>> module_map = ModuleMap.from_module(tensor.keys, linear) >>> out = module_map(tensor) >>> out[0].values tensor([[ 7.], [14.]], grad_fn=<MmBackward0>) >>> out[1].values tensor([[15.], [11.]], grad_fn=<MmBackward0>) """ module = deepcopy(module) modules = [] for _ in range(len(in_keys)): modules.append(deepcopy(module)) return cls(in_keys, modules, out_properties)
[docs] def forward(self, tensor: TensorMap) -> TensorMap: """ Apply the modules on each block in ``tensor``. ``tensor`` must have the same set of keys as the modules used to initialize this :py:class:`ModuleMap`. :param tensor: input tensor map """ out_blocks: List[TensorBlock] = [] for key, block in tensor.items(): out_block = self.forward_block(key, block) for parameter, gradient in block.gradients(): if len(gradient.gradients_list()) != 0: raise NotImplementedError( "gradients of gradients are not supported" ) out_block.add_gradient( parameter=parameter, gradient=self.forward_block(key, gradient), ) out_blocks.append(out_block) return TensorMap(tensor.keys, out_blocks)
def forward_block(self, key: LabelsEntry, block: TensorBlock) -> TensorBlock: position = self._in_keys.position(key) if position is None: raise KeyError(f"key {key} not found in modules.") module_idx: int = position module: ModuleMapInterface = self[module_idx] out_values = module.forward(block.values) if self._out_properties is None: # we do not use range because of metatensor/issues/410 properties = Labels( "_", _dispatch.int_array_like( list(range(out_values.shape[-1])), block.samples.values ).reshape(-1, 1), ) else: properties = self._out_properties[module_idx] return TensorBlock( values=out_values, properties=properties, components=block.components, samples=block.samples, )
[docs] @torch.jit.export def get_module(self, key: LabelsEntry): """ :param key: key of module which should be returned :return module: returns he torch.nn.Module corresponding to the :param key: """ # type annotation in function signature had to be removed because of TorchScript position = self._in_keys.position(key) if position is None: raise KeyError(f"key {key} not found in modules.") module_idx: int = position module: ModuleMapInterface = self[module_idx] return module
@property def in_keys(self) -> Labels: """ A list of labels that defines the initialized keys with corresponding modules of this module map. """ return self._in_keys @property def out_properties(self) -> Union[None, List[Labels]]: """ A list of labels that is used to determine properties labels of the output of forward function. """ return self._out_properties
[docs] def repr_as_module_dict(self) -> str: """ Returns a string that is easier to read that the standard __repr__ showing the mapping from label entry key to module. """ representation = "ModuleMap(\n" for i, key in enumerate(self._in_keys): representation += f" ({key!r}): {self[i]!r}\n" representation += ")" return representation