Pytorch-backed Dataloader Use Pytorch to load batches. It requires pytorch to be installed.
Pytorch
-backed Dataloader¤
Use Pytorch
to load batches. It requires
pytorch to be installed.
to_torch_dataset¤
to_torch_dataset (dataset:jax_dataloader.datasets.Dataset)
to_torch_dataset¤
to_torch_dataset (dataset:jax_dataloader.datasets.Dataset)
to_torch_dataset¤
to_torch_dataset (dataset:jax_dataloader.datasets.Dataset)
DataLoaderPytorch¤
DataLoaderPytorch (dataset:Union[jax_dataloader.datasets.Dataset,torch.ut ils.data.dataset.Dataset,Annotated[Union[datasets.arro w_dataset.Dataset,datasets.dataset_dict.DatasetDict,da tasets.dataset_dict.IterableDatasetDict,datasets.itera ble_dataset.IterableDataset],beartype.vale.Is[lambda_: hf_datasetsisnotNone]]], batch_size:int=1, shuffle:bool=False, drop_last:bool=False, **kwargs)
Pytorch Dataloader
Type | Default | Details | |
---|---|---|---|
dataset | typing.Union[jax_dataloader.datasets.Dataset, torch.utils.data.dataset.Dataset, typing.Annotated[typing.Union[datasets.arrow_dataset.Dataset, datasets.dataset_dict.DatasetDict, datasets.dataset_dict.IterableDatasetDict, datasets.iterable_dataset.IterableDataset], beartype.vale.Is[lambda _: hf_datasets is not None]]] | ||
batch_size | int | 1 | Batch size |
shuffle | bool | False | If true, dataloader shuffles before sampling each batch |
drop_last | bool | False | Drop last batch or not |
kwargs |
samples = 1280
batch_size = 12
feats = np.arange(samples).repeat(10).reshape(samples, 10)
labels = np.arange(samples).reshape(samples, 1)
ds_torch = torch_data.TensorDataset(torch.from_numpy(feats), torch.from_numpy(labels))
ds_array = ArrayDataset(feats, labels)
dl_1 = DataLoaderPytorch(ds_torch, batch_size=batch_size, shuffle=True)
for _ in range(10):
for (x, y) in dl_1:
assert isinstance(x, np.ndarray)
dl_2 = DataLoaderPytorch(ds_array, batch_size=batch_size, shuffle=True)
for (x, y) in dl_2:
assert isinstance(x, np.ndarray)