from typing import List, Union
from . import _dispatch
from ._classes import (
    Labels,
    TensorBlock,
    TensorMap,
    check_isinstance,
    torch_jit_is_scripting,
)
from ._utils import _check_same_keys_raise
from .manipulate_dimension import remove_dimension
def _disjoint_tensor_labels(tensors: List[TensorMap], axis: str) -> bool:
    """Checks if all labels in a list of TensorMaps are disjoint.
    We have to perform a check from all tensors to all others to ensure it
    they are "fully" disjoint.
    """
    for i_tensor, first_tensor in enumerate(tensors[:-1]):
        for second_tensor in tensors[i_tensor + 1 :]:
            for key, first_block in first_tensor.items():
                second_block = second_tensor.block(key)
                if axis == "samples":
                    first_labels = first_block.samples
                    second_labels = second_block.samples
                elif axis == "properties":
                    first_labels = first_block.properties
                    second_labels = second_block.properties
                else:
                    raise ValueError(
                        "Only `'properties'` or `'samples'` are "
                        "valid values for the `axis` parameter."
                    )
                if len(first_labels.intersection(second_labels)):
                    return False
    return True
def _unique_str(str_list: List[str]):
    unique_strings: List[str] = []
    for string in str_list:
        if string not in unique_strings:
            unique_strings.append(string)
    return unique_strings
