Source code for metatensor.operations.unique_metadata

"""
Module for finding unique metadata for TensorMaps and TensorBlocks
"""

from typing import List, Optional, Tuple, Union

from . import _dispatch
from ._classes import (
    Labels,
    TensorBlock,
    TensorMap,
    check_isinstance,
    torch_jit_is_scripting,
)


[docs] def unique_metadata( tensor: TensorMap, axis: str, names: Union[List[str], Tuple[str], str], gradient: Optional[str] = None, ) -> Labels: """ Returns a :py:class:`Labels` object containing the unique metadata across all blocks of the input :py:class:`TensorMap` ``tensor``. Unique Labels are returned for the specified ``axis`` (either ``"samples"`` or ``"properties"``) and metadata ``names``. Passing ``gradient`` as a ``str`` corresponding to a gradient parameter (for instance ``"cell"`` or ``"positions"``) returns the unique indices only for the gradient blocks. Note that gradient blocks by definition have the same properties metadata as their parent :py:class:`TensorBlock`. An empty :py:class:`Labels` object is returned if there are no indices in the (gradient) blocks of ``tensor`` corresponding to the specified ``axis`` and ``names``. This will have length zero but the names will be the same as passed in ``names``. For example, to find the unique ``"structure"`` indices in the ``"samples"`` metadata present in a given :py:class:`TensorMap`: .. code-block:: python import metatensor unique_structures = metatensor.unique_metadata( tensor, axis="samples", names=["structure"], ) Or, to find the unique ``"atom"`` indices in the ``"samples"`` metadata present in the ``"positions"`` gradient blocks of a given :py:class:`TensorMap`: .. code-block:: python unique_grad_atoms = metatensor.unique_metadata( tensor, axis="samples", names=["atom"], gradient="positions", ) The unique indices can then be used to split the :py:class:`TensorMap` into several smaller :py:class:`TensorMap` objects. Say, for example, that the ``unique_structures`` from the example above are: .. code-block:: python Labels( [(0,), (1,), (2,), (3,), (4,), (5,), (6,), (7,), (8,), (9,)], dtype=[("structure", "<i4")], ) Then, the following code will split the :py:class:`TensorMap` into 2 :py:class:`TensorMap` objects, with first containing structure indices 0-3 and the second containing structure indices 4-9: .. code-block:: python import metatensor [tensor_1, tensor_2] = metatensor.split( tensor, axis="samples", grouped_labels=[unique_structures[:4], unique_structures[4:]], ) :param tensor: the :py:class:`TensorMap` to find unique indices for. :param axis: a ``str``, either ``"samples"`` or ``"properties"``, corresponding to the ``axis`` along which the named unique indices should be found. :param names: a ``str``, ``list`` of ``str``, or ``tuple`` of ``str`` corresponding to the name(s) of the indices along the specified ``axis`` for which the unique values should be found. :param gradient: a ``str`` corresponding to the gradient parameter name for the gradient blocks to find the unique indices for. If :py:obj:`None` (default), the unique indices of the regular :py:class:`TensorBlock` objects will be calculated. :return: a sorted :py:class:`Labels` object containing the unique metadata for the blocks of the input ``tensor`` or its gradient blocks for the specified parameter. Each element in the returned :py:class:`Labels` object has len(``names``) entries. """ # Parse input args if not torch_jit_is_scripting(): if not check_isinstance(tensor, TensorMap): raise TypeError( f"`tensor` must be a metatensor TensorMap, not {type(tensor)}" ) names = ( [names] if isinstance(names, str) else (list(names) if isinstance(names, tuple) else names) ) _check_args(axis, names, gradient) # Make a list of the blocks to find unique indices for if gradient is None: blocks = tensor.blocks() else: blocks = [block.gradient(gradient) for block in tensor.blocks()] return _unique_from_blocks(blocks, axis, names)
[docs] def unique_metadata_block( block: TensorBlock, axis: str, names: Union[List[str], Tuple[str], str], gradient: Optional[str] = None, ) -> Labels: """ Returns a :py:class:`Labels` object containing the unique metadata in the input :py:class:`TensorBlock` ``block``, for the specified ``axis`` (either ``"samples"`` or ``"properties"``) and metadata ``names``. Passing ``gradient`` as a ``str`` corresponding to a gradient parameter (for instance ``"cell"`` or ``"positions"``) returns the unique indices only for the gradient block associated with ``block``. Note that gradient blocks by definition have the same properties metadata as their parent :py:class:`TensorBlock`. An empty :py:class:`Labels` object is returned if there are no indices in the (gradient) blocks of ``tensor`` corresponding to the specified ``axis`` and ``names``. This will have length zero but the names will be the same as passed in ``names``. For example, to find the unique ``"structure"`` indices in the ``"samples"`` metadata present in a given :py:class:`TensorBlock`: .. code-block:: python import metatensor unique_samples = metatensor.unique_metadata_block( block, axis="samples", names=["structure"], ) To find the unique ``"atom"`` indices along the ``"samples"`` axis present in the ``"positions"`` gradient block of a given :py:class:`TensorBlock`: .. code-block:: python unique_grad_samples = metatensor.unique_metadata_block( block, axis="samples", names=["atom"], gradient="positions", ) :param block: the :py:class:`TensorBlock` to find unique indices for. :param axis: a str, either ``"samples"`` or ``"properties"``, corresponding to the ``axis`` along which the named unique metadata should be found. :param names: a ``str``, ``list`` of ``str``, or ``tuple`` of ``str`` corresponding to the name(s) of the metadata along the specified ``axis`` for which the unique indices should be found. :param gradient: a ``str`` corresponding to the gradient parameter name for the gradient blocks to find the unique metadata for. If :py:obj:`None` (default), the unique metadata of the regular :py:class:`TensorBlock` objects will be calculated. :return: a sorted :py:class:`Labels` object containing the unique metadata for the input ``block`` or its gradient for the specified parameter. Each element in the returned :py:class:`Labels` object has len(``names``) entries. """ # Parse input args if not torch_jit_is_scripting(): if not check_isinstance(block, TensorBlock): raise TypeError( f"`block` must be a metatensor TensorBlock, not {type(block)}" ) names = ( [names] if isinstance(names, str) else (list(names) if isinstance(names, tuple) else names) ) _check_args(axis, names, gradient) # Make a list of the blocks to find unique indices for if gradient is None: blocks = [block] else: blocks = [block.gradient(gradient)] return _unique_from_blocks(blocks, axis, names)
def _unique_from_blocks( blocks: List[TensorBlock], axis: str, names: List[str], ) -> Labels: """ Finds the unique metadata of a list of blocks along the given ``axis`` and for the specified ``names``. """ all_values = [] for block in blocks: if axis == "samples": all_values.append(block.samples.view(names).values) else: assert axis == "properties" all_values.append(block.properties.view(names).values) unique_values = _dispatch.unique(_dispatch.concatenate(all_values, axis=0), axis=0) return Labels(names=names, values=unique_values) def _check_args( axis: str, names: List[str], gradient: Optional[str] = None, ): """Checks input args for `unique_metadata_block`""" if not torch_jit_is_scripting(): if not isinstance(axis, str): raise TypeError(f"`axis` must be a string, not {type(axis)}") if not isinstance(names, list): raise TypeError(f"`names` must be a list of strings, not {type(names)}") for name in names: if not isinstance(name, str): raise TypeError(f"`names` elements must be a strings, not {type(name)}") if gradient is not None: if not torch_jit_is_scripting(): if not isinstance(gradient, str): raise TypeError(f"`gradient` must be a string, not {type(gradient)}") if axis not in ["samples", "properties"]: raise ValueError( f"`axis` must be either 'samples' or 'properties', not '{axis}'" )