Skip to content

BaseDataLoader

jax_dataloader.loaders.base.BaseDataLoader ¤

Dataloader Interface

Source code in jax_dataloader/loaders/base.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
class BaseDataLoader:
    """Dataloader Interface"""

    def __init__(
        self, 
        dataset, 
        batch_size: int = 1,  # batch size
        shuffle: bool = False,  # if true, dataloader shuffles before sampling each batch
        num_workers: int = 0,  # how many subprocesses to use for data loading.
        drop_last: bool = False,
        **kwargs
    ):
        pass

    def __len__(self):
        raise NotImplementedError

    def __next__(self):
        raise NotImplementedError

    def __iter__(self):
        raise NotImplementedError

Functions¤

__init__(dataset, batch_size: int = 1, shuffle: bool = False, num_workers: int = 0, drop_last: bool = False, **kwargs) ¤

Source code in jax_dataloader/loaders/base.py
14
15
16
17
18
19
20
21
22
23
def __init__(
    self, 
    dataset, 
    batch_size: int = 1,  # batch size
    shuffle: bool = False,  # if true, dataloader shuffles before sampling each batch
    num_workers: int = 0,  # how many subprocesses to use for data loading.
    drop_last: bool = False,
    **kwargs
):
    pass

__iter__() ¤

Source code in jax_dataloader/loaders/base.py
31
32
def __iter__(self):
    raise NotImplementedError

__len__() ¤

Source code in jax_dataloader/loaders/base.py
25
26
def __len__(self):
    raise NotImplementedError

__next__() ¤

Source code in jax_dataloader/loaders/base.py
28
29
def __next__(self):
    raise NotImplementedError