JAX Dataloader¤
EpochIterator¤
EpochIterator (data, batch_size:int, indices:Sequence[int])
to_jax_dataset¤
to_jax_dataset (dataset:jax_dataloader.datasets.Dataset)
to_jax_dataset¤
to_jax_dataset (dataset:jax_dataloader.datasets.Dataset)
DataLoaderJAX¤
DataLoaderJAX (dataset:Union[jax_dataloader.datasets.Dataset,Annotated[Un ion[datasets.arrow_dataset.Dataset,datasets.dataset_dict.D atasetDict,datasets.dataset_dict.IterableDatasetDict,datas ets.iterable_dataset.IterableDataset],beartype.vale.Is[lam bda_:hf_datasetsisnotNone]]], batch_size:int=1, shuffle:bool=False, num_workers:int=0, drop_last:bool=False, **kwargs)
Dataloader Interface
Type | Default | Details | |
---|---|---|---|
dataset | typing.Union[jax_dataloader.datasets.Dataset, typing.Annotated[typing.Union[datasets.arrow_dataset.Dataset, datasets.dataset_dict.DatasetDict, datasets.dataset_dict.IterableDatasetDict, datasets.iterable_dataset.IterableDataset], beartype.vale.Is[lambda _: hf_datasets is not None]]] | ||
batch_size | int | 1 | batch size |
shuffle | bool | False | if true, dataloader shuffles before sampling each batch |
num_workers | int | 0 | how many subprocesses to use for data loading. Ignored. |
drop_last | bool | False | |
kwargs |
samples = 1280
batch_size = 12
feats = np.arange(samples).repeat(10).reshape(samples, 10)
labels = np.arange(samples).reshape(samples, 1)
ds = ArrayDataset(feats, labels)
dl = DataLoaderJAX(ds, batch_size=batch_size, shuffle=True)
assert len(dl) == 1280 // 12 + 1
assert len(dl.indices) == 1280
samples = 128
batch_size = 128
feats = np.arange(samples).repeat(10).reshape(samples, 10)
labels = np.arange(samples).reshape(samples, 1)
ds = ArrayDataset(feats, labels)
dl = DataLoaderJAX(ds, batch_size=batch_size, shuffle=True, drop_last=True)
assert len(dl) == 1
dl = DataLoaderJAX(ds, batch_size=batch_size, shuffle=True, drop_last=False)
assert len(dl) == 1
test_dataloader(DataLoaderJAX, samples=1280, batch_size=10)
281 ms ± 27.8 ms per loop (mean ± std. dev. of 3 runs, 5 loops each)