Source code for metatensor.learn.data.dataloader
"""
Module containing the DataLoader.
"""
import torch
from .collate import group_and_join
[docs]
class DataLoader(torch.utils.data.DataLoader):
"""
Class for loading data from an :py:class:`torch.utils.data.Dataset` object with a
default collate function (:py:func:`metatensor.learn.data.group_and_join`) that
supports :py:class:`torch.Tensor`, :py:class:`metatensor.torch.atomistic.System`,
:py:class:`metatensor.TensorMap`, and :py:class:`metatensor.torch.TensorMap`.
The dataset wil typically be :py:class:`metatensor.learn.Dataset` or
:py:class:`metatensor.learn.IndexedDataset`.
Any argument as accepted by the torch :py:class:`torch.utils.data.DataLoader` parent
class is supported.
"""
def __init__(self, dataset, collate_fn=group_and_join, **kwargs):
super().__init__(dataset, collate_fn=collate_fn, **kwargs)