Source code for metatensor.operations.subtract

from typing import Union

from ._backend import (
    TensorMap,
    is_metatensor_class,
    torch_jit_is_scripting,
    torch_jit_script,
)
from .add import add
from .multiply import multiply


[docs] @torch_jit_script def subtract(A: TensorMap, B: Union[float, int, TensorMap]) -> TensorMap: r"""Return a new :class:`TensorMap` with the values being the subtract of ``A`` and ``B``. If ``B`` is a :py:class:`TensorMap` it has to have the same metadata as ``A``. If gradients are present in ``A``: * ``B`` is a scalar: .. math:: \nabla(A - B) = \nabla A * ``B`` is a :py:class:`TensorMap` with the same metadata of ``A``: .. math:: \nabla(A - B) = \nabla A - \nabla B :param A: First :py:class:`TensorMap` for the subtraction. :param B: Second instance for the subtraction. Parameter can be a scalar or a :py:class:`TensorMap`. In the latter case ``B`` must have the same metadata of ``A``. :return: New :py:class:`TensorMap` with the same metadata as ``A``. """ if not torch_jit_is_scripting(): if not is_metatensor_class(A, TensorMap): raise TypeError(f"`A` must be a metatensor TensorMap, not {type(A)}") if torch_jit_is_scripting(): is_tensor_map = isinstance(B, TensorMap) else: is_tensor_map = is_metatensor_class(B, TensorMap) if isinstance(B, (float, int)): B = -float(B) elif is_tensor_map: B = multiply(B, -1) else: if torch_jit_is_scripting(): extra = "" else: extra = f", not {type(B)}" raise TypeError("`B` must be a metatensor TensorMap or a scalar value" + extra) return add(A=A, B=B)