from typing import List, Optional, Union
from . import _dispatch
from ._backend import (
Labels,
TensorBlock,
TensorMap,
isinstance_metatensor,
torch_jit_is_scripting,
torch_jit_script,
)
from ._utils import (
_check_blocks_raise,
_check_same_gradients_raise,
_check_same_keys_raise,
)
def _unique_str(str_list: List[str]):
unique_strings: List[str] = []
for string in str_list:
if string not in unique_strings:
unique_strings.append(string)
return unique_strings
def _tensors_intersection(tensors: List[TensorMap]) -> List[TensorMap]:
"""Create a new tensors list where keys are based on the intersection from all
tensors.
Blocks corresponding to keys that are not present in all tensor will be discarded.
"""
# Construct a Labels object with intersected keys
all_keys = tensors[0].keys
for tensor in tensors[1:]:
all_keys = all_keys.intersection(tensor.keys)
# Create new blocks and discard bocks not present in all_keys
new_tensors: List[TensorMap] = []
for tensor in tensors:
new_blocks: List[TensorBlock] = []
for i_key in range(all_keys.values.shape[0]):
new_blocks.append(tensor.block(all_keys.entry(i_key)).copy())
new_tensors.append(TensorMap(keys=all_keys, blocks=new_blocks))
return new_tensors
def _tensors_union(tensors: List[TensorMap], axis: str) -> List[TensorMap]:
"""Create a new tensors list where keys are based on the union from all tensors.
Missing keys will be filled by empty blocks having containing no labels in the
``axis`` dimension.
"""
# Construct a Labels object with all keys
all_keys = tensors[0].keys
for tensor in tensors[1:]:
all_keys = all_keys.union(tensor.keys)
# Create empty blocks for missing keys for each TensorMap
new_tensors: List[TensorMap] = []
for tensor in tensors:
_, map, _ = all_keys.intersection_and_mapping(tensor.keys)
missing_keys = Labels(
names=tensor.keys.names, values=all_keys.values[map == -1]
)
new_keys = tensor.keys.union(missing_keys)
new_blocks = [block.copy() for block in tensor.blocks()]
for i_key in range(missing_keys.values.shape[0]):
key = missing_keys.entry(i_key)
# Find corresponding block with the missing key
reference_block: Union[None, TensorBlock] = None
for reference_tensor in tensors:
if key in reference_tensor.keys:
reference_block = reference_tensor.block(key)
break
# There should be a block with the key otherwise we did something wrong
assert reference_block is not None
# Construct new block with zero samples based on the metadata of
# reference_block
if axis == "samples":
values = _dispatch.empty_like(
array=reference_block.values,
shape=(0,) + reference_block.values.shape[1:],
)
samples = Labels(
names=reference_block.samples.names,
values=_dispatch.empty_like(
reference_block.samples.values,
(0, len(reference_block.samples.names)),
),
)
properties = reference_block.properties
else:
assert axis == "properties"
values = _dispatch.empty_like(
array=reference_block.values,
shape=reference_block.values.shape[:-1] + (0,),
)
samples = reference_block.samples
properties = Labels(
names=reference_block.properties.names,
values=_dispatch.empty_like(
reference_block.properties.values,
(0, len(reference_block.properties.names)),
),
)
new_block = TensorBlock(
values=values,
samples=samples,
components=reference_block.components,
properties=properties,
)
for parameter, gradient in reference_block.gradients():
if len(gradient.gradients_list()) != 0:
raise NotImplementedError(
"gradients of gradients are not supported"
)
if axis == "samples":
values = _dispatch.empty_like(
array=gradient.values,
shape=(0,) + gradient.values.shape[1:],
)
gradient_samples = Labels(
names=gradient.samples.names,
values=_dispatch.empty_like(
gradient.samples.values, (0, len(gradient.samples.names))
),
)
else:
values = _dispatch.empty_like(
array=gradient.values,
shape=gradient.values.shape[:-1] + (0,),
)
gradient_samples = gradient.samples
new_block.add_gradient(
parameter=parameter,
gradient=TensorBlock(
values=values,
samples=gradient_samples,
components=gradient.components,
properties=properties,
),
)
new_blocks.append(new_block)
new_tensors.append(TensorMap(keys=new_keys, blocks=new_blocks))
return new_tensors
def _join_block_samples(
blocks: List[TensorBlock],
fname: str,
add_dimension: Optional[str],
) -> TensorBlock:
"""
Actual implementation of block joining along samples.
:param blocks: blocks to join
:param fname: name of the function calling ``_join_block_samples``, to be used in
error messages
:param add_dimension: name of the dimension to add to the samples
"""
assert len(blocks) != 0
if len(blocks) == 1:
return blocks[0]
first_block = blocks[0]
n_joined_samples = 0
for block in blocks:
n_joined_samples += block.values.shape[0]
_check_blocks_raise(first_block, block, fname, ["components", "properties"])
_check_same_gradients_raise(first_block, block, fname, ["components"])
shape = list(first_block.values.shape)
shape[0] = n_joined_samples
new_values = _dispatch.empty_like(first_block.values, shape=shape)
first_sample_size = first_block.samples.values.shape[1]
if add_dimension is None:
samples_size = first_sample_size
else:
samples_size = first_sample_size + 1
new_samples = _dispatch.empty_like(
first_block.samples.values,
shape=[n_joined_samples, samples_size],
)
start = 0
for block_i, block in enumerate(blocks):
stop = start + len(block.samples)
new_values[start:stop] = block.values[:]
new_samples[start:stop, :first_sample_size] = block.samples.values[:, :]
if add_dimension is not None:
new_samples[start:stop, first_sample_size] = block_i
start = stop
new_samples_names = first_block.samples.names
if add_dimension is not None:
new_samples_names.append(add_dimension)
new_block = TensorBlock(
new_values,
Labels(new_samples_names, new_samples),
first_block.components,
first_block.properties,
)
for parameter in first_block.gradients_list():
first_gradient = first_block.gradient(parameter)
gradients: List[TensorBlock] = []
n_gradient_samples = 0
for block in blocks:
gradient = block.gradient(parameter)
if len(gradient.gradients_list()) != 0:
raise NotImplementedError(
"gradients of gradients are not yet supported"
)
n_gradient_samples += len(gradient.samples)
gradients.append(gradient)
shape = list(first_gradient.values.shape)
shape[0] = n_gradient_samples
new_values = _dispatch.empty_like(first_gradient.values, shape=shape)
new_samples = _dispatch.empty_like(
first_gradient.samples.values,
shape=[n_gradient_samples, first_gradient.samples.values.shape[1]],
)
start = 0
sample_shift = 0
for block, gradient in zip(blocks, gradients):
stop = start + len(gradient.samples)
new_values[start:stop] = gradient.values[:]
new_samples[start:stop] = gradient.samples.values[:]
# update the "sample" dimension, matching the shift in the values
new_samples[start:stop, 0] += sample_shift
sample_shift += len(block.samples)
start = stop
new_gradient = TensorBlock(
new_values,
Labels(first_gradient.samples.names, new_samples),
first_gradient.components,
first_gradient.properties,
)
new_block.add_gradient(parameter, new_gradient)
return new_block
def _join_block_properties(
blocks: List[TensorBlock],
fname: str,
add_dimension: Optional[str],
) -> TensorBlock:
"""
Actual implementation of block joining along properties.
:param blocks: blocks to join
:param fname: name of the function calling ``_join_block_properties``, to be used in
error messages
:param add_dimension: name of the dimension to add to the properties
"""
assert len(blocks) != 0
if len(blocks) == 1:
return blocks[0]
first_block = blocks[0]
property_names = first_block.properties.names
has_different_property_names = False
n_joined_properties = 0
for block in blocks:
n_joined_properties += block.values.shape[-1]
_check_blocks_raise(first_block, block, fname, ["samples", "components"])
_check_same_gradients_raise(first_block, block, fname, ["components"])
if block.properties.names != property_names:
has_different_property_names = True
shape = list(first_block.values.shape)
shape[-1] = n_joined_properties
new_values = _dispatch.empty_like(first_block.values, shape=shape)
first_properties_size = first_block.properties.values.shape[1]
if add_dimension is None:
properties_size = first_properties_size
else:
properties_size = first_properties_size + 1
if has_different_property_names:
if add_dimension is not None:
raise ValueError(
"We can not add an extra dimension to properties when the inputs "
"have different property names"
)
new_properties_values = _dispatch.empty_like(
first_block.samples.values, shape=[n_joined_properties, 2]
)
else:
new_properties_values = _dispatch.empty_like(
first_block.samples.values, shape=[n_joined_properties, properties_size]
)
start = 0
for block_i, block in enumerate(blocks):
stop = start + len(block.properties)
new_values[..., start:stop] = block.values[..., :]
if has_different_property_names:
new_properties_values[start:stop, 0] = block_i
new_properties_values[start:stop, 1] = _dispatch.int_array_like(
list(range(len(block.properties))), block.properties.values
)
else:
new_properties_values[start:stop, :first_properties_size] = (
block.properties.values[:]
)
if add_dimension is not None:
new_properties_values[start:stop, first_properties_size] = block_i
start = stop
# finalize the new properties
if has_different_property_names:
new_properties = Labels(
names=["joined_index", "property"],
values=new_properties_values,
)
else:
new_properties_names = first_block.properties.names
if add_dimension is not None:
new_properties_names.append(add_dimension)
new_properties = Labels(new_properties_names, new_properties_values)
new_block = TensorBlock(
new_values,
first_block.samples,
first_block.components,
new_properties,
)
for parameter in first_block.gradients_list():
first_gradient = first_block.gradient(parameter)
gradients: List[TensorBlock] = []
joined_gradient_samples = first_gradient.samples
for block in blocks:
gradient = block.gradient(parameter)
if len(gradient.gradients_list()) != 0:
raise NotImplementedError(
"gradients of gradients are not yet supported"
)
joined_gradient_samples = joined_gradient_samples.union(gradient.samples)
gradients.append(gradient)
shape = list(first_gradient.values.shape)
shape[0] = len(joined_gradient_samples)
shape[-1] = len(new_properties)
# we need to use `zeros_like` instead of `empty_like`, because some
# gradients might be missing (i.e. implicitly zero) in some input blocks
new_values = _dispatch.zeros_like(first_gradient.values, shape=shape)
start = 0
for gradient in gradients:
stop = start + len(gradient.properties)
# find where we should put the current gradients in the joined samples
# we can not get the mapping in the first loop over gradients above since
# `joined_gradient_samples` could still change
_, _, mapping = joined_gradient_samples.union_and_mapping(gradient.samples)
new_values[mapping, ..., start:stop] = gradient.values
start = stop
new_gradient = TensorBlock(
new_values,
joined_gradient_samples,
first_gradient.components,
new_properties,
)
new_block.add_gradient(parameter, new_gradient)
return new_block
[docs]
@torch_jit_script
def join(
tensors: List[TensorMap],
axis: str,
different_keys: str = "error",
add_dimension: Optional[str] = None,
) -> TensorMap:
"""Join a sequence of :py:class:`TensorMap` with similar keys along an axis.
The ``axis`` parameter specifies the type of joining: with ``axis='properties'`` the
``tensors`` will be joined along the ``properties`` and for ``axis='samples'`` they
will be joined along the ``samples``.
:param tensors: sequence of :py:class:`TensorMap` to join
:param axis: Along which axis the :py:class:`TensorMap`s should be joined. This can
be either ``'properties'`` or ``'samples'``.
:param different_keys: Method to handle different keys between the tensors. For
``"error"`` keys in all tensors have to be the same. For ``"intersection"`` only
blocks present in all tensors will be taken into account. For ``"union"``
missing keys will be treated like if they where associated with an empty block.
:param add_dimension: Add an the extra dimension to the joined labels with the given
name. See examples for the case where this is applicable. The dimension forms
the last dimension of the joined labels.
:return: The joined :py:class:`TensorMap` with more properties or samples than the
inputs.
Examples
--------
The first use case for this function is when joining ``TensorMap`` with the same
labels names (either along ``samples`` or ``properties``):
>>> import numpy as np
>>> import metatensor as mts
>>> from metatensor import Labels, TensorBlock, TensorMap
>>> values = np.array([[1.1, 2.1, 3.1]])
>>> samples = Labels("sample", np.array([[0]]))
Define two disjoint set of :py:class:`Labels`.
>>> properties_1 = Labels("n", np.array([[0], [2], [3]]))
>>> properties_2 = Labels("n", np.array([[1], [4], [5]]))
>>> block_1 = TensorBlock(
... values=values,
... samples=Labels.single(),
... components=[],
... properties=properties_1,
... )
>>> block_2 = TensorBlock(
... values=values,
... samples=Labels.single(),
... components=[],
... properties=properties_2,
... )
>>> tensor_1 = TensorMap(keys=Labels.single(), blocks=[block_1])
>>> tensor_2 = TensorMap(keys=Labels.single(), blocks=[block_2])
joining along the properties leads to
>>> joined_tensor = mts.join([tensor_1, tensor_2], axis="properties")
>>> joined_tensor[0].properties
Labels(
n
0
2
3
1
4
5
)
Second, if the labels names are the same but the values are not unique, you can ask
to add an extra dimension to the labels when joining with ``add_dimension``, thus
creating unique values
>>> properties_3 = Labels("n", np.array([[0], [2], [3]]))
``properties_3`` has the same name and also shares values with ``properties_1`` as
defined above.
>>> block_3 = TensorBlock(
... values=values,
... samples=Labels.single(),
... components=[],
... properties=properties_3,
... )
>>> tensor_3 = TensorMap(keys=Labels.single(), blocks=[block_3])
joining along properties leads to
>>> joined_tensor = mts.join(
... [tensor_1, tensor_3], axis="properties", add_dimension="tensor"
... )
>>> joined_tensor[0].properties
Labels(
n tensor
0 0
2 0
3 0
0 1
2 1
3 1
)
Finally, when joining along properties, if different ``TensorMap`` have different
property names, we'll re-create new properties labels containing the original tensor
index and the corresponding property index. This does not apply when joining along
samples.
>>> properties_4 = Labels(["a", "b"], np.array([[0, 0], [1, 2], [1, 3]]))
``properties_4`` has the different names compared to ``properties_1`` defined above.
>>> block_4 = TensorBlock(
... values=values,
... samples=Labels.single(),
... components=[],
... properties=properties_4,
... )
>>> tensor_4 = TensorMap(keys=Labels.single(), blocks=[block_4])
joining along properties leads to
>>> joined_tensor = mts.join([tensor_1, tensor_4], axis="properties")
>>> joined_tensor[0].properties
Labels(
joined_index property
0 0
0 1
0 2
1 0
1 1
1 2
)
"""
if not torch_jit_is_scripting():
if not isinstance(tensors, (list, tuple)):
raise TypeError(f"`tensors` must be a list or a tuple, not {type(tensors)}")
for tensor in tensors:
if not isinstance_metatensor(tensor, "TensorMap"):
raise TypeError(
"`tensors` elements must be metatensor TensorMap, "
f"not {type(tensor)}"
)
if len(tensors) == 0:
raise ValueError("provide at least one `TensorMap` for joining")
if axis not in ("samples", "properties"):
raise ValueError(
"Only `'properties'` or `'samples'` are "
"valid values for the `axis` parameter."
)
if len(tensors) == 1:
return tensors[0]
if different_keys == "error":
for ts_to_join in tensors[1:]:
_check_same_keys_raise(tensors[0], ts_to_join, "join")
elif different_keys == "intersection":
tensors = _tensors_intersection(tensors)
elif different_keys == "union":
tensors = _tensors_union(tensors, axis=axis)
else:
raise ValueError(
f"'{different_keys}' is not a valid option for `different_keys`. Choose "
"either 'error', 'intersection' or 'union'."
)
# Deduce if sample/property names are the same in all tensors.
# If this is not the case we have to change unify the corresponding labels later.
if axis == "samples":
names_list = [tensor.sample_names for tensor in tensors]
names_list_flattened: List[str] = []
for names in names_list:
names_list_flattened += names
unique_names = _unique_str(names_list_flattened)
length_equal = [len(unique_names) == len(names) for names in names_list]
sample_names_are_same = sum(length_equal) == len(length_equal)
if not sample_names_are_same:
# It's fine to lose metadata for the properties, less so for the samples
raise ValueError(
"Different tensor have different sample names in `join`. "
"Joining along samples with different sample names will lose "
"information and is not supported."
)
keys = tensors[0].keys
blocks: List[TensorBlock] = []
for i in range(len(keys)):
key = keys[i]
blocks_to_join: List[TensorBlock] = []
for tensor in tensors:
blocks_to_join.append(tensor.block(key))
if axis == "samples":
blocks.append(_join_block_samples(blocks_to_join, "join", add_dimension))
else:
blocks.append(_join_block_properties(blocks_to_join, "join", add_dimension))
return TensorMap(keys=keys, blocks=blocks)
[docs]
@torch_jit_script
def join_blocks(
blocks: List[TensorBlock],
axis: str,
add_dimension: Optional[str] = None,
) -> TensorBlock:
"""Join a sequence of :py:class:`TensorBlock` along an axis.
The ``axis`` parameter specifies the type of joining: with ``axis='properties'`` the
``blocks`` will be joined along the ``properties`` and for ``axis='samples'`` they
will be joined along the ``samples``.
:param tensors: sequence of :py:class:`TensorMap` to join
:param axis: Along which axis the blocks should be joined. This can be either
``'properties'`` or ``'samples'``.
:param add_dimension: Add an the extra dimension to the joined labels with the given
name. The dimension forms the last dimension of the joined labels.
:return: The joined :py:class:`TensorBlock` with more properties or samples than the
inputs.
.. seealso::
The examples for :py:func:`join`.
"""
if not torch_jit_is_scripting():
if not isinstance(blocks, (list, tuple)):
raise TypeError(f"`blocks` must be a list or a tuple, not {type(blocks)}")
for block in blocks:
if not isinstance_metatensor(block, "TensorBlock"):
raise TypeError(
"`blocks` elements must be metatensor TensorBlock, "
f"not {type(block)}"
)
if len(blocks) == 0:
raise ValueError("provide at least one `TensorBlock` for joining")
if axis not in ("samples", "properties"):
raise ValueError(
"Only `'properties'` or `'samples'` are valid values for the `axis` "
"parameter"
)
if axis == "samples":
names_list = [block.samples.names for block in blocks]
names_list_flattened: List[str] = []
for names in names_list:
names_list_flattened += names
unique_names = _unique_str(names_list_flattened)
length_equal = [len(unique_names) == len(names) for names in names_list]
samples_names_are_same = sum(length_equal) == len(length_equal)
if not samples_names_are_same:
# It's fine to lose metadata for the properties, less so for the samples!
raise ValueError(
"Different blocks have different sample names in `join_blocks`. "
"Joining along samples with different sample names will lose "
"information and is not supported."
)
return _join_block_samples(blocks, "join_blocks", add_dimension)
else:
return _join_block_properties(blocks, "join_blocks", add_dimension)