[docs]@torch_jit_scriptdefremove_gradients_block(block:TensorBlock,remove:Optional[List[str]]=None,)->TensorBlock:"""Remove some or all of the gradients from a :py:class:`TensorBlock`. This function is related but different to :py:func:`metatensor.detach_block`. This function removes the explicit forward mode gradients stored in the ``block``, while :py:func:`metatensor.detach_block` separate the values (as well as any potential gradients) from the underlying computational graph use by PyTorch to run backward differentiation. :param block: :py:class:`TensorBlock` with gradients to be removed :param remove: which gradients should be excluded from the new block. If this is set to :py:obj:`None` (this is the default), all the gradients will be removed. :returns: A new :py:class:`TensorBlock` without the gradients in ``remove``. """ifremoveisNone:remove=block.gradients_list()new_block=TensorBlock(values=block.values,samples=block.samples,components=block.components,properties=block.properties,)forparameter,gradientinblock.gradients():ifparameterinremove:continueiflen(gradient.gradients_list())!=0:raiseNotImplementedError("gradients of gradients are not supported")new_block.add_gradient(parameter=parameter,gradient=TensorBlock(values=gradient.values,samples=gradient.samples,components=gradient.components,properties=gradient.properties,),)returnnew_block
[docs]@torch_jit_scriptdefremove_gradients(tensor:TensorMap,remove:Optional[List[str]]=None,)->TensorMap:"""Remove some or all of the gradients from a :py:class:`TensorMap`. This function is related but different to :py:func:`metatensor.detach`. This function removes the explicit forward mode gradients stored in the blocks, while :py:func:`metatensor.detach` separate the values (as well as any potential gradients) from the underlying computational graph use by PyTorch to run backward differentiation. :param tensor: :py:class:`TensorMap` with gradients to be removed :param remove: which gradients should be excluded from the new tensor map. If this is set to :py:obj:`None` (this is the default), all the gradients will be removed. :returns: A new :py:class:`TensorMap` without the gradients in ``remove``. """ifremoveisNoneandlen(tensor)!=0:remove=tensor.block(0).gradients_list()blocks:List[TensorBlock]=[]forblockintensor.blocks():blocks.append(remove_gradients_block(block,remove))returnTensorMap(tensor.keys,blocks)