Data utilites¶
Dataset¶
- class metatensor.learn.Dataset(size: int | None = None, **kwargs)[source]¶
Defines a PyTorch compatible
torch.data.Dataset
class for various named data fields.The data fields are specified as keyword arguments to the constructor, where the keyword is the name of the field, and the value is either a list of data objects or a callable.
size
specifies the number of samples in the dataset. This only needs to be passed if all data fields are passed as callables. Otherwise, the size is inferred from the length of the data fields passed as lists.Every sample in the dataset is assigned a numeric ID from 0 to
size - 1
. This ID can be used to access the corresponding sample. For instance,dataset[0]
returns a named tuple of all data fields for the first sample in the dataset.A data field kwarg passed as a list must be comprised of the same type of object, and its length must be consistent with the
size
argument (if specified) and the length of all other data fields passed as lists.Otherwise a data field kwarg passed as a callable must take a single argument corresponding to the numeric sample ID and return the data object for that sample ID. This data field for a given sample is then only lazily loaded into memory when the
Dataset.__getitem__()
method is called.>>> # create a Dataset with only lists >>> dataset = Dataset(num=[1, 2, 3], string=["a", "b", "c"]) >>> dataset[0] Sample(num=1, string='a') >>> dataset[2] Sample(num=3, string='c')
>>> # create a Dataset with callables for lazy loading of data >>> def call_me(sample_id: int): ... # this could also read from a file ... return f"compute something with sample {sample_id}" >>> dataset = Dataset(num=[1, 2, 3], call_me=call_me) >>> dataset[0] Sample(num=1, call_me='compute something with sample 0') >>> dataset[2] Sample(num=3, call_me='compute something with sample 2')
>>> # iterating over a dataset >>> for num, called in dataset: ... print(num, " -- ", called) 1 -- compute something with sample 0 2 -- compute something with sample 1 3 -- compute something with sample 2 >>> for sample in dataset: ... print(sample.num, " -- ", sample.call_me) 1 -- compute something with sample 0 2 -- compute something with sample 1 3 -- compute something with sample 2
- Parameters:
size (int | None) – Optional, an integer indicating the size of the dataset, i.e. the number of samples. This only needs to be specified if all the fields are callable.
kwargs – List or Callable. Keyword arguments specifying the data fields for the dataset.
- classmethod from_dict(data: Dict[str, List | Callable], size: int | None = None) Dataset [source]¶
Create a
Dataset
from the givendata
. This function behave likeDataset(size=size, **data)
, but allows to use names for the different fields that would not be valid in the main constructor.>>> dataset = Dataset.from_dict( ... { ... "valid": [0, 0, 0], ... "with space": [-1, 2, -3], ... "with/slash": ["a", "b", "c"], ... } ... ) >>> sample = dataset[1] >>> sample Sample(valid=0, 'with space'=2, 'with/slash'='b') >>> # fields for which the name is a valid identifier can be accessed as usual >>> sample.valid 0 >>> # fields which are not valid identifiers can be accessed like this >>> sample["with space"] 2 >>> sample["with/slash"] 'b'
- Parameters:
data (Dict[str, List | Callable]) – Dictionary of List or Callable containing the data. This will behave as if all the entries of the dictionary where passed as keyword arguments to
__init__
.size (int | None) – Optional, an integer indicating the size of the dataset, i.e. the number of samples. This only needs to be specified if all the fields are callable.
- Return type:
- __getitem__(idx: int) NamedTuple [source]¶
Returns the data for each field corresponding to the internal index
idx
.Each item can be accessed with
self[idx]
. Returned is a named tuple with fields corresponding to those passed (in order) to the constructor upon class initialization.- Parameters:
idx (int)
- Return type:
- class metatensor.learn.IndexedDataset(sample_id: List, **kwargs)[source]¶
Defines a PyTorch compatible
torch.data.Dataset
class for various named data fields, with sample indexed by a list of unique sample IDs.The data fields are specified as keyword arguments to the constructor, where the keyword is the name of the field, and the value is either a list of data objects or a callable.
sample_id
must be a unique list of any hashable object. Each respective sample ID is assigned an internal numeric index from 0 tolen(sample_id) - 1
. This is used to internally index the dataset, and can be used to access a given sample. For instance,dataset[0]
returns a named tuple of all data fields for the first sample in the dataset, i.e. the one with unique sample ID atsample_id[0]
. In order to access a sample by its ID, use theIndexedDataset.get_sample()
method.A data field kwarg passed as a list must be comprised of the same type of object, and its length must be consistent with the length of
sample_id
and the length of all other data fields passed as lists.Otherwise a data field kwarg passed as a callable must take a single argument corresponding to the unique sample ID (i.e. those passed in
sample_id
) and return the data object for that sample ID. This data field for a given sample is then only lazily loaded into memory whenIndexedDataset.__getitem__()
orIndexedDataset.get_sample()
methods are called.>>> # create an IndexedDataset with lists >>> dataset = IndexedDataset(sample_id=["cat", "bird", "dog"], y=[11, 22, 33]) >>> dataset[0] Sample(sample_id='cat', y=11) >>> dataset[2] Sample(sample_id='dog', y=33)
>>> # create an IndexedDataset with callables for lazy loading of data >>> def call_me(sample_id: int): ... # this could also read from a file ... return f"compute something with sample {sample_id}" >>> dataset = IndexedDataset(sample_id=["cat", "bird", "dog"], call_me=call_me) >>> dataset[0] Sample(sample_id='cat', call_me='compute something with sample cat') >>> dataset[2] Sample(sample_id='dog', call_me='compute something with sample dog') >>> dataset.get_sample("bird") Sample(sample_id='bird', call_me='compute something with sample bird')
- Parameters:
sample_id (List) – A list of unique IDs for each sample in the dataset.
kwargs – Keyword arguments specifying the data fields for the dataset.
- classmethod from_dict(data: Dict[str, Any], sample_id: List) IndexedDataset [source]¶
Create a
IndexedDataset
from the givendata
. This function behave likeIndexedDataset(sample_id=sample_id, **data)
, but allows to use names for the different fields that would not be valid in the main constructor.>>> dataset = IndexedDataset.from_dict( ... { ... "valid": [0, 0, 0], ... "with space": [-1, 2, -3], ... "with/slash": ["a", "b", "c"], ... }, ... sample_id=[11, 22, 33], ... ) >>> sample = dataset[1] >>> sample Sample(sample_id=22, valid=0, 'with space'=2, 'with/slash'='b') >>> # fields for which the name is a valid identifier can be accessed as usual >>> sample.valid 0 >>> # fields which are not valid identifiers can be accessed like this >>> sample["with space"] 2 >>> sample["with/slash"] 'b'
- Parameters:
- Return type:
- __getitem__(idx: int) NamedTuple [source]¶
Returns the data for each field corresponding to the internal index
idx
.Each item can be accessed with
self[idx]
. Returned is a named tuple, whose first field is the sample ID, and the remaining fields correspond those passed (in order) to the constructor upon class initialization.- Parameters:
idx (int)
- Return type:
- get_sample(sample_id) NamedTuple [source]¶
Returns a named tuple for the sample corresponding to the given
sample_id
.- Return type:
Dataloader¶
- class metatensor.learn.DataLoader(dataset, collate_fn=<function group_and_join>, **kwargs)[source]¶
Class for loading data from an
torch.utils.data.Dataset
object with a default collate function (metatensor.learn.data.group_and_join()
) that supportstorch.Tensor
,metatensor.torch.atomistic.System
,metatensor.TensorMap
, andmetatensor.torch.TensorMap
.The dataset wil typically be
metatensor.learn.Dataset
ormetatensor.learn.IndexedDataset
.Any argument as accepted by the torch
torch.utils.data.DataLoader
parent class is supported.- Parameters:
dataset (Dataset[_T_co])
Collating data¶
- metatensor.learn.data.group(batch: List[NamedTuple]) NamedTuple [source]¶
Collates a minibatch by grouping the data for each data field and returning a named tuple.
batch is a list of named tuples. Each is composed of a number of named data fields, which are arbitrary objects, such as
torch.Tensor
,atomistic.Systems
, orTensorMap
.Returned is a new named tuple with the same named data fields as the each sample in the batch, but with the sample data collated for each respective field. The indices of the samples in the minibatch belong to the first field unpacked from the named tuple, i.e. “sample_indices”.
- Parameters:
batch (List[NamedTuple]) – list of named tuples for each sample in the minibatch.
- Returns:
a named tuple, with the named fields the same as in the original samples in the batch, but with the samples grouped by data field.
- Return type:
- metatensor.learn.data.group_and_join(batch: List[NamedTuple], fields_to_join: List[str] | None = None, join_kwargs: dict | None = None) NamedTuple [source]¶
Collates a minibatch by grouping the data for each fields, joining tensors along the samples axis, and returning a named tuple.
Similar in functionality to the generic
group()
, but instead data fields that aretorch.Tensor
objects are vertically stacked, andTensorMap
objects are joined along the samples axis.batch is a list of named tuples. Each has a number of fields that correspond to different named data fields. These data fields can be arbitrary objects, such as
torch.Tensor
,atomistic.Systems
, orTensorMap
.For each data field, the data object from each sample in the batch is collated into a list, except where the data field is a list of
torch.Tensor
orTensorMap
objects. In this case, the tensors are joined along the samples axis. If torch tensors, all must be of the same size. In the case of TensorMaps, the union of sparse keys are taken.Returned is a new named tuple with the same fields as the each sample in the batch, but with the sample data collated for each respective field. The sample indices in the minibatch are in the first field of the named tuple under “sample_indices”.
- Parameters:
batch (List[NamedTuple]) – list of named tuples for each sample in the batch.
fields_to_join (List[str] | None) – list of data field names to join. If None, all fields that can be joined are joined, i.e. those comprised of
torch.Tensor
orTensorMap
objects. Any names passed that are either invalid or are names of fields that aren’t these types will be silently ignored.join_kwargs (dict | None) – keyword arguments passed to the
metatensor.join()
function, to be used when joining data fields comprised ofTensorMap
objects. If none, the defaults are used - see the function documentation for details. The axis=”samples” arg is set by default.
- Returns:
a named tuple, with the named fields the same as in the original samples in the batch, but with the samples collated for each respective field. If the data fields are
torch.Tensor
orTensorMap
objects, they are joined along the samples axis.- Return type: