fromtypingimportList,Unionfrom.import_dispatchfrom._backendimport(TensorBlock,TensorMap,is_metatensor_class,torch_jit_is_scripting,torch_jit_script,)def_pow_block_constant(block:TensorBlock,constant:float)->TensorBlock:values=block.values[:]**constantresult_block=TensorBlock(values=values,samples=block.samples,components=block.components,properties=block.properties,)_shape:List[int]=[]forcinblock.components:_shape.append(len(c))_shape.append(len(block.properties))forparameter,gradientinblock.gradients():iflen(gradient.gradients_list())!=0:raiseNotImplementedError("gradients of gradients are not supported")gradient_values=gradient.values# we find the difference between the number of components# of the gradients and the values and then use it to create# empty dimensions for broadcastingdiff_components=len(gradient_values.shape)-len(block.values.shape)gradient_samples_to_values_samples=gradient.samples.column("sample")values_grad=(constant*gradient_values*block.values[_dispatch.to_index_array(gradient_samples_to_values_samples)].reshape([-1]+[1]*diff_components+_shape)**(constant-1))result_block.add_gradient(parameter=parameter,gradient=TensorBlock(values=values_grad,samples=gradient.samples,components=gradient.components,properties=gradient.properties,),)returnresult_block
[docs]@torch_jit_scriptdefpow(A:TensorMap,B:Union[float,int])->TensorMap:r"""Return a new :class:`TensorMap` with the same metadata of ``A`` and the values being the element-wise ``B``-power of ``A.values``. ``B`` can only be a scalar. If gradients are present in ``A`` the gradient of the resulting :class:`TensorMap` are given by the standard formula: .. math:: \nabla(A ^ B) = B* \nabla A * A^{(B-1)} :param A: :py:class:`TensorMap` to be elevated at the power of B. :param B: The power to which we want to elevate ``A``. Parameter can only be a scalar or something that can be converted to a scalar. :return: New :py:class:`TensorMap` with the same metadata as ``A``. """ifnottorch_jit_is_scripting():ifnotis_metatensor_class(A,TensorMap):raiseTypeError(f"`A` must be a metatensor TensorMap, not {type(A)}")ifnotisinstance(B,(float,int)):raiseTypeError(f"`B` must be a scalar value, not {type(B)}")B=float(B)blocks:List[TensorBlock]=[]forblock_AinA.blocks():blocks.append(_pow_block_constant(block=block_A,constant=B))returnTensorMap(A.keys,blocks)