Skip to content

DataLoader

jax_dataloader.core.DataLoader ¤

Main Dataloader class to load Numpy data batches

Source code in jax_dataloader/core.py
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
class DataLoader:
    """Main Dataloader class to load Numpy data batches"""

    def __init__(
        self,
        dataset, # Dataset from which to load the data
        backend: Literal['jax', 'pytorch', 'tensorflow'], # Dataloader backend to load the dataset
        batch_size: int = 1,  # How many samples per batch to load
        shuffle: bool = False,  # If true, dataloader reshuffles every epoch
        drop_last: bool = False, # If true, drop the last incomplete batch
        **kwargs
    ):
        dl_cls = _dispatch_dataloader(backend)
        self.dataloader = dl_cls(
            dataset=dataset, 
            batch_size=batch_size, 
            shuffle=shuffle, 
            drop_last=drop_last,
            **kwargs
        )

    def __len__(self):
        return len(self.dataloader)

    def __next__(self):
        return next(self.dataloader)

    def __iter__(self):
        return iter(self.dataloader)

Attributes¤

dataloader = dl_cls(dataset=dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, **kwargs) instance-attribute ¤

Functions¤

__init__(dataset, backend: Literal['jax', 'pytorch', 'tensorflow'], batch_size: int = 1, shuffle: bool = False, drop_last: bool = False, **kwargs) ¤

Source code in jax_dataloader/core.py
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
def __init__(
    self,
    dataset, # Dataset from which to load the data
    backend: Literal['jax', 'pytorch', 'tensorflow'], # Dataloader backend to load the dataset
    batch_size: int = 1,  # How many samples per batch to load
    shuffle: bool = False,  # If true, dataloader reshuffles every epoch
    drop_last: bool = False, # If true, drop the last incomplete batch
    **kwargs
):
    dl_cls = _dispatch_dataloader(backend)
    self.dataloader = dl_cls(
        dataset=dataset, 
        batch_size=batch_size, 
        shuffle=shuffle, 
        drop_last=drop_last,
        **kwargs
    )

__iter__() ¤

Source code in jax_dataloader/core.py
116
117
def __iter__(self):
    return iter(self.dataloader)

__len__() ¤

Source code in jax_dataloader/core.py
110
111
def __len__(self):
    return len(self.dataloader)

__next__() ¤

Source code in jax_dataloader/core.py
113
114
def __next__(self):
    return next(self.dataloader)