Main Dataloader class to load Numpy data batches
Source code in jax_dataloader/core.py
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117 | class DataLoader:
"""Main Dataloader class to load Numpy data batches"""
def __init__(
self,
dataset, # Dataset from which to load the data
backend: Literal['jax', 'pytorch', 'tensorflow'], # Dataloader backend to load the dataset
batch_size: int = 1, # How many samples per batch to load
shuffle: bool = False, # If true, dataloader reshuffles every epoch
drop_last: bool = False, # If true, drop the last incomplete batch
**kwargs
):
dl_cls = _dispatch_dataloader(backend)
self.dataloader = dl_cls(
dataset=dataset,
batch_size=batch_size,
shuffle=shuffle,
drop_last=drop_last,
**kwargs
)
def __len__(self):
return len(self.dataloader)
def __next__(self):
return next(self.dataloader)
def __iter__(self):
return iter(self.dataloader)
|
Attributes
dataloader = dl_cls(dataset=dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, **kwargs)
instance-attribute
Functions
__init__(dataset, backend: Literal['jax', 'pytorch', 'tensorflow'], batch_size: int = 1, shuffle: bool = False, drop_last: bool = False, **kwargs)
Source code in jax_dataloader/core.py
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108 | def __init__(
self,
dataset, # Dataset from which to load the data
backend: Literal['jax', 'pytorch', 'tensorflow'], # Dataloader backend to load the dataset
batch_size: int = 1, # How many samples per batch to load
shuffle: bool = False, # If true, dataloader reshuffles every epoch
drop_last: bool = False, # If true, drop the last incomplete batch
**kwargs
):
dl_cls = _dispatch_dataloader(backend)
self.dataloader = dl_cls(
dataset=dataset,
batch_size=batch_size,
shuffle=shuffle,
drop_last=drop_last,
**kwargs
)
|
__iter__()
Source code in jax_dataloader/core.py
| def __iter__(self):
return iter(self.dataloader)
|
__len__()
Source code in jax_dataloader/core.py
| def __len__(self):
return len(self.dataloader)
|
__next__()
Source code in jax_dataloader/core.py
| def __next__(self):
return next(self.dataloader)
|