Skip to content

JAX Dataloader¤


source

EpochIterator¤

 EpochIterator (data, batch_size:int, indices:Sequence[int])

source

to_jax_dataset¤

 to_jax_dataset (dataset:jax_dataloader.datasets.Dataset)

source

to_jax_dataset¤

 to_jax_dataset (dataset:jax_dataloader.datasets.Dataset)

source

DataLoaderJAX¤

 DataLoaderJAX (dataset:Union[jax_dataloader.datasets.Dataset,Annotated[Un
                ion[datasets.arrow_dataset.Dataset,datasets.dataset_dict.D
                atasetDict,datasets.dataset_dict.IterableDatasetDict,datas
                ets.iterable_dataset.IterableDataset],beartype.vale.Is[lam
                bda_:hf_datasetsisnotNone]]], batch_size:int=1,
                shuffle:bool=False, num_workers:int=0,
                drop_last:bool=False, **kwargs)

Dataloader Interface

Type Default Details
dataset typing.Union[jax_dataloader.datasets.Dataset, typing.Annotated[typing.Union[datasets.arrow_dataset.Dataset, datasets.dataset_dict.DatasetDict, datasets.dataset_dict.IterableDatasetDict, datasets.iterable_dataset.IterableDataset], beartype.vale.Is[lambda _: hf_datasets is not None]]]
batch_size int 1 batch size
shuffle bool False if true, dataloader shuffles before sampling each batch
num_workers int 0 how many subprocesses to use for data loading. Ignored.
drop_last bool False
kwargs
samples = 1280
batch_size = 12
feats = np.arange(samples).repeat(10).reshape(samples, 10)
labels = np.arange(samples).reshape(samples, 1)
ds = ArrayDataset(feats, labels)
dl = DataLoaderJAX(ds, batch_size=batch_size, shuffle=True)
assert len(dl) == 1280 // 12 + 1
assert len(dl.indices) == 1280
samples = 128
batch_size = 128
feats = np.arange(samples).repeat(10).reshape(samples, 10)
labels = np.arange(samples).reshape(samples, 1)
ds = ArrayDataset(feats, labels)
dl = DataLoaderJAX(ds, batch_size=batch_size, shuffle=True, drop_last=True)
assert len(dl) == 1
dl = DataLoaderJAX(ds, batch_size=batch_size, shuffle=True, drop_last=False)
assert len(dl) == 1
test_dataloader(DataLoaderJAX, samples=1280, batch_size=10)
281 ms ± 27.8 ms per loop (mean ± std. dev. of 3 runs, 5 loops each)