"""Module to find the absolute values of a :py:class:`TensorMap`, returning a new:py:class:`TensorMap`."""fromtypingimportListfrom.import_dispatchfrom._backendimportTensorBlock,TensorMap,torch_jit_scriptdef_abs_block(block:TensorBlock)->TensorBlock:""" Returns a :py:class:`TensorBlock` with the absolute values of the block values and associated gradient data. """values=_dispatch.abs(block.values)result_block=TensorBlock(values=values,samples=block.samples,components=block.components,properties=block.properties,)iflen(block.gradients_list())==0:returnresult_blocksign_values=_dispatch.sign(block.values)_shape:List[int]=[]forcinblock.components:_shape+=[len(c)]_shape+=[len(block.properties)]forparameter,gradientinblock.gradients():iflen(gradient.gradients_list())!=0:raiseNotImplementedError("gradients of gradients are not supported")diff_components=len(gradient.components)-len(block.components)# The sign_values have the same dimensions as that of the block.values.# Reshape the sign_values to allow multiplication with gradient.valuesnew_grad=gradient.values[:]*sign_values[_dispatch.to_index_array(gradient.samples.column("sample"))].reshape([-1]+[1]*diff_components+_shape)gradient=TensorBlock(new_grad,gradient.samples,gradient.components,gradient.properties)result_block.add_gradient(parameter,gradient)returnresult_block
[docs]@torch_jit_scriptdefabs(A:TensorMap)->TensorMap:r""" Return a new :py:class:`TensorMap` with the same metadata as A and absolute values of ``A``. .. math:: A \rightarrow = \vert A \vert If gradients are present in ``A``: .. math:: \nabla(A) \rightarrow \nabla(\vert A \vert) = (A/\vert A \vert)*\nabla A :param A: the input :py:class:`TensorMap`. :return: a new :py:class:`TensorMap` with the same metadata as ``A`` and absolute values of ``A``. """blocks:List[TensorBlock]=[]keys=A.keysforiinrange(len(keys)):blocks.append(_abs_block(block=A.block(keys.entry(i))))returnTensorMap(keys,blocks)