Source code for metatensor.operations.abs

"""
Module to find the absolute values of a :py:class:`TensorMap`, returning a new
:py:class:`TensorMap`.
"""

from typing import List

from . import _dispatch
from ._classes import TensorBlock, TensorMap


[docs] def abs(A: TensorMap) -> TensorMap: r""" Return a new :py:class:`TensorMap` with the same metadata as A and absolute values of ``A``. .. math:: A \rightarrow = \vert A \vert If gradients are present in ``A``: .. math:: \nabla(A) \rightarrow \nabla(\vert A \vert) = (A/\vert A \vert)*\nabla A :param A: the input :py:class:`TensorMap`. :return: a new :py:class:`TensorMap` with the same metadata as ``A`` and absolute values of ``A``. """ blocks: List[TensorBlock] = [] keys = A.keys for i in range(len(keys)): blocks.append(_abs_block(block=A.block(keys.entry(i)))) return TensorMap(keys, blocks)
def _abs_block(block: TensorBlock) -> TensorBlock: """ Returns a :py:class:`TensorBlock` with the absolute values of the block values and associated gradient data. """ values = _dispatch.abs(block.values) result_block = TensorBlock( values=values, samples=block.samples, components=block.components, properties=block.properties, ) if len(block.gradients_list()) == 0: return result_block sign_values = _dispatch.sign(block.values) _shape: List[int] = [] for c in block.components: _shape += [len(c)] _shape += [len(block.properties)] for parameter, gradient in block.gradients(): if len(gradient.gradients_list()) != 0: raise NotImplementedError("gradients of gradients are not supported") diff_components = len(gradient.components) - len(block.components) # The sign_values have the same dimensions as that of the block.values. # Reshape the sign_values to allow multiplication with gradient.values new_grad = gradient.values[:] * sign_values[ _dispatch.to_index_array(gradient.samples.column("sample")) ].reshape([-1] + [1] * diff_components + _shape) gradient = TensorBlock( new_grad, gradient.samples, gradient.components, gradient.properties ) result_block.add_gradient(parameter, gradient) return result_block