Skip to content

Pytorch-backed Dataloader Use Pytorch to load batches. It requires pytorch to be installed.

Pytorch-backed Dataloader¤

Use Pytorch to load batches. It requires pytorch to be installed.


source

to_torch_dataset¤

 to_torch_dataset (dataset:jax_dataloader.datasets.Dataset)

source

to_torch_dataset¤

 to_torch_dataset (dataset:jax_dataloader.datasets.Dataset)

source

to_torch_dataset¤

 to_torch_dataset (dataset:jax_dataloader.datasets.Dataset)

source

DataLoaderPytorch¤

 DataLoaderPytorch (dataset:Union[jax_dataloader.datasets.Dataset,torch.ut
                    ils.data.dataset.Dataset,Annotated[Union[datasets.arro
                    w_dataset.Dataset,datasets.dataset_dict.DatasetDict,da
                    tasets.dataset_dict.IterableDatasetDict,datasets.itera
                    ble_dataset.IterableDataset],beartype.vale.Is[lambda_:
                    hf_datasetsisnotNone]]], batch_size:int=1,
                    shuffle:bool=False, drop_last:bool=False, **kwargs)

Pytorch Dataloader

Type Default Details
dataset typing.Union[jax_dataloader.datasets.Dataset, torch.utils.data.dataset.Dataset, typing.Annotated[typing.Union[datasets.arrow_dataset.Dataset, datasets.dataset_dict.DatasetDict, datasets.dataset_dict.IterableDatasetDict, datasets.iterable_dataset.IterableDataset], beartype.vale.Is[lambda _: hf_datasets is not None]]]
batch_size int 1 Batch size
shuffle bool False If true, dataloader shuffles before sampling each batch
drop_last bool False Drop last batch or not
kwargs
samples = 1280
batch_size = 12
feats = np.arange(samples).repeat(10).reshape(samples, 10)
labels = np.arange(samples).reshape(samples, 1)

ds_torch = torch_data.TensorDataset(torch.from_numpy(feats), torch.from_numpy(labels))
ds_array = ArrayDataset(feats, labels)
dl_1 = DataLoaderPytorch(ds_torch, batch_size=batch_size, shuffle=True)

for _ in range(10):
    for (x, y) in dl_1: 
        assert isinstance(x, np.ndarray)

dl_2 = DataLoaderPytorch(ds_array, batch_size=batch_size, shuffle=True)
for (x, y) in dl_2: 
    assert isinstance(x, np.ndarray)