Data utilites#

Dataset#

class metatensor.learn.data.Dataset(size: int | None = None, **kwargs)[source]#

Defines a dataset class for various named data fields.

The data fields are specified as keyword arguments to the constructor, where the keyword is the name of the field, and the value is either a list of data objects or a callable.

size specifies the number of samples in the dataset. This only needs to be passed if all data fields are passed as callables. Otherwise, the size is inferred from the length of the data fields passed as lists.

Every sample in the dataset is assigned a numeric ID from 0 to size - 1. This ID can be used to access the corresponding sample. For instance, dataset[0] returns a named tuple of all data fields for the first sample in the dataset.

A data field kwarg passed as a list must be comprised of the same type of object, and its length must be consistent with the size argument (if specified) and the length of all other data fields passed as lists.

Otherwise a data field kwarg passed as a callable must take a single argument corresponding to the numeric sample ID and return the data object for that sample ID. This data field for a given sample is then only lazily loaded into memory when the Dataset.__getitem__ method is called.

Parameters:
  • size (int | None) – Optional, an integer indicating the size of the dataset, i.e. the number of samples. This only needs to be specified if all data

  • kwargs – List or Callable. Keyword arguments specifying the data fields for the dataset.

class metatensor.learn.data.IndexedDataset(sample_id: List, **kwargs)[source]#

Defines a dataset class for various named data fields, with sample indexed by a list of unique sample IDs.

The data fields are specified as keyword arguments to the constructor, where the keyword is the name of the field, and the value is either a list of data objects or a callable.

sample_id must be a unique list of any hashable object. Each respective sample ID is assigned an internal numeric index from 0 to len(sample_id) - 1. This is used to internally index the dataset, and can be used to access a given sample. For instance, dataset[0] returns a named tuple of all data fields for the first sample in the dataset, i.e. the one with unique sample ID at sample_id[0]. In order to access a sample by its ID, use the Dataset.get_sample method.

A data field kwarg passed as a list must be comprised of the same type of object, and its length must be consistent with the length of sample_id and the length of all other data fields passed as lists.

Otherwise a data field kwarg passed as a callable must take a single argument corresponding to the unique sample ID (i.e. those passed in sample_id) and return the data object for that sample ID. This data field for a given sample is then only lazily loaded into memory when the Dataset.__getitem__ or Dataset.get_sample methods are called.

Parameters:
  • sample_id (List) – A list of unique IDs for each sample in the dataset.

  • kwargs – Keyword arguments specifying the data fields for the dataset.

get_sample(sample_id) NamedTuple[source]#

Returns a named tuple for the sample corresponding to the given sample_id.

Return type:

NamedTuple

Dataloader#

class metatensor.learn.data.DataLoader(dataset: ~torch.utils.data.dataset.Dataset, collate_fn: ~typing.Callable = <function group_and_join>, **kwargs)[source]#

Class for loading data from an :py:class:`torch.utils.data.Dataset object with a default collate function that supports torch.Tensor, atomistic.Systems, or TensorMap.

Any argument as accepted by the torch torch.utils.data.DataLoader parent class is supported. Please refer to https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader

Parameters:

Collating data#

metatensor.learn.data.group(batch: List[NamedTuple]) NamedTuple[source]#

Collates a minibatch by grouping the data for each data field and returning a named tuple.

batch is a list of named tuples. Each is composed of a number of named data fields, which are arbitrary objects, such as torch.Tensor, atomistic.Systems, or TensorMap.

Returned is a new named tuple with the same named data fields as the each sample in the batch, but with the sample data collated for each respective field. The indices of the samples in the minibatch belong to the first field unpacked from the named tuple, i.e. “sample_indices”.

Parameters:

batch (List[NamedTuple]) – list of named tuples for each sample in the minibatch.

Returns:

a named tuple, with the named fields the same as in the original samples in the batch, but with the samples grouped by data field.

Return type:

NamedTuple

metatensor.learn.data.group_and_join(batch: List[NamedTuple], fields_to_join: List[str] | None = None, join_kwargs: dict | None = None) NamedTuple[source]#

Collates a minibatch by grouping the data for each fields, joining tensors along the samples axis, and returning a named tuple.

Similar in functionality to the generic group(), but instead data fields that are torch.Tensor objects are vertically stacked, and TensorMap objects are joined along the samples axis.

batch is a list of named tuples. Each has a number of fields that correspond to different named data fields. These data fields can be arbitrary objects, such as torch.Tensor, atomistic.Systems, or TensorMap.

For each data field, the data object from each sample in the batch is collated into a list, except where the data field is a list of torch.Tensor or TensorMap objects. In this case, the tensors are joined along the samples axis. If torch tensors, all must be of the same size. In the case of TensorMaps, the union of sparse keys are taken.

Returned is a new named tuple with the same fields as the each sample in the batch, but with the sample data collated for each respective field. The sample indices in the minibatch are in the first field of the named tuple under “sample_indices”.

Parameters:
  • batch (List[NamedTuple]) – list of named tuples for each sample in the batch.

  • fields_to_join (List[str] | None) – list of data field names to join. If None, all fields that can be joined are joined, i.e. those comprised of torch.Tensor or TensorMap objects. Any names passed that are either invalid or are names of fields that aren’t these types will be silently ignored.

  • join_kwargs (dict | None) – keyword arguments passed to the metatensor.join() function, to be used when joining data fields comprised of TensorMap objects. If none, the defaults are used - see the function documentation for details. The axis=”samples” arg is set by default.

Returns:

a named tuple, with the named fields the same as in the original samples in the batch, but with the samples collated for each respective field. If the data fields are torch.Tensor or TensorMap objects, they are joined along the samples axis.

Return type:

NamedTuple