Source code for metatensor.operations.unique_metadata
"""
Module for finding unique metadata for TensorMaps and TensorBlocks
"""
from typing import List, Optional, Tuple, Union
from . import _dispatch
from ._classes import (
Labels,
TensorBlock,
TensorMap,
check_isinstance,
torch_jit_is_scripting,
)
[docs]
def unique_metadata(
tensor: TensorMap,
axis: str,
names: Union[List[str], Tuple[str], str],
gradient: Optional[str] = None,
) -> Labels:
"""
Returns a :py:class:`Labels` object containing the unique metadata across
all blocks of the input :py:class:`TensorMap` ``tensor``. Unique Labels are
returned for the specified ``axis`` (either ``"samples"`` or
``"properties"``) and metadata ``names``.
Passing ``gradient`` as a ``str`` corresponding to a gradient parameter (for
instance ``"cell"`` or ``"positions"``) returns the unique indices only for
the gradient blocks. Note that gradient blocks by definition have the same
properties metadata as their parent :py:class:`TensorBlock`.
An empty :py:class:`Labels` object is returned if there are no indices in
the (gradient) blocks of ``tensor`` corresponding to the specified ``axis``
and ``names``. This will have length zero but the names will be the same as
passed in ``names``.
For example, to find the unique ``"structure"`` indices in the ``"samples"``
metadata present in a given :py:class:`TensorMap`:
.. code-block:: python
import metatensor
unique_structures = metatensor.unique_metadata(
tensor,
axis="samples",
names=["structure"],
)
Or, to find the unique ``"atom"`` indices in the ``"samples"`` metadata
present in the ``"positions"`` gradient blocks of a given
:py:class:`TensorMap`:
.. code-block:: python
unique_grad_atoms = metatensor.unique_metadata(
tensor,
axis="samples",
names=["atom"],
gradient="positions",
)
The unique indices can then be used to split the :py:class:`TensorMap` into
several smaller :py:class:`TensorMap` objects. Say, for example, that the
``unique_structures`` from the example above are:
.. code-block:: python
Labels(
[(0,), (1,), (2,), (3,), (4,), (5,), (6,), (7,), (8,), (9,)],
dtype=[("structure", "<i4")],
)
Then, the following code will split the :py:class:`TensorMap` into 2
:py:class:`TensorMap` objects, with first containing structure indices 0-3
and the second containing structure indices 4-9:
.. code-block:: python
import metatensor
[tensor_1, tensor_2] = metatensor.split(
tensor,
axis="samples",
grouped_labels=[unique_structures[:4], unique_structures[4:]],
)
:param tensor: the :py:class:`TensorMap` to find unique indices for.
:param axis: a ``str``, either ``"samples"`` or ``"properties"``,
corresponding to the ``axis`` along which the named unique indices
should be found.
:param names: a ``str``, ``list`` of ``str``, or ``tuple`` of ``str``
corresponding to the name(s) of the indices along the specified ``axis``
for which the unique values should be found.
:param gradient: a ``str`` corresponding to the gradient parameter name for
the gradient blocks to find the unique indices for. If :py:obj:`None`
(default), the unique indices of the regular :py:class:`TensorBlock`
objects will be calculated.
:return: a sorted :py:class:`Labels` object containing the unique metadata
for the blocks of the input ``tensor`` or its gradient blocks for the
specified parameter. Each element in the returned :py:class:`Labels`
object has len(``names``) entries.
"""
# Parse input args
if not torch_jit_is_scripting():
if not check_isinstance(tensor, TensorMap):
raise TypeError(
f"`tensor` must be a metatensor TensorMap, not {type(tensor)}"
)
names = (
[names]
if isinstance(names, str)
else (list(names) if isinstance(names, tuple) else names)
)
_check_args(axis, names, gradient)
# Make a list of the blocks to find unique indices for
if gradient is None:
blocks = tensor.blocks()
else:
blocks = [block.gradient(gradient) for block in tensor.blocks()]
return _unique_from_blocks(blocks, axis, names)
[docs]
def unique_metadata_block(
block: TensorBlock,
axis: str,
names: Union[List[str], Tuple[str], str],
gradient: Optional[str] = None,
) -> Labels:
"""
Returns a :py:class:`Labels` object containing the unique metadata in the
input :py:class:`TensorBlock` ``block``, for the specified ``axis`` (either
``"samples"`` or ``"properties"``) and metadata ``names``.
Passing ``gradient`` as a ``str`` corresponding to a gradient parameter (for
instance ``"cell"`` or ``"positions"``) returns the unique indices only for
the gradient block associated with ``block``. Note that gradient blocks by
definition have the same properties metadata as their parent
:py:class:`TensorBlock`.
An empty :py:class:`Labels` object is returned if there are no indices in
the (gradient) blocks of ``tensor`` corresponding to the specified ``axis``
and ``names``. This will have length zero but the names will be the same as
passed in ``names``.
For example, to find the unique ``"structure"`` indices in the ``"samples"``
metadata present in a given :py:class:`TensorBlock`:
.. code-block:: python
import metatensor
unique_samples = metatensor.unique_metadata_block(
block,
axis="samples",
names=["structure"],
)
To find the unique ``"atom"`` indices along the ``"samples"`` axis present
in the ``"positions"`` gradient block of a given :py:class:`TensorBlock`:
.. code-block:: python
unique_grad_samples = metatensor.unique_metadata_block(
block,
axis="samples",
names=["atom"],
gradient="positions",
)
:param block: the :py:class:`TensorBlock` to find unique indices for.
:param axis: a str, either ``"samples"`` or ``"properties"``, corresponding
to the ``axis`` along which the named unique metadata should be found.
:param names: a ``str``, ``list`` of ``str``, or ``tuple`` of ``str``
corresponding to the name(s) of the metadata along the specified
``axis`` for which the unique indices should be found.
:param gradient: a ``str`` corresponding to the gradient parameter name for
the gradient blocks to find the unique metadata for. If :py:obj:`None`
(default), the unique metadata of the regular :py:class:`TensorBlock`
objects will be calculated.
:return: a sorted :py:class:`Labels` object containing the unique metadata
for the input ``block`` or its gradient for the specified parameter.
Each element in the returned :py:class:`Labels` object has
len(``names``) entries.
"""
# Parse input args
if not torch_jit_is_scripting():
if not check_isinstance(block, TensorBlock):
raise TypeError(
f"`block` must be a metatensor TensorBlock, not {type(block)}"
)
names = (
[names]
if isinstance(names, str)
else (list(names) if isinstance(names, tuple) else names)
)
_check_args(axis, names, gradient)
# Make a list of the blocks to find unique indices for
if gradient is None:
blocks = [block]
else:
blocks = [block.gradient(gradient)]
return _unique_from_blocks(blocks, axis, names)
def _unique_from_blocks(
blocks: List[TensorBlock],
axis: str,
names: List[str],
) -> Labels:
"""
Finds the unique metadata of a list of blocks along the given ``axis`` and
for the specified ``names``.
"""
all_values = []
for block in blocks:
if axis == "samples":
all_values.append(block.samples.view(names).values)
else:
assert axis == "properties"
all_values.append(block.properties.view(names).values)
unique_values = _dispatch.unique(_dispatch.concatenate(all_values, axis=0), axis=0)
return Labels(names=names, values=unique_values)
def _check_args(
axis: str,
names: List[str],
gradient: Optional[str] = None,
):
"""Checks input args for `unique_metadata_block`"""
if not torch_jit_is_scripting():
if not isinstance(axis, str):
raise TypeError(f"`axis` must be a string, not {type(axis)}")
if not isinstance(names, list):
raise TypeError(f"`names` must be a list of strings, not {type(names)}")
for name in names:
if not isinstance(name, str):
raise TypeError(f"`names` elements must be a strings, not {type(name)}")
if gradient is not None:
if not torch_jit_is_scripting():
if not isinstance(gradient, str):
raise TypeError(f"`gradient` must be a string, not {type(gradient)}")
if axis not in ["samples", "properties"]:
raise ValueError(
f"`axis` must be either 'samples' or 'properties', not '{axis}'"
)