Source code for metatensor.operations.sort

from typing import List, Union

from . import _dispatch
from ._backend import (
    Labels,
    TensorBlock,
    TensorMap,
    torch_jit_annotate,
    torch_jit_is_scripting,
    torch_jit_script,
)


def _sort_single_gradient_block(
    block: TensorBlock,
    gradient_block: TensorBlock,
    axes: List[str],
    descending: bool,
) -> TensorBlock:
    """
    Sorts a single gradient tensor block given the tensor block which the gradients are
    attached to. This function does not check the user input.  This is different from
    `_sort_single_block` because we need to update the sample differently (since
    gradient samples are pointers into the values samples).
    """

    sample_names = gradient_block.samples.names
    sample_values = gradient_block.samples.values

    component_names: List[List[str]] = []
    components_values = []
    for component in gradient_block.components:
        component_names.append(component.names)
        components_values.append(component.values)

    property_names = gradient_block.properties.names
    properties_values = gradient_block.properties.values

    values = gradient_block.values
    if "samples" in axes:
        # we first need to get the mapping induced by the sorting in its parent block
        # so we can change the sample column label entries so it matches the ones of
        # the parent block
        block_sample_values = block.samples.values
        # sample index -> sample labels
        sorted_idx = _dispatch.argsort_labels_values(
            block_sample_values, reverse=descending
        )
        # obtain inverse mapping sample labels -> sample index
        sorted_idx_inverse = _dispatch.empty_like(sorted_idx, shape=(len(sorted_idx),))
        sorted_idx_inverse[sorted_idx] = _dispatch.int_array_like(
            list(range(len(sorted_idx))), sorted_idx
        )
        # adapt sample column in gradient samples to the one of the sorted values of
        # the gradient_block the gradient is attached to
        sample_values = _dispatch.copy(sample_values)
        sample_values[:, 0] = sorted_idx_inverse[
            _dispatch.to_index_array(sample_values[:, 0])
        ]

        # sort the samples in gradient regularly moving the rows considering all columns
        sorted_idx = _dispatch.argsort_labels_values(sample_values, reverse=descending)
        sample_values = sample_values[sorted_idx]
        values = values[sorted_idx]
    if "components" in axes:
        for i, _ in enumerate(gradient_block.components):
            sorted_idx = _dispatch.argsort_labels_values(
                components_values[i], reverse=descending
            )
            components_values[i] = components_values[i][sorted_idx]
            values = _dispatch.take(values, sorted_idx, axis=i + 1)
    if "properties" in axes:
        sorted_idx = _dispatch.argsort_labels_values(
            properties_values, reverse=descending
        )
        properties_values = properties_values[sorted_idx]
        values = _dispatch.take(values, sorted_idx, axis=-1)

    samples_labels = Labels(names=sample_names, values=sample_values)
    properties_labels = Labels(names=property_names, values=properties_values)
    components_labels = [
        Labels(names=component_names[i], values=components_values[i])
        for i in range(len(component_names))
    ]

    return TensorBlock(
        values=values,
        samples=samples_labels,
        components=components_labels,
        properties=properties_labels,
    )


def _sort_single_block(
    block: TensorBlock,
    axes: List[str],
    descending: bool,
) -> TensorBlock:
    """
    Sorts a single TensorBlock without the user input checking and sorting of gradients
    """

    sample_names = block.samples.names
    sample_values = block.samples.values

    component_names: List[List[str]] = []
    components_values = []
    for component in block.components:
        component_names.append(component.names)
        components_values.append(component.values)

    property_names = block.properties.names
    properties_values = block.properties.values

    values = block.values
    if "samples" in axes:
        sorted_idx = _dispatch.argsort_labels_values(sample_values, reverse=descending)
        sample_values = sample_values[sorted_idx]
        values = values[sorted_idx]
    if "components" in axes:
        for i, _ in enumerate(block.components):
            sorted_idx = _dispatch.argsort_labels_values(
                components_values[i], reverse=descending
            )
            components_values[i] = components_values[i][sorted_idx]
            values = _dispatch.take(values, sorted_idx, axis=i + 1)
    if "properties" in axes:
        sorted_idx = _dispatch.argsort_labels_values(
            properties_values, reverse=descending
        )
        properties_values = properties_values[sorted_idx]
        values = _dispatch.take(values, sorted_idx, axis=-1)

    samples_labels = Labels(names=sample_names, values=sample_values)
    properties_labels = Labels(names=property_names, values=properties_values)
    components_labels = [
        Labels(names=component_names[i], values=components_values[i])
        for i in range(len(component_names))
    ]

    return TensorBlock(
        values=values,
        samples=samples_labels,
        components=components_labels,
        properties=properties_labels,
    )


