Source code for metatensor.operations.remove_gradients

from typing import List, Optional

from ._backend import TensorBlock, TensorMap, torch_jit_script


[docs] @torch_jit_script def remove_gradients_block( block: TensorBlock, remove: Optional[List[str]] = None, ) -> TensorBlock: """Remove some or all of the gradients from a :py:class:`TensorBlock`. This function is related but different to :py:func:`metatensor.detach_block`. This function removes the explicit forward mode gradients stored in the ``block``, while :py:func:`metatensor.detach_block` separate the values (as well as any potential gradients) from the underlying computational graph use by PyTorch to run backward differentiation. :param block: :py:class:`TensorBlock` with gradients to be removed :param remove: which gradients should be excluded from the new block. If this is set to :py:obj:`None` (this is the default), all the gradients will be removed. :returns: A new :py:class:`TensorBlock` without the gradients in ``remove``. """ if remove is None: remove = block.gradients_list() new_block = TensorBlock( values=block.values, samples=block.samples, components=block.components, properties=block.properties, ) for parameter, gradient in block.gradients(): if parameter in remove: continue if len(gradient.gradients_list()) != 0: raise NotImplementedError("gradients of gradients are not supported") new_block.add_gradient( parameter=parameter, gradient=TensorBlock( values=gradient.values, samples=gradient.samples, components=gradient.components, properties=gradient.properties, ), ) return new_block
[docs] @torch_jit_script def remove_gradients( tensor: TensorMap, remove: Optional[List[str]] = None, ) -> TensorMap: """Remove some or all of the gradients from a :py:class:`TensorMap`. This function is related but different to :py:func:`metatensor.detach`. This function removes the explicit forward mode gradients stored in the blocks, while :py:func:`metatensor.detach` separate the values (as well as any potential gradients) from the underlying computational graph use by PyTorch to run backward differentiation. :param tensor: :py:class:`TensorMap` with gradients to be removed :param remove: which gradients should be excluded from the new tensor map. If this is set to :py:obj:`None` (this is the default), all the gradients will be removed. :returns: A new :py:class:`TensorMap` without the gradients in ``remove``. """ if remove is None and len(tensor) != 0: remove = tensor.block(0).gradients_list() blocks: List[TensorBlock] = [] for block in tensor.blocks(): blocks.append(remove_gradients_block(block, remove)) return TensorMap(tensor.keys, blocks)