Source code for metatensor.operations.requires_grad
from typing import List
from . import _dispatch
from ._classes import TensorBlock, TensorMap
[docs]
def requires_grad(tensor: TensorMap, requires_grad: bool = True) -> TensorMap:
    """
    Set ``requires_grad`` on all arrays (blocks and gradients of blocks) in this
    ``tensor`` to the provided value.
    This is mainly intended for torch arrays, and will warn if trying to set
    ``requires_grad=True`` with numpy arrays.
    :param tensor: :py:class:`TensorMap` to modify
    :param requires_grad: new value for ``requires_grad``
    """
    blocks: List[TensorBlock] = []
    for block in tensor.blocks():
        blocks.append(requires_grad_block(block, requires_grad=requires_grad))
    return TensorMap(tensor.keys, blocks)
[docs]
def requires_grad_block(block: TensorBlock, requires_grad: bool = True) -> TensorBlock:
    """
    Set ``requires_grad`` on the values and all gradients in this ``block`` to the
    provided value.
    This is mainly intended for torch arrays, and will warn if trying to set
    ``requires_grad=True`` with numpy arrays.
    :param block: :py:class:`TensorBlock` to modify
    :param requires_grad: new value for ``requires_grad``
    """
    new_block = TensorBlock(
        values=_dispatch.requires_grad(block.values, value=requires_grad),
        samples=block.samples,
        components=block.components,
        properties=block.properties,
    )
    for parameter, gradient in block.gradients():
        if len(gradient.gradients_list()) != 0:
            raise NotImplementedError("gradients of gradients are not supported")
        new_block.add_gradient(
            parameter=parameter,
            gradient=TensorBlock(
                values=_dispatch.requires_grad(gradient.values, value=requires_grad),
                samples=gradient.samples,
                components=gradient.components,
                properties=gradient.properties,
            ),
        )
    return new_block