Data utilites

Dataset

class metatensor.learn.Dataset(size: int | None = None, **kwargs)[source]

Defines a PyTorch compatible torch.data.Dataset class for various named data fields.

The data fields are specified as keyword arguments to the constructor, where the keyword is the name of the field, and the value is either a list of data objects or a callable.

size specifies the number of samples in the dataset. This only needs to be passed if all data fields are passed as callables. Otherwise, the size is inferred from the length of the data fields passed as lists.

Every sample in the dataset is assigned a numeric ID from 0 to size - 1. This ID can be used to access the corresponding sample. For instance, dataset[0] returns a named tuple of all data fields for the first sample in the dataset.

A data field kwarg passed as a list must be comprised of the same type of object, and its length must be consistent with the size argument (if specified) and the length of all other data fields passed as lists.

Otherwise a data field kwarg passed as a callable must take a single argument corresponding to the numeric sample ID and return the data object for that sample ID. This data field for a given sample is then only lazily loaded into memory when the Dataset.__getitem__() method is called.

>>> # create a Dataset with only lists
>>> dataset = Dataset(num=[1, 2, 3], string=["a", "b", "c"])
>>> dataset[0]
Sample(num=1, string='a')
>>> dataset[2]
Sample(num=3, string='c')
>>> # create a Dataset with callables for lazy loading of data
>>> def call_me(sample_id: int):
...     # this could also read from a file
...     return f"compute something with sample {sample_id}"
>>> dataset = Dataset(num=[1, 2, 3], call_me=call_me)
>>> dataset[0]
Sample(num=1, call_me='compute something with sample 0')
>>> dataset[2]
Sample(num=3, call_me='compute something with sample 2')
>>> # iterating over a dataset
>>> for num, called in dataset:
...     print(num, " -- ", called)
1  --  compute something with sample 0
2  --  compute something with sample 1
3  --  compute something with sample 2
>>> for sample in dataset:
...     print(sample.num, " -- ", sample.call_me)
1  --  compute something with sample 0
2  --  compute something with sample 1
3  --  compute something with sample 2
Parameters:
  • size (int | None) – Optional, an integer indicating the size of the dataset, i.e. the number of samples. This only needs to be specified if all the fields are callable.

  • kwargs – List or Callable. Keyword arguments specifying the data fields for the dataset.

classmethod from_dict(data: Dict[str, List | Callable], size: int | None = None) Dataset[source]

Create a Dataset from the given data. This function behave like Dataset(size=size, **data), but allows to use names for the different fields that would not be valid in the main constructor.

>>> dataset = Dataset.from_dict(
...     {
...         "valid": [0, 0, 0],
...         "with space": [-1, 2, -3],
...         "with/slash": ["a", "b", "c"],
...     }
... )
>>> sample = dataset[1]
>>> sample
Sample(valid=0, 'with space'=2, 'with/slash'='b')
>>> # fields for which the name is a valid identifier can be accessed as usual
>>> sample.valid
0
>>> # fields which are not valid identifiers can be accessed like this
>>> sample["with space"]
2
>>> sample["with/slash"]
'b'
Parameters:
  • data (Dict[str, List | Callable]) – Dictionary of List or Callable containing the data. This will behave as if all the entries of the dictionary where passed as keyword arguments to __init__.

  • size (int | None) – Optional, an integer indicating the size of the dataset, i.e. the number of samples. This only needs to be specified if all the fields are callable.

Return type:

Dataset

__getitem__(idx: int) NamedTuple[source]

Returns the data for each field corresponding to the internal index idx.

Each item can be accessed with self[idx]. Returned is a named tuple with fields corresponding to those passed (in order) to the constructor upon class initialization.

Parameters:

idx (int)

Return type:

NamedTuple

class metatensor.learn.IndexedDataset(sample_id: List, **kwargs)[source]

Defines a PyTorch compatible torch.data.Dataset class for various named data fields, with sample indexed by a list of unique sample IDs.

The data fields are specified as keyword arguments to the constructor, where the keyword is the name of the field, and the value is either a list of data objects or a callable.

sample_id must be a unique list of any hashable object. Each respective sample ID is assigned an internal numeric index from 0 to len(sample_id) - 1. This is used to internally index the dataset, and can be used to access a given sample. For instance, dataset[0] returns a named tuple of all data fields for the first sample in the dataset, i.e. the one with unique sample ID at sample_id[0]. In order to access a sample by its ID, use the IndexedDataset.get_sample() method.

A data field kwarg passed as a list must be comprised of the same type of object, and its length must be consistent with the length of sample_id and the length of all other data fields passed as lists.

Otherwise a data field kwarg passed as a callable must take a single argument corresponding to the unique sample ID (i.e. those passed in sample_id) and return the data object for that sample ID. This data field for a given sample is then only lazily loaded into memory when IndexedDataset.__getitem__() or IndexedDataset.get_sample() methods are called.

