Source code for metatensor.operations._utils

from typing import List, Union

from ._backend import TensorBlock, TensorMap


[docs] class NotEqualError(Exception): """Exception used to indicate that two metatensor objects are different""" pass
def _check_same_keys(a: TensorMap, b: TensorMap, fname: str) -> bool: """ Returns true if the keys of 2 TensorMaps are the same, without specification of the order, and false otherwise. """ return not bool(_check_same_keys_impl(a, b, fname)) def _check_same_keys_raise(a: TensorMap, b: TensorMap, fname: str) -> None: """ If the keys of 2 TensorMaps are not the same, raises a NotEqualError, otherwise returns None. """ message = _check_same_keys_impl(a, b, fname) if message != "": raise NotEqualError(message) def _check_same_keys_impl(a: TensorMap, b: TensorMap, fname: str) -> str: """ Checks if the keys of 2 TensorMaps are the same, without specification of the order. Returns an empty str if they are the same, otherwise returns a str message of a meaningful error. The functions verifies that the key names and length are the same, and that the key values are the same without specification of exact order. :param a: first :py:class:`TensorMap` for check :param b: second :py:class:`TensorMap` for check :param fname: name of the function where the check is performed. The input will be used to generate a meaningful error message. """ keys_a = a.keys keys_b = b.keys if keys_a.names != keys_b.names: return ( f"inputs to '{fname}' should have the same keys names, " f"got '{keys_a.names}' and '{keys_b.names}'" ) if len(keys_a) != len(keys_b): return ( f"inputs to '{fname}' should have the same number of blocks, " f"got {len(keys_a)} and {len(keys_b)}" ) if not all([keys_b[i] in keys_a for i in range(len(keys_b))]): return f"inputs to '{fname}' should have the same keys" return "" def _check_blocks( a: TensorBlock, b: TensorBlock, fname: str, check: Union[List[str], str] = "all", ) -> bool: """ Checks if the metadata of 2 TensorBlocks are the same. If not, returns false. Otherwise returns true. """ return not bool(_check_blocks_impl(a, b, fname, check)) def _check_blocks_raise( a: TensorBlock, b: TensorBlock, fname: str, check: Union[List[str], str] = "all", ) -> None: """ Checks if the metadata of two TensorBlocks are the same. If not, raises a NotEqualError. If invalid `check` are given, this function raises a `ValueError`. Otherwise returns it None. The message associated with the exception will contain more information on where the two :py:class:`TensorBlock` differ. See :py:func:`_check_blocks_impl` for more information on when two :py:class:`TensorBlock` are considered as equal. :param a: first :py:class:`TensorBlock` for check :param b: second :py:class:`TensorBlock` for check :param fname: name of the function where the check is performed. The input will be used to generate a meaningful error message. :param check: Which parts of the metadata to check. This can be a list containing any of ``'samples'``, ``'components'``, and ``'properties'``; or the string ``'all'`` to check everything. Defaults to ``'all'``. :raises: :py:class:`metatensor.NotEqualError` if the metadata of the blocks are different :raises: :py:class:`ValueError` if an invalid prop name in :param check: is given. See :param check: description for valid prop names """ message = _check_blocks_impl(a, b, fname, check) if message != "": raise NotEqualError(message) def _check_blocks_impl( a: TensorBlock, b: TensorBlock, fname: str, check: Union[List[str], str] = "all", ) -> str: """ Check if metadata between two TensorBlocks is consistent for an operation. The functions verifies that that the metadata of the given check is the same, in terms of length, dimension names, and order of the values. If they are not the same, an error message as a str is returned. Otherwise, an empty str is returned. :param a: first :py:class:`TensorBlock` for check :param b: second :py:class:`TensorBlock` for check :param fname: name of the function where the check is performed. The input will be used to generate a meaningful error message. :param check: Which parts of the metadata to check. This can be a list containing any of ``'samples'``, ``'components'``, and ``'properties'``; or the string ``'all'`` to check everything. Defaults to ``'all'``. """ if isinstance(check, str): if check == "all": metadata_to_check = ["samples", "components", "properties"] else: raise ValueError("`check` must be a list of strings or 'all'") else: metadata_to_check = check for metadata in metadata_to_check: if metadata == "samples": if not a.samples == b.samples: return ( f"inputs to '{fname}' should have the same samples, " "but they are not the same or not in the same order" ) elif metadata == "properties": if not a.properties == b.properties: return ( f"inputs to '{fname}' should have the same properties, " "but they are not the same or not in the same order" ) elif metadata == "components": if len(a.components) != len(b.components): return f"inputs to '{fname}' have a different number of components" for c1, c2 in zip(a.components, b.components): if not c1 == c2: return ( f"inputs to '{fname}' should have the same components, " "but they are not the same or not in the same order" ) else: raise ValueError( f"'{metadata}' does not refer to metadata to check, " "choose from 'samples', 'properties' and 'components'" ) return "" def _check_same_gradients( a: TensorBlock, b: TensorBlock, fname: str, check: Union[List[str], str] = "all", ) -> bool: """ Check if metadata between the gradients of 2 TensorBlocks is consistent for an operation. If they are the same, true is returned, otherwise false. The functions verifies that that the metadata of the given check is the same, in terms of length, dimension names, and order of the values. If check is None it only checks if the ``'parameters'`` are consistent. :param a: first :py:class:`TensorBlock` for check :param b: second :py:class:`TensorBlock` for check :param fname: name of the function where the check is performed. The input will be used to generate a meaningful error message. :param check: Which parts of the metadata to check. This can be a list containing any of ``'samples'``, ``'components'``, and ``'properties'``; or the string ``'all'`` to check everything. Defaults to ``'all'``. If you only want to check if the two blocks have the same gradients, pass an empty list ``check=[]``. """ return not bool(_check_same_gradients_impl(a, b, fname, check)) def _check_same_gradients_raise( a: TensorBlock, b: TensorBlock, fname: str, check: Union[List[str], str] = "all", ) -> None: """ Check if two TensorBlocks gradients have identical metadata. The message associated with the exception will contain more information on where the gradients of the two :py:class:`TensorBlock` differ. See :py:func:`_check_same_gradients_impl` for more information on when gradients of :py:class:`TensorBlock` are considered as equal. :param a: first :py:class:`TensorBlock` for check :param b: second :py:class:`TensorBlock` for check :param fname: name of the function where the check is performed. The input will be used to generate a meaningful error message. :param check: Which parts of the metadata to check. This can be a list containing any of ``'samples'``, ``'components'``, and ``'properties'``; or the string ``'all'`` to check everything. Defaults to ``'all'``. If you only want to check if the two blocks have the same gradients, pass an empty list ``check=[]``. :raises: :py:class:`metatensor.NotEqualError` if the gradients of the blocks are different :raises: :py:class:`ValueError` if an invalid prop name in :param check: is given. See :param check: description for valid prop names """ message = _check_same_gradients_impl(a, b, fname, check) if message != "": raise NotEqualError(message) def _check_same_gradients_impl( a: TensorBlock, b: TensorBlock, fname: str, check: Union[List[str], str] = "all", ) -> str: """ Check if metadata between the gradients of two TensorBlocks is consistent for an operation. The functions verifies that that the 2 TensorBlocks have the same gradient parameters, then checks the metadata of the given ``check`` is the same, in terms of length, dimension names, and order of the values. If they are not the same, an error message as a str is returned. Otherwise, an empty str is returned. If the 2 blocks have no gradients, an empty string is returned. :param a: first :py:class:`TensorBlock` whose gradients are to be checked :param b: second :py:class:`TensorBlock` whose gradients are to be checked :param fname: name of the function where the check is performed. The input will be used to generate a meaningful error message. :param check: Which parts of the metadata to check. This can be a list containing any of ``'samples'``, ``'components'``, and ``'properties'``; or the string ``'all'`` to check everything. Defaults to ``'all'``. If you only want to check if the two blocks have the same gradients, pass an empty list ``check=[]``. """ if isinstance(check, str): if check == "all": metadata_to_check = ["samples", "components", "properties"] else: raise ValueError("`check` must be a list of strings or 'all'") else: metadata_to_check = check err_msg = f"inputs to '{fname}' should have the same gradients: " gradients_list_a = a.gradients_list() gradients_list_b = b.gradients_list() if len(gradients_list_a) != len(gradients_list_b) or ( not all([parameter in gradients_list_b for parameter in gradients_list_a]) ): return f"inputs to '{fname}' should have the same gradient parameters" for parameter, grad_a in a.gradients(): grad_b = b.gradient(parameter) for metadata in metadata_to_check: err_msg_1 = ( f"gradient '{parameter}' {metadata} are not the same or not in the " "same order" ) if metadata == "samples": if not grad_a.samples == grad_b.samples: return err_msg + err_msg_1 elif metadata == "properties": if not grad_a.properties == grad_b.properties: return err_msg + err_msg_1 elif metadata == "components": if len(grad_a.components) != len(grad_b.components): extra = ( f"gradient '{parameter}' have different number of components" ) return err_msg + extra for c1, c2 in zip(grad_a.components, grad_b.components): if not c1 == c2: return err_msg + err_msg_1 else: raise ValueError( f"{metadata} is not a valid property to check, " "choose from 'samples', 'properties' and 'components'" ) return "" def _check_gradient_presence_raise( block: TensorBlock, parameters: List[str], fname: str, ) -> None: """ For a single TensorBlock checks if each of the passed ``parameters`` are present as parameters of its gradients. If all of them are present, None is returned. Otherwise a ValueError is raised. """ for parameter in parameters: if parameter not in block.gradients_list(): raise ValueError( f"requested gradient '{parameter}' in '{fname}' is not defined " "in this tensor" )