from typing import List, Union
from . import _dispatch
from ._classes import Labels, TensorBlock, TensorMap, torch_jit_is_scripting
def _sort_single_block(
    block: TensorBlock,
    axes: List[str],
    descending: bool,
) -> TensorBlock:
    """
    Sorts a single TensorBlock without the user input checking and sorting of gradients
    """
    sample_names = block.samples.names
    sample_values = block.samples.values
    component_names: List[List[str]] = []
    components_values = []
    for component in block.components:
        component_names.append(component.names)
        components_values.append(component.values)
    property_names = block.properties.names
    properties_values = block.properties.values
    values = block.values
    if "samples" in axes:
        sorted_idx = _dispatch.argsort_labels_values(sample_values, reverse=descending)
        sample_values = sample_values[sorted_idx]
        values = values[sorted_idx]
    if "components" in axes:
        for i, _ in enumerate(block.components):
            sorted_idx = _dispatch.argsort_labels_values(
                components_values[i], reverse=descending
            )
            components_values[i] = components_values[i][sorted_idx]
            values = _dispatch.take(values, sorted_idx, axis=i + 1)
    if "properties" in axes:
        sorted_idx = _dispatch.argsort_labels_values(
            properties_values, reverse=descending
        )
        properties_values = properties_values[sorted_idx]
        values = _dispatch.take(values, sorted_idx, axis=-1)
    samples_labels = Labels(names=sample_names, values=sample_values)
    properties_labels = Labels(names=property_names, values=properties_values)
    components_labels = [
        Labels(names=component_names[i], values=components_values[i])
        for i in range(len(component_names))
    ]
    return TensorBlock(
        values=values,
        samples=samples_labels,
        components=components_labels,
        properties=properties_labels,
    )
[docs]
def sort_block(
    block: TensorBlock,
    axes: Union[str, List[str]] = "all",
    descending: bool = False,
) -> TensorBlock:
    """
    Rearrange the values of a block according to the order given by the sorted metadata
    of the given axes.
    This function creates copies of the metadata on the CPU to sort the metadata.
    :param axes: axes to sort. The labels entries along these axes will be sorted in
        lexicographic order, and the arrays values will be reordered accordingly.
        Possible values are ``'samples'``, ``'components'``, ``'properties'`` and
        ``'all'`` to sort everything.
    :param descending: if false, the order is ascending
    :return: sorted tensor block
    >>> import numpy as np
    >>> import metatensor
    >>> from metatensor import TensorBlock, TensorMap, Labels
    >>> block = TensorBlock(
    ...     values=np.arange(9).reshape(3, 3),
    ...     samples=Labels(
    ...         ["structures", "centers"], np.array([[0, 3], [0, 1], [0, 2]])
    ...     ),
    ...     components=[],
    ...     properties=Labels(["n", "l"], np.array([[2, 0], [3, 0], [1, 0]])),
    ... )
    >>> print(block)
    TensorBlock
        samples (3): ['structures', 'centers']
        components (): []
        properties (3): ['n', 'l']
        gradients: None
    >>> # sorting axes one by one
    >>> block_sorted_stepwise = metatensor.sort_block(block, axes=["properties"])
    >>> # properties (last dimension of the array) are sorted
    >>> block_sorted_stepwise.values
    array([[2, 0, 1],
           [5, 3, 4],
           [8, 6, 7]])
    >>> block_sorted_stepwise = metatensor.sort_block(
    ...     block_sorted_stepwise, axes=["samples"]
    ... )
    >>> # samples (first dimension of the array) are sorted
    >>> block_sorted_stepwise.values
    array([[5, 3, 4],
           [8, 6, 7],
           [2, 0, 1]])
    >>> # sorting both samples and properties at the same time
    >>> sorted_block = metatensor.sort_block(block)
    >>> np.all(sorted_block.values == block_sorted_stepwise.values)
    True
    >>> # This function can also sort gradients:
    >>> sorted_block.add_gradient(
    ...     parameter="g",
    ...     gradient=TensorBlock(
    ...         values=np.arange(18).reshape(3, 2, 3),
    ...         samples=Labels(["sample"], np.array([[1], [2], [0]])),
    ...         components=[Labels.range("direction", 2)],
    ...         properties=sorted_block.properties,
    ...     ),
    ... )
    >>> sorted_grad_block = metatensor.sort_block(sorted_block)
    >>> sorted_grad_block.gradient("g").samples == Labels.range("sample", 3)
    True
    >>> sorted_grad_block.gradient("g").properties == sorted_block.properties
    True
    >>> # the components (middle dimensions) are also sorted:
    >>> sorted_grad_block.gradient("g").values
    array([[[12, 13, 14],
            [15, 16, 17]],
    <BLANKLINE>
           [[ 0,  1,  2],
            [ 3,  4,  5]],
    <BLANKLINE>
           [[ 6,  7,  8],
            [ 9, 10, 11]]])
    """
    if isinstance(axes, str):
        if axes == "all":
            axes_list = ["samples", "components", "properties"]
        else:
            axes_list = [axes]
    elif isinstance(axes, list):
        axes_list = axes
    else:
        if torch_jit_is_scripting():
            extra = ""
        else:
            extra = f", not {type(axes)}"
        raise TypeError("'axes' should be a string or list of strings" + extra)
    for axis in axes_list:
        if axis not in ["samples", "components", "properties"]:
            raise ValueError(
                "`axes` must be one of 'samples', 'components' or 'properties', "
                f"not '{axis}'"
            )
    result_block = _sort_single_block(block, axes_list, descending)
    for parameter, gradient in block.gradients():
        if len(gradient.gradients_list()) != 0:
            raise NotImplementedError("gradients of gradients are not supported")
        result_block.add_gradient(
            parameter=parameter,
            gradient=_sort_single_block(gradient, axes_list, descending),
        )
    return result_block 