[docs] @torch_jit_script def sort_block( block: TensorBlock, axes: Union[str, List[str]] = "all", descending: bool = False, ) -> TensorBlock: """ Rearrange the values of a block according to the order given by the sorted metadata of the given axes. This function creates copies of the metadata on the CPU to sort the metadata. :param axes: axes to sort. The labels entries along these axes will be sorted in lexicographic order, and the arrays values will be reordered accordingly. Possible values are ``'samples'``, ``'components'``, ``'properties'`` and ``'all'`` to sort everything. :param descending: if false, the order is ascending :return: sorted tensor block >>> import numpy as np >>> import metatensor >>> from metatensor import TensorBlock, TensorMap, Labels >>> block = TensorBlock( ... values=np.arange(9).reshape(3, 3), ... samples=Labels(["system", "atom"], np.array([[0, 3], [0, 1], [0, 2]])), ... components=[], ... properties=Labels(["n", "l"], np.array([[2, 0], [3, 0], [1, 0]])), ... ) >>> print(block) TensorBlock samples (3): ['system', 'atom'] components (): [] properties (3): ['n', 'l'] gradients: None >>> # sorting axes one by one >>> block_sorted_stepwise = metatensor.sort_block(block, axes=["properties"]) >>> # properties (last dimension of the array) are sorted >>> block_sorted_stepwise.values array([[2, 0, 1], [5, 3, 4], [8, 6, 7]]) >>> block_sorted_stepwise = metatensor.sort_block( ... block_sorted_stepwise, axes=["samples"] ... ) >>> # samples (first dimension of the array) are sorted >>> block_sorted_stepwise.values array([[5, 3, 4], [8, 6, 7], [2, 0, 1]]) >>> # sorting both samples and properties at the same time >>> sorted_block = metatensor.sort_block(block) >>> bool(np.all(sorted_block.values == block_sorted_stepwise.values)) True >>> # This function can also sort gradients: >>> sorted_block.add_gradient( ... parameter="g", ... gradient=TensorBlock( ... values=np.arange(18).reshape(3, 2, 3), ... samples=Labels(["sample"], np.array([[1], [2], [0]])), ... components=[Labels.range("direction", 2)], ... properties=sorted_block.properties, ... ), ... ) >>> sorted_grad_block = metatensor.sort_block(sorted_block) >>> sorted_grad_block.gradient("g").samples == Labels.range("sample", 3) True >>> sorted_grad_block.gradient("g").properties == sorted_block.properties True >>> # the components (middle dimensions) are also sorted: >>> sorted_grad_block.gradient("g").values array([[[12, 13, 14], [15, 16, 17]], <BLANKLINE> [[ 0, 1, 2], [ 3, 4, 5]], <BLANKLINE> [[ 6, 7, 8], [ 9, 10, 11]]]) """ if isinstance(axes, str): if axes == "all": axes_list = ["samples", "components", "properties"] else: axes_list = [axes] elif isinstance(axes, list): axes_list = axes else: if torch_jit_is_scripting(): extra = "" else: extra = f", not {type(axes)}" raise TypeError("'axes' should be a string or list of strings" + extra) for axis in axes_list: if axis not in ["samples", "components", "properties"]: raise ValueError( "`axes` must be one of 'samples', 'components' or 'properties', " f"not '{axis}'" ) result_block = _sort_single_block(block, axes_list, descending) for parameter, gradient in block.gradients(): if len(gradient.gradients_list()) != 0: raise NotImplementedError("gradients of gradients are not supported") result_block.add_gradient( parameter=parameter, gradient=_sort_single_gradient_block( block, gradient, axes_list, descending ), ) return result_block
[docs] @torch_jit_script def sort( tensor: TensorMap, axes: Union[str, List[str]] = "all", descending: bool = False, ) -> TensorMap: """ Sort the ``tensor`` according to the key values and the blocks for each specified axis in ``axes`` according to the label values along these axes. Each block is sorted separately, see :py:func:`sort_block` for more information Note: This function duplicates metadata on the CPU for the purpose of sorting. :param axes: axes to sort. The labels entries along these axes will be sorted in lexicographic order, and the arrays values will be reordered accordingly. Possible values are ``'keys'``, ``'samples'``, ``'components'``, ``'properties'`` and ``'all'`` to sort everything. :param descending: if false, the order is ascending :return: sorted tensor map >>> import numpy as np >>> import metatensor >>> from metatensor import TensorBlock, TensorMap, Labels >>> block_1 = TensorBlock( ... values=np.arange(9).reshape(3, 3), ... samples=Labels(["system", "atom"], np.array([[0, 3], [0, 1], [0, 2]])), ... components=[], ... properties=Labels(["n", "l"], np.array([[1, 0], [2, 0], [0, 0]])), ... ) >>> block_2 = TensorBlock( ... values=np.arange(3).reshape(1, 3), ... samples=Labels(["system", "atom"], np.array([[0, 0]])), ... components=[], ... properties=Labels(["n", "l"], np.array([[1, 0], [2, 0], [0, 0]])), ... ) >>> tensor = TensorMap( ... keys=Labels(["types"], np.array([[1], [0]])), blocks=[block_2, block_1] ... ) >>> metatensor.sort(tensor, axes="keys") TensorMap with 2 blocks keys: types 0 1 """ if isinstance(axes, str): axes_list: List[str] = [] if axes == "all": axes_list = ["samples", "components", "properties"] sort_keys = True elif axes == "keys": axes_list = torch_jit_annotate(List[str], []) sort_keys = True else: axes_list = [axes] sort_keys = False elif isinstance(axes, list): axes_list = axes if "keys" in axes_list: keys_index = axes_list.index("keys") sort_keys = True axes_list.pop(keys_index) else: sort_keys = False else: if torch_jit_is_scripting(): extra = "" else: extra = f", not {type(axes)}" raise TypeError("'axes' should be a string or list of strings" + extra) # Do we need to sort the keys? if sort_keys: sorted_idx = _dispatch.argsort_labels_values( tensor.keys.values, reverse=descending ) new_keys = Labels( names=tensor.keys.names, values=tensor.keys.values[sorted_idx], ) else: new_keys = tensor.keys sorted_idx = _dispatch.int_array_like( int_list=list(range(len(new_keys))), like=new_keys.values, ) # Do any required sorting on the blocks new_blocks: List[TensorBlock] = [] for i in sorted_idx: new_blocks.append( sort_block( block=tensor.block(tensor.keys[int(i)]), axes=axes_list, descending=descending, ) ) return TensorMap(new_keys, new_blocks)