Source code for metatensor.learn.nn.linear

from typing import List, Optional, Union

import torch

from .._classes import Labels, TensorMap
from .module_map import ModuleMap


[docs] class Linear(ModuleMap): """ Construct a module map with only linear modules. :param in_keys: The keys that are assumed to be in the input tensor map in the :py:meth:`forward` function. :param in_features: Specifies the dimensionality of the input after it has been applied on the input tensor map. If a list of integers is given, it specifies the input dimension for each block, therefore it should have the same length as :param in_keys:. If only one integer is given, it is assumed that the same be applied on each block. :param out_features: Specifies the dimensionality of the output after it has been applied on the input tensor map. If a list of integers is given, it specifies the output dimension for each block, therefore it should have the same length as :param in_keys:. If only one integer is given, it is assumed that the same be applied on each block. :param bias: Specifies if a bias term (offset) should be applied on the input tensor map in the forward function. If a list of bools is given, it specifies the bias term for each block, therefore it should have the same length as :param in_keys:. If only one bool value is given, it is assumed that the same be applied on each block. :param device: Specifies the torch device of the values. If None the default torch device is taken. :param dtype: Specifies the torch dtype of the values. If None the default torch dtype is taken. :param out_properties: A list of labels that is used to determine the properties labels of the output. Because a module could change the number of properties, the labels of the properties cannot be persevered. By default the output properties are relabeled using Labels.range. """ def __init__( self, in_keys: Labels, in_features: Union[int, List[int]], out_features: Union[int, List[int]], bias: Union[bool, List[bool]] = True, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, out_properties: Optional[List[Labels]] = None, ): if isinstance(in_features, int): in_features = [in_features] * len(in_keys) elif not (isinstance(in_features, List)): raise TypeError( "`in_features` must be integer or List of integers, but not" f" {type(in_features)}." ) elif len(in_keys) != len(in_features): raise ValueError( "`in_features` must have same length as `in_keys`, but" f" len(in_features) != len(in_keys) [{len(in_features)} !=" f" {len(in_keys)}]" ) if isinstance(out_features, int): out_features = [out_features] * len(in_keys) elif not (isinstance(out_features, List)): raise TypeError( "`out_features` must be integer or List of integers, but not" f" {type(out_features)}." ) elif len(in_keys) != len(out_features): raise ValueError( "`out_features` must have same length as `in_keys`, but" f" len(out_features) != len(in_keys) [{len(out_features)} !=" f" {len(in_keys)}]" ) if isinstance(bias, bool): bias = [bias] * len(in_keys) elif not (isinstance(bias, List)): raise TypeError( f"`bias` must be bool or List of bools, but not {type(bias)}." ) elif len(in_keys) != len(bias): raise ValueError( "`bias` must have same length as `in_keys`, but len(bias) !=" f" len(in_keys) [{len(bias)} != {len(in_keys)}]" ) modules = [] for i in range(len(in_keys)): module = torch.nn.Linear( in_features[i], out_features[i], bias[i], device, dtype, ) modules.append(module) super().__init__(in_keys, modules, out_properties)
[docs] @classmethod def from_weights(cls, weights: TensorMap, bias: Optional[TensorMap] = None): """ Construct a linear module map from a tensor map for the weights and bias. :param weights: The weight tensor map from which we create the linear modules. The properties of the tensor map describe the input dimension and the samples describe the output dimension. :param bias: The weight tensor map from which we create the linear layers. """ linear = cls( weights.keys, in_features=[len(weights[i].samples) for i in range(len(weights))], out_features=[len(weights[i].properties) for i in range(len(weights))], bias=False, device=weights.device, dtype=weights[0].values.dtype, # due to bug metatensor/issues/464 # we use the dtype of the values out_properties=[weights[i].properties for i in range(len(weights))], ) for i in range(len(weights)): key = weights.keys[i] weights_block = weights[i] linear[i].weight = torch.nn.Parameter(weights_block.values.T) if bias is not None: linear[i].bias = torch.nn.Parameter(bias.block(key).values) return linear