remove_gradients

metatensor.remove_gradients(tensor: TensorMap, remove: List[str] | None = None) TensorMap[source]

Remove some or all of the gradients from a TensorMap.

This function is related but different to metatensor.detach(). This function removes the explicit forward mode gradients stored in the blocks, while metatensor.detach() separate the values (as well as any potential gradients) from the underlying computational graph use by PyTorch to run backward differentiation.

Parameters:
  • tensor (TensorMap) – TensorMap with gradients to be removed

  • remove (List[str] | None) – which gradients should be excluded from the new tensor map. If this is set to None (this is the default), all the gradients will be removed.

Returns:

A new TensorMap without the gradients in remove.

Return type:

TensorMap

metatensor.remove_gradients_block(block: TensorBlock, remove: List[str] | None = None) TensorBlock[source]

Remove some or all of the gradients from a TensorBlock.

This function is related but different to metatensor.detach_block(). This function removes the explicit forward mode gradients stored in the block, while metatensor.detach_block() separate the values (as well as any potential gradients) from the underlying computational graph use by PyTorch to run backward differentiation.

Parameters:
  • block (TensorBlock) – TensorBlock with gradients to be removed

  • remove (List[str] | None) – which gradients should be excluded from the new block. If this is set to None (this is the default), all the gradients will be removed.

Returns:

A new TensorBlock without the gradients in remove.

Return type:

TensorBlock