Note
Go to the end to download the full example code.
Using IndexedDataset¶
import os
import torch
from metatensor.learn.data import DataLoader, Dataset, IndexedDataset
Review of the standard Dataset¶
The previous tutorial, Datasets and data loaders, showed how to define
a Dataset
able to handle both torch tensor and metatensor TensorMap. We
saw that in-memory, on-disk, or mixed in-memory/on-disk datasets can be defined.
DataLoaders are then defined on top of these Dataset objects.
In all cases, however, each data sample is accessed by a numeric integer index, which
ranges from 0 to len(dataset) - 1
. Let’s use a simple example to review this.
Again let’s define some dummy data as before. Our x data as a list of random tensors, and our y data as a list of integers that enumerate the samples.
For the purposes of this tutorial, we will only focus on an in-memory dataset, though the same principles apply to on-disk and mixed datasets.
n_samples = 5
x_data = [torch.randn(3) for _ in range(n_samples)]
y_data = [i for i in range(n_samples)]
dataset = Dataset(x=x_data, y=y_data)
A sample is accessed by its numeric index. As the length of the lists passed as kwargs
is 5, both for x
and y
, the valid indices are [0, 1, 2, 3, 4].
Let’s retrieve the 4th sample (index 3) and print it. The value of the “y” data field should be 3.
print(dataset[3])
Sample(x=tensor([-1.1023, 0.5280, 0.0502]), y=3)
What if we wanted to access samples by something other than an integer index part of a continuous range?
- For instance, what if we wanted to access samples by:
a string id, or other arbitrary hashable object?
an integer index that is not defined inside a continuous range?
In these cases, we can use an IndexedDataset instead.
IndexedDataset¶
First let’s define a Dataset where the samples are indexed by arbitrary unique indexes, such as strings, integers, and tuples.
Suppose the unique indexes for our 5 samples are:
sample_id = [
"cat",
4,
("small", "cow"),
"dog",
0,
]
# Build an IndexedDataset, specifying the unique sample indexes with ``sample_id``
dataset = IndexedDataset(
x=x_data,
y=y_data,
sample_id=sample_id,
)
Now, when we access the dataset, we can access samples by their unique sample index
using the get_sample
method. This method takes a single argument, the sample
index, and returns the corresponding sample.
print(dataset.get_sample("dog"))
print(dataset.get_sample(4))
print(dataset.get_sample(("small", "cow")))
Sample(sample_id='dog', x=tensor([-1.1023, 0.5280, 0.0502]), y=3)
Sample(sample_id=4, x=tensor([ 1.3223, 0.0085, -0.0658]), y=1)
Sample(sample_id=('small', 'cow'), x=tensor([ 1.0664, -0.1478, 0.1371]), y=2)
Note that using __getitem__
, i.e. dataset[4]
, will return the sample passed to
the constructor at position 5. In this case, the sample indexes map to the numeric
indices as follows:
"cat"
4
("small", "cow")
"dog"
0
Thus, accessing the unique sample index "cat"
can be done equivalently with
either of:
print(dataset[0])
print(dataset.get_sample("cat"))
Sample(sample_id='cat', x=tensor([0.3149, 0.2212, 0.6026]), y=0)
Sample(sample_id='cat', x=tensor([0.3149, 0.2212, 0.6026]), y=0)
Note that the named tuple returned in both cases contains the unique sample index as
the sample_id
field, which precedes all other data fields. This is in contrast to
the standard Dataset, which only returns the passed data fields and not the index.
A DataLoader
can be constructed on top of an IndexedDataset
in the same way as a Dataset
. Batches are accessed by iterating over the
DataLoader
, though this time the Batch
named tuple returned by the
data loader will contain the unique sample indexes sample_id
as the first field.
dataloader = DataLoader(dataset, batch_size=2)
# Iterate over batches
for batch in dataloader:
print(batch)
Batch(sample_id=('cat', 4), x=tensor([[ 0.3149, 0.2212, 0.6026],
[ 1.3223, 0.0085, -0.0658]]), y=(0, 1))
Batch(sample_id=(('small', 'cow'), 'dog'), x=tensor([[ 1.0664, -0.1478, 0.1371],
[-1.1023, 0.5280, 0.0502]]), y=(2, 3))
Batch(sample_id=(0,), x=tensor([[0.9286, 0.3232, 1.1743]]), y=(4,))
As before, we can create separate variables in the iteration pattern
for ids, x, y in dataloader:
print(ids, x, y)
('cat', 4) tensor([[ 0.3149, 0.2212, 0.6026],
[ 1.3223, 0.0085, -0.0658]]) (0, 1)
(('small', 'cow'), 'dog') tensor([[ 1.0664, -0.1478, 0.1371],
[-1.1023, 0.5280, 0.0502]]) (2, 3)
(0,) tensor([[0.9286, 0.3232, 1.1743]]) (4,)
On-disk IndexedDataset
with arbitrary sample indexes¶
When defining an IndexedDataset
with data fields on-disk, i.e. to be
loaded lazily, the sample indexes passed as the sample_id
kwarg to the
constructor are used as the arguments to the load function.
To demonstrate this, as we did in the previous tutorial, let’s save the x
data to
disk and build a mixed in-memory/on-disk IndexedDataset
.
For instance, the below code will save sone x data for the sample "dog"
at
relative path "data/x_dog.pt"
.
# Create a directory to save the dummy x data to disk
os.makedirs("data", exist_ok=True)
for i, x in zip(sample_id, x_data):
torch.save(x, f"data/x_{i}.pt")
We can now define a load function to load data from disk. This should take the unique sample index as a single argument, and return the corresponding data in memory.
def load_x(sample_id):
"""
Loads the x data for the sample indexed by `sample_id` from disk and
returns the object in memory
"""
print(f"loading x for sample {sample_id}")
return torch.load(f"data/x_{sample_id}.pt")
Now when we define an IndexedDataset, the ‘x’ data field can be passed as a callable.
mixed_dataset = IndexedDataset(x=load_x, y=y_data, sample_id=sample_id)
print(mixed_dataset.get_sample("dog"))
print(mixed_dataset.get_sample(("small", "cow")))
loading x for sample dog
/home/runner/work/metatensor/metatensor/python/examples/learn/2-indexed-dataset.py:174: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
return torch.load(f"data/x_{sample_id}.pt")
Sample(sample_id='dog', x=tensor([-1.1023, 0.5280, 0.0502]), y=3)
loading x for sample ('small', 'cow')
Sample(sample_id=('small', 'cow'), x=tensor([ 1.0664, -0.1478, 0.1371]), y=2)
Using an IndexedDataset: subset integer ranges¶
One could also define an IndexedDataset where the samples indices are integers forming a possibly shuffled and non-continuous subset of a larger continuous range of numeric indices.
For instance, imagine we have a global Dataset of 1000 samples, with indices [0, …,
999], but only want to build a dataset for samples with indices [4, 7, 200, 5, 999],
in that order. We can pass these indices kwarg sample_id
.
# Build an IndexedDataset, specifying the subset sample indexes in a specific order
sample_id = [4, 7, 200, 5, 999]
dataset = IndexedDataset(x=x_data, y=y_data, sample_id=sample_id)
Now, when we access the dataset, we can access samples by their unique sample index using the get_sample method. This method takes a single argument, the sample index, and returns the corresponding sample.
Again, the numeric index can be used equivalently to access the sample, and again note
that the Sample
named tuple includes the sample_id
field.
# These return the same sample
print(dataset.get_sample(5))
print(dataset[4])
Sample(sample_id=5, x=tensor([-1.1023, 0.5280, 0.0502]), y=3)
Sample(sample_id=999, x=tensor([0.9286, 0.3232, 1.1743]), y=4)
And finally, the DataLoader behaves as expected:
dataloader = DataLoader(dataset, batch_size=2)
for batch in dataloader:
print(batch)
Batch(sample_id=(4, 7), x=tensor([[ 0.3149, 0.2212, 0.6026],
[ 1.3223, 0.0085, -0.0658]]), y=(0, 1))
Batch(sample_id=(200, 5), x=tensor([[ 1.0664, -0.1478, 0.1371],
[-1.1023, 0.5280, 0.0502]]), y=(2, 3))
Batch(sample_id=(999,), x=tensor([[0.9286, 0.3232, 1.1743]]), y=(4,))
Total running time of the script: (0 minutes 0.014 seconds)