Using IndexedDataset
====================

.. py:currentmodule:: metatensor.torch.learn.data

.. GENERATED FROM PYTHON SOURCE LINES 10-18

.. code-block:: Python

   import os

   import torch

   from metatensor.learn.data import DataLoader, Dataset, IndexedDataset

.. GENERATED FROM PYTHON SOURCE LINES 19-35

Review of the standard Dataset
------------------------------

The previous tutorial, :ref:`learn-tutorial-dataset-dataloader`, showed how to
define a :py:class:`Dataset` able to handle both torch tensor and metatensor
TensorMap. We saw that in-memory, on-disk, or mixed in-memory/on-disk datasets can be defined. DataLoaders are then defined on top of these Dataset objects. In all cases, however, each data sample is accessed by a numeric integer index, which ranges from 0 to ``len(dataset) - 1``. Let's use a simple example to review this. Again let's define some dummy data as before. Our x data as a list of random tensors, and our y data as a list of integers that enumerate the samples. For the purposes of this tutorial, we will only focus on an in-memory dataset, though the same principles apply to on-disk and mixed datasets. .. GENERATED FROM PYTHON SOURCE LINES 36-43 .. code-block:: Python n_samples = 5 x_data = [torch.randn(3) for _ in range(n_samples)] y_data = [i for i in range(n_samples)] dataset = Dataset(x=x_data, y=y_data) .. GENERATED FROM PYTHON SOURCE LINES 44-49 A sample is accessed by its numeric index. As the length of the lists passed as kwargs is 5, both for ``x`` and ``y``, the valid indices are [0, 1, 2, 3, 4]. Let's retrieve the 4th sample (index 3) and print it. The value of the "y" data field should be 3. .. GENERATED FROM PYTHON SOURCE LINES 50-53 .. code-block:: Python print(dataset[3]) .. rst-class:: sphx-glr-script-out .. code-block:: none Sample(x=tensor([ 0.5957, 0.6598, -2.6910]), y=3) .. GENERATED FROM PYTHON SOURCE LINES 54-62 What if we wanted to access samples by something other than an integer index part of a continuous range? For instance, what if we wanted to access samples by: 1. a string id, or other arbitrary hashable object? 2. an integer index that is not defined inside a continuous range? In these cases, we can use an IndexedDataset instead. .. GENERATED FROM PYTHON SOURCE LINES 65-72 IndexedDataset -------------- First let's define a Dataset where the samples are indexed by arbitrary unique indexes, such as strings, integers, and tuples. Suppose the unique indexes for our 5 samples are: .. GENERATED FROM PYTHON SOURCE LINES 73-89 .. code-block:: Python sample_id = [ "cat", 4, ("small", "cow"), "dog", 0, ] # Build an IndexedDataset, specifying the unique sample indexes with ``sample_id`` dataset = IndexedDataset( x=x_data, y=y_data, sample_id=sample_id, ) .. GENERATED FROM PYTHON SOURCE LINES 90-93 Now, when we access the dataset, we can access samples by their unique sample index using the ``get_sample`` method. This method takes a single argument, the sample index, and returns the corresponding sample. .. GENERATED FROM PYTHON SOURCE LINES 94-99 .. code-block:: Python print(dataset.get_sample("dog")) print(dataset.get_sample(4)) print(dataset.get_sample(("small", "cow"))) .. rst-class:: sphx-glr-script-out .. code-block:: none Sample(sample_id='dog', x=tensor([ 0.5957, 0.6598, -2.6910]), y=3) Sample(sample_id=4, x=tensor([ 0.0851, 0.6831, -0.2733]), y=1) Sample(sample_id=('small', 'cow'), x=tensor([-1.3655, -1.9929, 0.2798]), y=2) .. GENERATED FROM PYTHON SOURCE LINES 100-112 Note that using ``__getitem__``, i.e. ``dataset[4]``, will return the sample passed to the constructor at position 5. In this case, the sample indexes map to the numeric indices as follows: 0. ``"cat"`` 1. ``4`` 2. ``("small", "cow")`` 3. ``"dog"`` 4. ``0`` Thus, accessing the unique sample index ``"cat"`` can be done equivalently with either of: .. GENERATED FROM PYTHON SOURCE LINES 113-118 .. code-block:: Python print(dataset[0]) print(dataset.get_sample("cat")) .. rst-class:: sphx-glr-script-out .. code-block:: none Sample(sample_id='cat', x=tensor([-2.5825, -1.1212, -0.7937]), y=0) Sample(sample_id='cat', x=tensor([-2.5825, -1.1212, -0.7937]), y=0) .. GENERATED FROM PYTHON SOURCE LINES 119-127 Note that the named tuple returned in both cases contains the unique sample index as the ``sample_id`` field, which precedes all other data fields. This is in contrast to the standard Dataset, which only returns the passed data fields and not the index. A :py:class:`DataLoader` can be constructed on top of an :py:class:`IndexedDataset` in the same way as a :py:class:`Dataset`. Batches are accessed by iterating over the :py:class:`DataLoader`, though this time the ``Batch`` named tuple returned by the data loader will contain the unique sample indexes ``sample_id`` as the first field. .. GENERATED FROM PYTHON SOURCE LINES 128-135 .. code-block:: Python dataloader = DataLoader(dataset, batch_size=2) # Iterate over batches for batch in dataloader: print(batch) .. rst-class:: sphx-glr-script-out .. code-block:: none Batch(sample_id=('cat', 4), x=tensor([[-2.5825, -1.1212, -0.7937], [ 0.0851, 0.6831, -0.2733]]), y=(0, 1)) Batch(sample_id=(('small', 'cow'), 'dog'), x=tensor([[-1.3655, -1.9929, 0.2798], [ 0.5957, 0.6598, -2.6910]]), y=(2, 3)) Batch(sample_id=(0,), x=tensor([[-0.5695, -0.7348, 0.3795]]), y=(4,)) .. GENERATED FROM PYTHON SOURCE LINES 136-137 As before, we can create separate variables in the iteration pattern .. GENERATED FROM PYTHON SOURCE LINES 138-141 .. code-block:: Python for ids, x, y in dataloader: print(ids, x, y) .. rst-class:: sphx-glr-script-out .. code-block:: none ('cat', 4) tensor([[-2.5825, -1.1212, -0.7937], [ 0.0851, 0.6831, -0.2733]]) (0, 1) (('small', 'cow'), 'dog') tensor([[-1.3655, -1.9929, 0.2798], [ 0.5957, 0.6598, -2.6910]]) (2, 3) (0,) tensor([[-0.5695, -0.7348, 0.3795]]) (4,) .. GENERATED FROM PYTHON SOURCE LINES 142-154 On-disk :py:class:`IndexedDataset` with arbitrary sample indexes ----------------------------------------------------------------- When defining an :py:class:`IndexedDataset` with data fields on-disk, i.e. to be loaded lazily, the sample indexes passed as the ``sample_id`` kwarg to the constructor are used as the arguments to the load function. To demonstrate this, as we did in the previous tutorial, let's save the ``x`` data to disk and build a mixed in-memory/on-disk :py:class:`IndexedDataset`. For instance, the below code will save sone x data for the sample ``"dog"`` at relative path ``"data/x_dog.pt"``. .. GENERATED FROM PYTHON SOURCE LINES 155-162 .. code-block:: Python # Create a directory to save the dummy x data to disk os.makedirs("data", exist_ok=True) for i, x in zip(sample_id, x_data): torch.save(x, f"data/x_{i}.pt") .. GENERATED FROM PYTHON SOURCE LINES 163-165 We can now define a load function to load data from disk. This should take the unique sample index as a single argument, and return the corresponding data in memory. .. GENERATED FROM PYTHON SOURCE LINES 166-177 .. code-block:: Python def load_x(sample_id): """ Loads the x data for the sample indexed by `sample_id` from disk and returns the object in memory """ print(f"loading x for sample {sample_id}") return torch.load(f"data/x_{sample_id}.pt") .. GENERATED FROM PYTHON SOURCE LINES 178-180 Now when we define an IndexedDataset, the 'x' data field can be passed as a callable. .. GENERATED FROM PYTHON SOURCE LINES 181-187 .. code-block:: Python mixed_dataset = IndexedDataset(x=load_x, y=y_data, sample_id=sample_id) print(mixed_dataset.get_sample("dog")) print(mixed_dataset.get_sample(("small", "cow"))) .. rst-class:: sphx-glr-script-out .. code-block:: none loading x for sample dog /home/runner/work/metatensor/metatensor/python/examples/learn/2-indexed-dataset.py:174: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. return torch.load(f"data/x_{sample_id}.pt") Sample(sample_id='dog', x=tensor([ 0.5957, 0.6598, -2.6910]), y=3) loading x for sample ('small', 'cow') Sample(sample_id=('small', 'cow'), x=tensor([-1.3655, -1.9929, 0.2798]), y=2) .. GENERATED FROM PYTHON SOURCE LINES 188-198 Using an IndexedDataset: subset integer ranges ---------------------------------------------- One could also define an IndexedDataset where the samples indices are integers forming a possibly shuffled and non-continuous subset of a larger continuous range of numeric indices. For instance, imagine we have a global Dataset of 1000 samples, with indices [0, ..., 999], but only want to build a dataset for samples with indices [4, 7, 200, 5, 999], in that order. We can pass these indices kwarg ``sample_id``. .. GENERATED FROM PYTHON SOURCE LINES 199-204 .. code-block:: Python # Build an IndexedDataset, specifying the subset sample indexes in a specific order sample_id = [4, 7, 200, 5, 999] dataset = IndexedDataset(x=x_data, y=y_data, sample_id=sample_id) .. GENERATED FROM PYTHON SOURCE LINES 205-211 Now, when we access the dataset, we can access samples by their unique sample index using the `get_sample` method. This method takes a single argument, the sample index, and returns the corresponding sample. Again, the numeric index can be used equivalently to access the sample, and again note that the ``Sample`` named tuple includes the ``sample_id`` field. .. GENERATED FROM PYTHON SOURCE LINES 212-217 .. code-block:: Python # These return the same sample print(dataset.get_sample(5)) print(dataset[4]) .. rst-class:: sphx-glr-script-out .. code-block:: none Sample(sample_id=5, x=tensor([ 0.5957, 0.6598, -2.6910]), y=3) Sample(sample_id=999, x=tensor([-0.5695, -0.7348, 0.3795]), y=4) .. GENERATED FROM PYTHON SOURCE LINES 218-219 And finally, the DataLoader behaves as expected: .. 