from typing import Dict, List
from ._classes import (
    Labels,
    TensorBlock,
    TensorMap,
    check_isinstance,
    torch_jit_is_scripting,
)
from .slice import _slice_block
[docs]
def split(
    tensor: TensorMap,
    axis: str,
    grouped_labels: List[Labels],
) -> List[TensorMap]:
    """Split a :py:class:`TensorMap` into multiple :py:class:`TensorMap`.
    The operation is based on some specified groups of indices, along either the
    "samples" or "properties" ``axis``. The number of returned :py:class:`TensorMap`s is
    equal to the number of :py:class:`Labels` objects passed in ``grouped_labels``. Each
    returned :py:class`TensorMap` will have the same keys and number of blocks at the
    input ``tensor``, but with the dimensions of the blocks reduced to only contain the
    specified indices for the corresponding group.
    For example, to split a tensor along the "samples" axis, according to the
    "structure" index, where structures 0, 6, and 7 are in the first returned
    :py:class`TensorMap`; 2, 3, and 4 in the second; and 1, 5, 8, 9, and 10 in the
    third:
    .. code-block:: python
        import metatensor
        tensor_splitted = metatensor.split(
            tensor,
            axis="samples",
            grouped_labels=[
                Labels(names=["structure"], values=np.array([[0], [6], [7]])),
                Labels(names=["structure"], values=np.array([[2], [3], [4]])),
                Labels(names=["structure"], values=np.array([[1], [5], [8], [10]])),
            ],
        )
    :param tensor: a :py:class:`TensorMap` to be split
    :param axis: a str, either "samples" or "properties", that indicates the
        :py:class:`TensorBlock` axis along which the named index (or indices) in
        ``grouped_labels`` belongs. Each :py:class:`TensorBlock` in each returned
        :py:class:`TensorMap` could have a reduced dimension along this axis, but the
        other axes will remain the same size.
    :param grouped_labels: a list of :py:class:`Labels` containing the names and values
        of the indices along the specified ``axis`` which should be in each respective
        output :py:class:`TensorMap`.
    :return: a list of:py:class:`TensorMap` that corresponds to the split input
        ``tensor``. Each tensor in the returned list contains only the named indices in
        the respective py:class:`Labels` object of ``grouped_labels``.
    """
    # Check input args
    if not torch_jit_is_scripting():
        if not check_isinstance(tensor, TensorMap):
            raise TypeError(
                f"`tensor` must be a metatensor TensorMap, not {type(tensor)}"
            )
    _check_args(tensor.block(0), axis, grouped_labels)
    all_new_blocks: Dict[int, List[TensorBlock]] = {}
    for group_i in range(len(grouped_labels)):
        empty_list: List[TensorBlock] = []
        all_new_blocks[group_i] = empty_list
    for key_index in range(len(tensor.keys)):
        key = tensor.keys.entry(key_index)
        new_blocks = _split_block(tensor[key], axis, grouped_labels)
        for group_i, new_block in enumerate(new_blocks):
            all_new_blocks[group_i].append(new_block)
    return [
        TensorMap(keys=tensor.keys, blocks=all_new_blocks[group_i])
        for group_i in range(len(grouped_labels))
    ] 
[docs]
def split_block(
    block: TensorBlock,
    axis: str,
    grouped_labels: List[Labels],
) -> List[TensorBlock]:
    """
    Splits an input :py:class:`TensorBlock` into multiple :py:class:`TensorBlock`
    objects based on some specified ``grouped_labels``, along either the "samples" or
    "properties" ``axis``. The number of returned :py:class:`TensorBlock`s is equal to
    the number of :py:class:`Labels` objects passed in ``grouped_labels``. Each returned
    :py:class`TensorBlock` will have the same keys and number of blocks at the input
    ``tensor``, but with the dimensions of the blocks reduced to only contain the
    specified indices for the corresponding group.
    For example, to split a block along the "samples" axis, according to the "structure"
    index, where structures 0, 6, and 7 are in the first returned :py:class`TensorMap`;
    2, 3, and 4 in the second; and 1, 5, 8, 9, and 10 in the third:
    .. code-block:: python
        import metatensor
        block_splitted = metatensor.split_block(
            block,
            axis="samples",
            grouped_labels=[
                Labels(names=["structure"], values=np.array([[0], [6]])),
                Labels(names=["structure"], values=np.array([[2], [3]])),
                Labels(names=["structure"], values=np.array([[1], [5], [10]])),
            ],
        )
    :param block: a :py:class:`TensorBlock` to be split
    :param axis: a str, either "samples" or "properties", that indicates the
        :py:class:`TensorBlock` axis along which the named index (or indices) in
        ``grouped_labels`` belongs. Each :py:class:`TensorBlock` returned could have a
        reduced dimension along this axis, but the other axes will remain the same size.
    :param grouped_labels: a list of :py:class:`Labels` containing the names and values
        of the indices along the specified ``axis`` which should be in each respective
        output :py:class:`TensorBlock`.
    :return: a list of:py:class:`TensorBlock` that corresponds to the split input
        ``block``. Each block in the returned list contains only the named indices in
        the respective py:class:`Labels` object of ``grouped_labels``.
    """
    # Check input args
    if not torch_jit_is_scripting():
        if not check_isinstance(block, TensorBlock):
            raise TypeError(
                f"`block` must be a metatensor TensorBlock, not {type(block)}"
            )
    _check_args(block, axis, grouped_labels)
    return _split_block(block, axis, grouped_labels) 
def _split_block(
    block: TensorBlock,
    axis: str,
    grouped_labels: List[Labels],
) -> List[TensorBlock]:
    """
    Splits a TensorBlock into mutliple blocks, as in the public function
    :py:func:`split_block` but with no input checks. Note that the block is
    currently split into N new blocks by performing N number of slice
    operations. There may be a more efficient way of doing it, but this is not
    yet implemented.
    """
    new_blocks: List[TensorBlock] = []
    for indices in grouped_labels:
        # perform the slice either along the samples or properties axis
        new_block = _slice_block(block, axis=axis, labels=indices)
        new_blocks.append(new_block)
    return new_blocks
def _check_args(block: TensorBlock, axis: str, grouped_labels: List[Labels]):
    """
    Checks the arguments passed to :py:func:`split` and :py:func:`split_block`.
    """
    # Check types
    if not torch_jit_is_scripting():
        if not isinstance(axis, str):
            raise TypeError(f"axis must be a string, not {type(axis)}")
        if not isinstance(grouped_labels, list):
            raise TypeError(
                f"`grouped_labels` must be a list, not {type(grouped_labels)}"
            )
        for labels in grouped_labels:
            if not check_isinstance(labels, Labels):
                raise TypeError(
                    "`grouped_labels` elements must be metatensor Labels, "
                    f"not {type(labels)}"
                )
    if axis not in ["samples", "properties"]:
        raise ValueError("axis must be either 'samples' or 'properties'")
    # If passed as an empty list, return now
    if len(grouped_labels) == 0:
        return
    # Check the Labels names are equivalent for all Labels in grouped_labels
    reference_names = grouped_labels[0].names
    for labels in grouped_labels[1:]:
        if labels.names != reference_names:
            raise ValueError(
                "the dimensions names of all Labels in `grouped_labels`"
                f" must be the same, got {reference_names} and {labels.names}"
            )
    # Check the names in grouped_labels Labels are contained within the names for
    # the block
    names = block.samples.names if axis == "samples" else block.properties.names
    for name in reference_names:
        if name not in names:
            raise ValueError(
                f"the '{name}' dimension name in `grouped_labels` is not part of "
                f"the {axis} names of the input tensor"
            )