Skip to content

Pytorch-backed Dataloader Use Pytorch to load batches. It requires pytorch to be installed.

Pytorch-backed Dataloader¤

Use Pytorch to load batches. It requires pytorch to be installed.


source

to_torch_dataset¤

to_torch_dataset (dataset:jax_dataloader.datasets.Dataset)


source

to_torch_dataset¤

to_torch_dataset (dataset:jax_dataloader.datasets.Dataset)


source

to_torch_dataset¤

to_torch_dataset (dataset:jax_dataloader.datasets.Dataset)


source

DataLoaderPytorch¤

DataLoaderPytorch (dataset:Union[jax_dataloader.datasets.Dataset,torch.ut ils.data.dataset.Dataset,Annotated[Union[datasets.arro w_dataset.Dataset,datasets.dataset_dict.DatasetDict,da tasets.dataset_dict.IterableDatasetDict,datasets.itera ble_dataset.IterableDataset],beartype.vale.Is[lambda_: hf_datasetsisnotNone]]], batch_size:int=1, shuffle:bool=False, drop_last:bool=False, generator:Un ion[jax_dataloader.utils.Generator,jax.Array,ForwardRe f('torch.Generator'),NoneType]=None, **kwargs)

Pytorch Dataloader

Type Default Details
dataset Union
batch_size int 1 Batch size
shuffle bool False If true, dataloader shuffles before sampling each batch
drop_last bool False Drop last batch or not
generator Union None
kwargs
samples = 1280
batch_size = 12
feats = np.arange(samples).repeat(10).reshape(samples, 10)
labels = np.arange(samples).reshape(samples, 1)

ds_torch = torch_data.TensorDataset(torch.from_numpy(feats), torch.from_numpy(labels))
ds_array = ArrayDataset(feats, labels)
dl_1 = DataLoaderPytorch(ds_torch, batch_size=batch_size, shuffle=True)

for _ in range(10):
    for (x, y) in dl_1: 
        assert isinstance(x, np.ndarray)

dl_2 = DataLoaderPytorch(ds_array, batch_size=batch_size, shuffle=True)
for (x, y) in dl_2: 
    assert isinstance(x, np.ndarray)