Source code for metatensor.operations.remove_gradients
from typing import List, Optional
from ._classes import TensorBlock, TensorMap
[docs]
def remove_gradients(
tensor: TensorMap,
remove: Optional[List[str]] = None,
) -> TensorMap:
"""Remove some or all of the gradients from a :py:class:`TensorMap`.
:param tensor:
input :py:class:`TensorMap`, with gradients to remove
:param remove:
which gradients should be excluded from the new tensor map. If this is
set to :py:obj:`None` (this is the default), all the gradients will be
removed.
:returns:
A new tensormap without the gradients.
"""
if remove is None:
remove = tensor.block(0).gradients_list()
blocks: List[TensorBlock] = []
for block in tensor.blocks():
new_block = TensorBlock(
values=block.values,
samples=block.samples,
components=block.components,
properties=block.properties,
)
for parameter, gradient in block.gradients():
if parameter in remove:
continue
if len(gradient.gradients_list()) != 0:
raise NotImplementedError("gradients of gradients are not supported")
new_block.add_gradient(
parameter=parameter,
gradient=TensorBlock(
values=gradient.values,
samples=gradient.samples,
components=gradient.components,
properties=gradient.properties,
),
)
blocks.append(new_block)
return TensorMap(tensor.keys, blocks)