Source code for metatensor.learn.data.dataloader
"""
Module containing the DataLoader.
"""
from typing import Callable
import torch
from .collate import group_and_join
[docs]
class DataLoader(torch.utils.data.DataLoader):
"""
Class for loading data from an `:py:class:`torch.utils.data.Dataset` object
with a default collate function that supports :py:class:`torch.Tensor`,
:py:class:`atomistic.Systems`, or :py:class:`TensorMap`.
Any argument as accepted by the torch
:py:class:`torch.utils.data.DataLoader` parent class is supported. Please
refer to
https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader
"""
def __init__(
self,
dataset: torch.utils.data.Dataset,
collate_fn: Callable = group_and_join,
**kwargs,
):
super().__init__(dataset, collate_fn=collate_fn, **kwargs)