Source code for metatensor.learn.data.collate

"""
Module containing a collate function for use in a Dataloader.
"""

from typing import List, NamedTuple, Optional, Union

import torch

import metatensor
from metatensor import TensorMap

from ._namedtuple import namedtuple


[docs] def group(batch: List[NamedTuple]) -> NamedTuple: """ Collates a minibatch by grouping the data for each data field and returning a named tuple. `batch` is a list of named tuples. Each is composed of a number of named data fields, which are arbitrary objects, such as :py:class:`torch.Tensor`, :py:class:`atomistic.Systems`, or :py:class:`TensorMap`. Returned is a new named tuple with the same named data fields as the each sample in the `batch`, but with the sample data collated for each respective field. The indices of the samples in the minibatch belong to the first field unpacked from the named tuple, i.e. "sample_indices". :param batch: list of named tuples for each sample in the minibatch. :return: a named tuple, with the named fields the same as in the original samples in the batch, but with the samples grouped by data field. """ return namedtuple("Batch", batch[0]._fields)(*list(zip(*batch)))
[docs] def group_and_join( batch: List[NamedTuple], fields_to_join: Optional[List[str]] = None, join_kwargs: Optional[dict] = None, ) -> NamedTuple: """ Collates a minibatch by grouping the data for each fields, joining tensors along the samples axis, and returning a named tuple. Similar in functionality to the generic :py:func:`group`, but instead data fields that are :py:class:`torch.Tensor` objects are vertically stacked, and :py:class:`TensorMap` objects are joined along the samples axis. `batch` is a list of named tuples. Each has a number of fields that correspond to different named data fields. These data fields can be arbitrary objects, such as :py:class:`torch.Tensor`, :py:class:`atomistic.Systems`, or :py:class:`TensorMap`. For each data field, the data object from each sample in the batch is collated into a list, except where the data field is a list of :py:class:`torch.Tensor` or :py:class:`TensorMap` objects. In this case, the tensors are joined along the samples axis. If torch tensors, all must be of the same size. In the case of TensorMaps, the union of sparse keys are taken. Returned is a new named tuple with the same fields as the each sample in the `batch`, but with the sample data collated for each respective field. The sample indices in the minibatch are in the first field of the named tuple under "sample_indices". :param batch: list of named tuples for each sample in the batch. :param fields_to_join: list of data field names to join. If None, all fields that can be joined are joined, i.e. those comprised of :py:class:`torch.Tensor` or :py:class:`TensorMap` objects. Any names passed that are either invalid or are names of fields that aren't these types will be silently ignored. :param join_kwargs: keyword arguments passed to the :py:func:`metatensor.join` function, to be used when joining data fields comprised of :py:class:`TensorMap` objects. If none, the defaults are used - see the function documentation for details. The `axis="samples"` arg is set by default. :return: a named tuple, with the named fields the same as in the original samples in the batch, but with the samples collated for each respective field. If the data fields are :py:class:`torch.Tensor` or :py:class:`TensorMap` objects, they are joined along the samples axis. """ data: List[Union[TensorMap, torch.Tensor]] = [] names = batch[0]._fields if fields_to_join is None: fields_to_join = names if join_kwargs is None: join_kwargs = {} for name, field in zip(names, list(zip(*batch))): if name == "sample_id": # special case, keep as is data.append(field) continue if name in fields_to_join: # Join tensors if requested if isinstance(field[0], TensorMap): # metatensor.TensorMap type data.append(metatensor.join(field, axis="samples", **join_kwargs)) elif isinstance(field[0], torch.ScriptObject) and field[0]._has_method( "keys_to_properties" ): # inferred metatensor.torch.TensorMap type data.append(metatensor.torch.join(field, axis="samples", **join_kwargs)) elif isinstance(field[0], torch.Tensor): # torch.Tensor type data.append(torch.vstack(field)) else: data.append(field) else: # otherwise just keep as a list data.append(field) return namedtuple("Batch", names)(*data)