[docs]
def join(
    tensors: List[TensorMap],
    axis: str,
    different_keys: str = "error",
    remove_tensor_name: bool = False,
) -> TensorMap:
    """Join a sequence of :py:class:`TensorMap` with the same blocks along an axis.
    The ``axis`` parameter specifies the type of joining. For example, if
    ``axis='properties'`` the tensor maps in `tensors` will be joined along the
    `properties` dimension and for ``axis='samples'`` they will be the along the
    `samples` dimension.
    :param tensors:
        sequence of :py:class:`TensorMap` for join
    :param axis:
        A string indicating how the tensormaps are stacked. Allowed
        values are ``'properties'`` or ``'samples'``.
    :param different_keys: Method to handle different keys between the tensors. For
        ``"error"`` keys in all tensors have to be the same. For ``"intersection"`` only
        blocks present in all tensors will be taken into account. For ``"union"``
        missing keys will be treated like if they where associated with an empty block.
    :param remove_tensor_name:
        Remove the extra ``tensor`` dimension from labels if possible. See examples
        above for the case where this is applicable.
    :return tensor_joined:
        The stacked :py:class:`TensorMap` with more properties or samples
        than the input TensorMap.
    Examples
    --------
    Possible clashes of the meta data like ``samples``/``properties`` will be resolved
    by one of the three following strategies:
    1. If Labels names are the same, the values are unique and
       ``remove_tensor_name=True`` we keep the names and join the values
       >>> import numpy as np
       >>> import metatensor
       >>> from metatensor import Labels, TensorBlock, TensorMap
       >>> values = np.array([[1.1, 2.1, 3.1]])
       >>> samples = Labels("sample", np.array([[0]]))
       Define two disjoint :py:class:`Labels`.
       >>> properties_1 = Labels("n", np.array([[0], [2], [3]]))
       >>> properties_2 = Labels("n", np.array([[1], [4], [5]]))
       >>> block_1 = TensorBlock(
       ...     values=values,
       ...     samples=Labels.single(),
       ...     components=[],
       ...     properties=properties_1,
       ... )
       >>> block_2 = TensorBlock(
       ...     values=values,
       ...     samples=Labels.single(),
       ...     components=[],
       ...     properties=properties_2,
       ... )
       >>> tensor_1 = TensorMap(keys=Labels.single(), blocks=[block_1])
       >>> tensor_2 = TensorMap(keys=Labels.single(), blocks=[block_2])
       joining along the properties leads
       >>> joined_tensor = metatensor.join(
       ...     [tensor_1, tensor_2], axis="properties", remove_tensor_name=True
       ... )
       >>> joined_tensor[0].properties
       Labels(
           n
           0
           2
           3
           1
           4
           5
       )
       If ``remove_tensor_name=False`` There will be an extra dimension ``tensor``
       added
       >>> joined_tensor = metatensor.join(
       ...     [tensor_1, tensor_2], axis="properties", remove_tensor_name=False
       ... )
       >>> joined_tensor[0].properties
       Labels(
           tensor  n
             0     0
             0     2
             0     3
             1     1
             1     4
             1     5
       )
    2. If Labels names are the same but the values are not unique, a new dimension
       ``"tensor"`` is added to the names.
       >>> properties_3 = Labels("n", np.array([[0], [2], [3]]))
       ``properties_3`` has the same name and also shares values with ``properties_1``
       as defined above.
       >>> block_3 = TensorBlock(
       ...     values=values,
       ...     samples=Labels.single(),
       ...     components=[],
       ...     properties=properties_3,
       ... )
       >>> tensor_3 = TensorMap(keys=Labels.single(), blocks=[block_3])
       joining along properties leads to
       >>> joined_tensor = metatensor.join([tensor_1, tensor_3], axis="properties")
       >>> joined_tensor[0].properties
       Labels(
           tensor  n
             0     0
             0     2
             0     3
             1     0
             1     2
             1     3
       )
    3. If Labels names are different we change the names to ("tensor", "property"). This
       case is only supposed to happen when joining in the property dimension, hence the
       choice of names:
       >>> properties_4 = Labels(["a", "b"], np.array([[0, 0], [1, 2], [1, 3]]))
       ``properties_4`` has the different names compared to ``properties_1``
       defined above.
       >>> block_4 = TensorBlock(
       ...     values=values,
       ...     samples=Labels.single(),
       ...     components=[],
       ...     properties=properties_4,
       ... )
       >>> tensor_4 = TensorMap(keys=Labels.single(), blocks=[block_4])
       joining along properties leads to
        >>> joined_tensor = metatensor.join([tensor_1, tensor_4], axis="properties")
        >>> joined_tensor[0].properties
        Labels(
            tensor  property
              0        0
              0        1
              0        2
              1        0
              1        1
              1        2
        )
    """
    if not torch_jit_is_scripting():
        if not isinstance(tensors, (list, tuple)):
            raise TypeError(f"`tensor` must be a list or a tuple, not {type(tensors)}")
        for tensor in tensors:
            if not check_isinstance(tensor, TensorMap):
                raise TypeError(
                    "`tensors` elements must be metatensor TensorMap, "
                    f"not {type(tensor)}"
                )
    if len(tensors) < 1:
        raise ValueError("provide at least one `TensorMap` for joining")
    if axis not in ("samples", "properties"):
        raise ValueError(
            "Only `'properties'` or `'samples'` are "
            "valid values for the `axis` parameter."
        )
    if len(tensors) == 1:
        return tensors[0]
    if different_keys == "error":
        for ts_to_join in tensors[1:]:
            _check_same_keys_raise(tensors[0], ts_to_join, "join")
    elif different_keys == "intersection":
        tensors = _tensors_intersection(tensors)
    elif different_keys == "union":
        tensors = _tensors_union(tensors, axis=axis)
    else:
        raise ValueError(
            f"'{different_keys}' is not a valid option for `different_keys`. Choose "
            "either 'error', 'intersection' or 'union'."
        )
    # Deduce if sample/property names are the same in all tensors.
    # If this is not the case we have to change unify the corresponding labels later.
    if axis == "samples":
        names_list = [tensor.sample_names for tensor in tensors]
    else:
        names_list = [tensor.property_names for tensor in tensors]
    # We use functools to flatten a list of sublists::
    #
    #   [('a', 'b', 'c'), ('a', 'b')] -> ['a', 'b', 'c', 'a', 'b']
    #
    # A nested list with sublist of different shapes can not be handled by `set`.
    names_list_flattened: List[str] = []
    for names in names_list:
        names_list_flattened += names
    unique_names = _unique_str(names_list_flattened)
    length_equal = [len(unique_names) == len(names) for names in names_list]
    names_are_same = sum(length_equal) == len(length_equal)
    # It's fine to lose metadata on the property axis, less so on the sample axis!
    if axis == "samples" and not names_are_same:
        raise ValueError(
            "Sample names are not the same! Joining along samples with different "
            "sample names will loose information and is not supported."
        )
    keys = tensors[0].keys
    n_tensors = len(tensors)
    n_keys_dimensions = 1 + keys.values.shape[1]
    new_keys_values = _dispatch.empty_like(
        array=keys.values,
        shape=[n_tensors, keys.values.shape[0], n_keys_dimensions],
    )
    for i, tensor in enumerate(tensors):
        for j, value in enumerate(tensor.keys.values):
            new_keys_values[i, j, 0] = i
            new_keys_values[i, j, 1:] = value
    keys = Labels(
        names=["tensor"] + keys.names,
        values=new_keys_values.reshape(-1, n_keys_dimensions),
    )
    blocks: List[TensorBlock] = []
    for tensor in tensors:
        for block in tensor.blocks():
            # We would already raised an error if `axis == "samples"`. Therefore, we can
            # neglect the check for `axis == "properties"`.
            if names_are_same:
                properties = block.properties
            else:
                properties = Labels.range("property", len(block.properties))
            new_block = TensorBlock(
                values=block.values,
                samples=block.samples,
                components=block.components,
                properties=properties,
            )
            for parameter, gradient in block.gradients():
                if len(gradient.gradients_list()) != 0:
                    raise NotImplementedError(
                        "gradients of gradients are not supported"
                    )
                new_block.add_gradient(
                    parameter=parameter,
                    gradient=TensorBlock(
                        values=gradient.values,
                        samples=gradient.samples,
                        components=gradient.components,
                        properties=new_block.properties,
                    ),
                )
            blocks.append(new_block)
    tensor = TensorMap(keys=keys, blocks=blocks)
    if axis == "samples":
        tensor_joined = tensor.keys_to_samples("tensor")
    else:
        tensor_joined = tensor.keys_to_properties("tensor")
    if remove_tensor_name and _disjoint_tensor_labels(tensors, axis):
        return remove_dimension(tensor_joined, name="tensor", axis=axis)
    else:
        return tensor_joined 