>>> # create an IndexedDataset with lists
>>> dataset = IndexedDataset(sample_id=["cat", "bird", "dog"], y=[11, 22, 33])
>>> dataset[0]
Sample(sample_id='cat', y=11)
>>> dataset[2]
Sample(sample_id='dog', y=33)
>>> # create an IndexedDataset with callables for lazy loading of data
>>> def call_me(sample_id: int):
...     # this could also read from a file
...     return f"compute something with sample {sample_id}"
>>> dataset = IndexedDataset(sample_id=["cat", "bird", "dog"], call_me=call_me)
>>> dataset[0]
Sample(sample_id='cat', call_me='compute something with sample cat')
>>> dataset[2]
Sample(sample_id='dog', call_me='compute something with sample dog')
>>> dataset.get_sample("bird")
Sample(sample_id='bird', call_me='compute something with sample bird')
Parameters:
  • sample_id (List) – A list of unique IDs for each sample in the dataset.

  • kwargs – Keyword arguments specifying the data fields for the dataset.

classmethod from_dict(data: Dict[str, Any], sample_id: List) IndexedDataset[source]

Create a IndexedDataset from the given data. This function behave like IndexedDataset(sample_id=sample_id, **data), but allows to use names for the different fields that would not be valid in the main constructor.

>>> dataset = IndexedDataset.from_dict(
...     {
...         "valid": [0, 0, 0],
...         "with space": [-1, 2, -3],
...         "with/slash": ["a", "b", "c"],
...     },
...     sample_id=[11, 22, 33],
... )
>>> sample = dataset[1]
>>> sample
Sample(sample_id=22, valid=0, 'with space'=2, 'with/slash'='b')
>>> # fields for which the name is a valid identifier can be accessed as usual
>>> sample.valid
0
>>> # fields which are not valid identifiers can be accessed like this
>>> sample["with space"]
2
>>> sample["with/slash"]
'b'
Parameters:
  • data (Dict[str, Any]) – Dictionary of List or Callable containing the data. This will behave as if all the entries of the dictionary where passed as keyword arguments to __init__.

  • sample_id (List) – A list of unique IDs for each sample in the dataset.

Return type:

IndexedDataset

__getitem__(idx: int) NamedTuple[source]

Returns the data for each field corresponding to the internal index idx.

Each item can be accessed with self[idx]. Returned is a named tuple, whose first field is the sample ID, and the remaining fields correspond those passed (in order) to the constructor upon class initialization.

Parameters:

idx (int)

Return type:

NamedTuple

get_sample(sample_id) NamedTuple[source]

Returns a named tuple for the sample corresponding to the given sample_id.

Return type:

NamedTuple

Dataloader

class metatensor.learn.DataLoader(dataset, collate_fn=<function group_and_join>, **kwargs)[source]

Class for loading data from an torch.utils.data.Dataset object with a default collate function (metatensor.learn.data.group_and_join()) that supports torch.Tensor, metatensor.torch.atomistic.System, metatensor.TensorMap, and metatensor.torch.TensorMap.

The dataset wil typically be metatensor.learn.Dataset or metatensor.learn.IndexedDataset.

Any argument as accepted by the torch torch.utils.data.DataLoader parent class is supported.

Parameters:

dataset (Dataset[_T_co])

Collating data

metatensor.learn.data.group(batch: List[NamedTuple]) NamedTuple[source]

Collates a minibatch by grouping the data for each data field and returning a named tuple.

batch is a list of named tuples. Each is composed of a number of named data fields, which are arbitrary objects, such as torch.Tensor, atomistic.Systems, or TensorMap.

Returned is a new named tuple with the same named data fields as the each sample in the batch, but with the sample data collated for each respective field. The indices of the samples in the minibatch belong to the first field unpacked from the named tuple, i.e. “sample_indices”.

Parameters:

batch (List[NamedTuple]) – list of named tuples for each sample in the minibatch.

Returns:

a named tuple, with the named fields the same as in the original samples in the batch, but with the samples grouped by data field.

Return type:

NamedTuple

metatensor.learn.data.group_and_join(batch: List[NamedTuple], fields_to_join: List[str] | None = None, join_kwargs: dict | None = None) NamedTuple[source]

Collates a minibatch by grouping the data for each fields, joining tensors along the samples axis, and returning a named tuple.

Similar in functionality to the generic group(), but instead data fields that are torch.Tensor objects are vertically stacked, and TensorMap objects are joined along the samples axis.

batch is a list of named tuples. Each has a number of fields that correspond to different named data fields. These data fields can be arbitrary objects, such as torch.Tensor, atomistic.Systems, or TensorMap.

For each data field, the data object from each sample in the batch is collated into a list, except where the data field is a list of torch.Tensor or TensorMap objects. In this case, the tensors are joined along the samples axis. If torch tensors, all must be of the same size. In the case of TensorMaps, the union of sparse keys are taken.

Returned is a new named tuple with the same fields as the each sample in the batch, but with the sample data collated for each respective field. The sample indices in the minibatch are in the first field of the named tuple under “sample_indices”.

Parameters:
  • batch (List[NamedTuple]) – list of named tuples for each sample in the batch.

  • fields_to_join (List[str] | None) – list of data field names to join. If None, all fields that can be joined are joined, i.e. those comprised of torch.Tensor or TensorMap objects. Any names passed that are either invalid or are names of fields that aren’t these types will be silently ignored.

  • join_kwargs (dict | None) – keyword arguments passed to the metatensor.join() function, to be used when joining data fields comprised of TensorMap objects. If none, the defaults are used - see the function documentation for details. The axis=”samples” arg is set by default.

Returns:

a named tuple, with the named fields the same as in the original samples in the batch, but with the samples collated for each respective field. If the data fields are torch.Tensor or TensorMap objects, they are joined along the samples axis.

Return type:

NamedTuple