[docs]
def sort(
    tensor: TensorMap,
    axes: Union[str, List[str]] = "all",
    descending: bool = False,
) -> TensorMap:
    """
    Sort the ``tensor`` according to the key values and the blocks for each specified
    axis in ``axes`` according to the label values along these axes.
    Each block is sorted separately, see :py:func:`sort_block` for more information
    Note: This function duplicates metadata on the CPU for the purpose of sorting.
    :param axes: axes to sort. The labels entries along these axes will be sorted in
        lexicographic order, and the arrays values will be reordered accordingly.
        Possible values are ``'keys'``, ``'samples'``, ``'components'``,
        ``'properties'`` and ``'all'`` to sort everything.
    :param descending: if false, the order is ascending
    :return: sorted tensor map
    >>> import numpy as np
    >>> import metatensor
    >>> from metatensor import TensorBlock, TensorMap, Labels
    >>> block_1 = TensorBlock(
    ...     values=np.arange(9).reshape(3, 3),
    ...     samples=Labels(
    ...         ["structures", "centers"], np.array([[0, 3], [0, 1], [0, 2]])
    ...     ),
    ...     components=[],
    ...     properties=Labels(["n", "l"], np.array([[1, 0], [2, 0], [0, 0]])),
    ... )
    >>> block_2 = TensorBlock(
    ...     values=np.arange(3).reshape(1, 3),
    ...     samples=Labels(["structures", "centers"], np.array([[0, 0]])),
    ...     components=[],
    ...     properties=Labels(["n", "l"], np.array([[1, 0], [2, 0], [0, 0]])),
    ... )
    >>> tm = TensorMap(
    ...     keys=Labels(["species"], np.array([[1], [0]])), blocks=[block_2, block_1]
    ... )
    >>> metatensor.sort(tm, axes="keys")
    TensorMap with 2 blocks
    keys: species
             0
             1
    """
    if isinstance(axes, str):
        axes_list: List[str] = []
        if axes == "all":
            axes_list = ["samples", "components", "properties"]
            sort_keys = True
        elif axes == "keys":
            axes_list = []
            sort_keys = True
        else:
            axes_list = [axes]
            sort_keys = False
    elif isinstance(axes, list):
        axes_list = axes
        if "keys" in axes_list:
            keys_index = axes_list.index("keys")
            sort_keys = True
            axes_list.pop(keys_index)
        else:
            sort_keys = False
    else:
        if torch_jit_is_scripting():
            extra = ""
        else:
            extra = f", not {type(axes)}"
        raise TypeError("'axes' should be a string or list of strings" + extra)
    # Do we need to sort the keys?
    if sort_keys:
        sorted_idx = _dispatch.argsort_labels_values(
            tensor.keys.values, reverse=descending
        )
        new_keys = Labels(
            names=tensor.keys.names,
            values=tensor.keys.values[sorted_idx],
        )
    else:
        new_keys = tensor.keys
        sorted_idx = _dispatch.int_array_like(
            int_list=list(range(len(new_keys))),
            like=new_keys.values,
        )
    # Do any required sorting on the blocks
    new_blocks: List[TensorBlock] = []
    for i in sorted_idx:
        new_blocks.append(
            sort_block(
                block=tensor.block(tensor.keys[int(i)]),
                axes=axes_list,
                descending=descending,
            )
        )
    return TensorMap(new_keys, new_blocks)