Source code for metatensor.operations.unique_metadata
"""
Module for finding unique metadata for TensorMaps and TensorBlocks
"""
from typing import List, Optional, Tuple, Union
from . import _dispatch
from ._backend import (
    Labels,
    TensorBlock,
    TensorMap,
    check_isinstance,
    torch_jit_is_scripting,
    torch_jit_script,
)
def _unique_from_blocks(
    blocks: List[TensorBlock],
    axis: str,
    names: List[str],
) -> Labels:
    """
    Finds the unique metadata of a list of blocks along the given ``axis`` and for the
    specified ``names``.
    """
    all_values = []
    for block in blocks:
        if axis == "samples":
            all_values.append(block.samples.view(names).values)
        else:
            assert axis == "properties"
            all_values.append(block.properties.view(names).values)
    unique_values = _dispatch.unique(_dispatch.concatenate(all_values, axis=0), axis=0)
    return Labels(names=names, values=unique_values)
def _check_args(
    axis: str,
    names: List[str],
    gradient: Optional[str] = None,
):
    """Checks input args for `unique_metadata_block`"""
    if not torch_jit_is_scripting():
        if not isinstance(axis, str):
            raise TypeError(f"`axis` must be a string, not {type(axis)}")
        if not isinstance(names, list):
            raise TypeError(f"`names` must be a list of strings, not {type(names)}")
        for name in names:
            if not isinstance(name, str):
                raise TypeError(f"`names` elements must be a strings, not {type(name)}")
    if gradient is not None:
        if not torch_jit_is_scripting():
            if not isinstance(gradient, str):
                raise TypeError(f"`gradient` must be a string, not {type(gradient)}")
    if axis not in ["samples", "properties"]:
        raise ValueError(
            f"`axis` must be either 'samples' or 'properties', not '{axis}'"
        )
[docs]
@torch_jit_script
def unique_metadata(
    tensor: TensorMap,
    axis: str,
    names: Union[List[str], Tuple[str], str],
    gradient: Optional[str] = None,
) -> Labels:
    """
    Returns a :py:class:`Labels` object containing the unique metadata across all blocks
    of the input :py:class:`TensorMap`  ``tensor``. Unique Labels are returned for the
    specified ``axis`` (either ``"samples"`` or ``"properties"``) and metadata
    ``names``.
    Passing ``gradient`` as a ``str`` corresponding to a gradient parameter (for
    instance ``"strain"`` or ``"positions"``) returns the unique indices only for the
    gradient blocks. Note that gradient blocks by definition have the same properties
    metadata as their parent :py:class:`TensorBlock`.
    An empty :py:class:`Labels` object is returned if there are no indices in the
    (gradient) blocks of ``tensor`` corresponding to the specified ``axis`` and
    ``names``. This will have length zero but the names will be the same as passed in
    ``names``.
    For example, to find the unique ``"system"`` indices in the ``"samples"`` metadata
    present in a given :py:class:`TensorMap`:
    >>> import numpy as np
    >>> from metatensor import Labels, TensorBlock, TensorMap
    >>> import metatensor
    >>> block = TensorBlock(
    ...     values=np.random.rand(5, 3),
    ...     samples=Labels(
    ...         names=["system", "atom"],
    ...         values=np.array([[0, 0], [0, 1], [1, 0], [1, 1], [2, 3]]),
    ...     ),
    ...     components=[],
    ...     properties=Labels.range("properties", 3),
    ... )
    >>> keys = Labels(names=["key"], values=np.array([[0]]))
    >>> tensor = TensorMap(keys, [block.copy()])
    >>> unique_systems = metatensor.unique_metadata(
    ...     tensor,
    ...     axis="samples",
    ...     names=["system"],
    ... )
    >>> unique_systems
    Labels(
        system
          0
          1
          2
    )
    Or, to find the unique ``(system, atom)`` pairs of indices in the ``"samples"``
    metadata present in the ``"positions"`` gradient blocks of a given
    :py:class:`TensorMap`:
    >>> gradient = TensorBlock(
    ...     values=np.random.rand(4, 3, 3),
    ...     samples=Labels(
    ...         names=["sample", "system", "atom"],
    ...         values=np.array([[0, 0, 0], [1, 0, 0], [2, 1, 4], [3, 2, 5]]),
    ...     ),
    ...     components=[Labels.range("xyz", 3)],
    ...     properties=Labels.range("properties", 3),
    ... )
    >>> block.add_gradient("positions", gradient)
    >>> tensor = TensorMap(keys, [block])
    >>> metatensor.unique_metadata(
    ...     tensor,
    ...     axis="samples",
    ...     names=["system", "atom"],
    ...     gradient="positions",
    ... )
    Labels(
        system  atom
          0      0
          1      4
          2      5
    )
    :param tensor: the :py:class:`TensorMap` to find unique indices for.
    :param axis: a ``str``, either ``"samples"`` or ``"properties"``, corresponding to
        the ``axis`` along which the named unique indices should be found.
    :param names: a ``str``, ``list`` of ``str``, or ``tuple`` of ``str`` corresponding
        to the name(s) of the indices along the specified ``axis`` for which the unique
        values should be found.
    :param gradient: a ``str`` corresponding to the gradient parameter name for the
        gradient blocks to find the unique indices for. If :py:obj:`None` (default), the
        unique indices of the :py:class:`TensorBlock` containing the values will be
        calculated.
    :return: a sorted :py:class:`Labels` object containing the unique metadata for the
        blocks of the input ``tensor`` or its gradient blocks for the specified
        parameter. Each element in the returned :py:class:`Labels` object has
        len(``names``) entries.
    """
    # Parse input args
    if not torch_jit_is_scripting():
        if not check_isinstance(tensor, TensorMap):
            raise TypeError(
                f"`tensor` must be a metatensor TensorMap, not {type(tensor)}"
            )
    names = (
        [names]
        if isinstance(names, str)
        else (list(names) if isinstance(names, tuple) else names)
    )
    _check_args(axis, names, gradient)
    # Make a list of the blocks to find unique indices for
    if gradient is None:
        blocks = tensor.blocks()
    else:
        blocks = [block.gradient(gradient) for block in tensor.blocks()]
    return _unique_from_blocks(blocks, axis, names)