def _tensors_intersection(tensors: List[TensorMap]) -> List[TensorMap]:
    """Create a new tensors list where keys are based on the intersection from all
    tensors.
    Blocks corresponding to keys that are not present in all tensor will be discarded.
    """
    # Construct a Labels object with intersected keys
    all_keys = tensors[0].keys
    for tensor in tensors[1:]:
        all_keys = all_keys.intersection(tensor.keys)
    # Create new blocks and discard bocks not present in all_keys
    new_tensors: List[TensorMap] = []
    for tensor in tensors:
        new_blocks: List[TensorBlock] = []
        for i_key in range(all_keys.values.shape[0]):
            new_blocks.append(tensor.block(all_keys.entry(i_key)).copy())
        new_tensors.append(TensorMap(keys=all_keys, blocks=new_blocks))
    return new_tensors
def _tensors_union(tensors: List[TensorMap], axis: str) -> List[TensorMap]:
    """Create a new tensors list where keys are based on the union from all tensors.
    Missing keys will be filled by empty blocks having containing no labels in the
    ``axis`` dimension.
    """
    # Construct a Labels object with all keys
    all_keys = tensors[0].keys
    for tensor in tensors[1:]:
        all_keys = all_keys.union(tensor.keys)
    # Create empty blocks for missing keys for each TensorMap
    new_tensors: List[TensorMap] = []
    for tensor in tensors:
        _, map, _ = all_keys.intersection_and_mapping(tensor.keys)
        missing_keys = Labels(
            names=tensor.keys.names, values=all_keys.values[map == -1]
        )
        new_keys = tensor.keys.union(missing_keys)
        new_blocks = [block.copy() for block in tensor.blocks()]
        for i_key in range(missing_keys.values.shape[0]):
            key = missing_keys.entry(i_key)
            # Find corresponding block with the missing key
            reference_block: Union[None, TensorBlock] = None
            for reference_tensor in tensors:
                if key in reference_tensor.keys:
                    reference_block = reference_tensor.block(key)
                    break
            # There should be a block with the key otherwise we did something wrong
            assert reference_block is not None
            # Construct new block with zero samples based on the metadata of
            # reference_block
            if axis == "samples":
                values = _dispatch.empty_like(
                    array=reference_block.values,
                    shape=(0,) + reference_block.values.shape[1:],
                )
                samples = Labels.empty(reference_block.samples.names)
                properties = reference_block.properties
            else:
                assert axis == "properties"
                values = _dispatch.empty_like(
                    array=reference_block.values,
                    shape=reference_block.values.shape[:-1] + (0,),
                )
                samples = reference_block.samples
                properties = Labels.empty(reference_block.properties.names)
            new_block = TensorBlock(
                values=values,
                samples=samples,
                components=reference_block.components,
                properties=properties,
            )
            for parameter, gradient in reference_block.gradients():
                if len(gradient.gradients_list()) != 0:
                    raise NotImplementedError(
                        "gradients of gradients are not supported"
                    )
                if axis == "samples":
                    values = _dispatch.empty_like(
                        array=gradient.values,
                        shape=(0,) + gradient.values.shape[1:],
                    )
                    gradient_samples = Labels.empty(gradient.samples.names)
                else:
                    values = _dispatch.empty_like(
                        array=gradient.values,
                        shape=gradient.values.shape[:-1] + (0,),
                    )
                    gradient_samples = gradient.samples
                new_block.add_gradient(
                    parameter=parameter,
                    gradient=TensorBlock(
                        values=values,
                        samples=gradient_samples,
                        components=gradient.components,
                        properties=properties,
                    ),
                )
            new_blocks.append(new_block)
        new_tensors.append(TensorMap(keys=new_keys, blocks=new_blocks))
    return new_tensors