[docs]
@torch_jit_script
def unique_metadata_block(
    block: TensorBlock,
    axis: str,
    names: Union[List[str], Tuple[str], str],
    gradient: Optional[str] = None,
) -> Labels:
    """
    Returns a :py:class:`Labels` object containing the unique metadata in the input
    :py:class:`TensorBlock`  ``block``, for the specified ``axis`` (either ``"samples"``
    or ``"properties"``) and metadata ``names``.
    Passing ``gradient`` as a ``str`` corresponding to a gradient parameter (for
    instance ``"strain"`` or ``"positions"``) returns the unique indices only for the
    gradient block associated with ``block``. Note that gradient blocks by definition
    have the same properties metadata as their parent :py:class:`TensorBlock`.
    An empty :py:class:`Labels` object is returned if there are no indices in the
    (gradient) blocks of ``tensor`` corresponding to the specified ``axis`` and
    ``names``. This will have length zero but the names will be the same as passed in
    ``names``.
    :param block: the :py:class:`TensorBlock` to find unique indices for.
    :param axis: a str, either ``"samples"`` or ``"properties"``, corresponding to the
        ``axis`` along which the named unique metadata should be found.
    :param names: a ``str``, ``list`` of ``str``, or ``tuple`` of ``str`` corresponding
        to the name(s) of the metadata along the specified ``axis`` for which the unique
        indices should be found.
    :param gradient: a ``str`` corresponding to the gradient parameter name for the
        gradient blocks to find the unique metadata for. If :py:obj:`None` (default),
        the unique metadata of the regular :py:class:`TensorBlock` objects will be
        calculated.
    :return: a sorted :py:class:`Labels` object containing the unique metadata for the
        input ``block`` or its gradient for the specified parameter. Each element in the
        returned :py:class:`Labels` object has len(``names``) entries.
    """
    # Parse input args
    if not torch_jit_is_scripting():
        if not check_isinstance(block, TensorBlock):
            raise TypeError(
                f"`block` must be a metatensor TensorBlock, not {type(block)}"
            )
    names = (
        [names]
        if isinstance(names, str)
        else (list(names) if isinstance(names, tuple) else names)
    )
    _check_args(axis, names, gradient)
    # Make a list of the blocks to find unique indices for
    if gradient is None:
        blocks = [block]
    else:
        blocks = [block.gradient(gradient)]
    return _unique_from_blocks(blocks, axis, names)