Dataloader for JAX¤
Overview¤
jax_dataloader
brings pytorch-like dataloader API to jax
. It
supports
-
4 datasets to download and pre-process data:
- huggingface datasets
- pytorch Dataset
-
3 backends to iteratively load batches:
- pytorch dataloader
- tensorflow dataset
A minimum jax-dataloader
example:
import jax_dataloader as jdl
jdl.manual_seed(1234) # Set the global seed to 1234 for reproducibility
dataloader = jdl.DataLoader(
dataset, # Can be a jdl.Dataset or pytorch or huggingface or tensorflow dataset
backend='jax', # Use 'jax' backend for loading data
batch_size=32, # Batch size
shuffle=True, # Shuffle the dataloader every iteration or not
drop_last=False, # Drop the last batch or not
)
batch = next(iter(dataloader)) # iterate next batch
Installation¤
The latest jax-dataloader
release can directly be installed from PyPI:
pip install jax-dataloader
or install directly from the repository:
pip install git+https://github.com/BirkhoffG/jax-dataloader.git
Note
We keep jax-dataloader
's dependencies minimum, which only install jax
and plum-dispatch
(for backend dispatching) when installing.
If you wish to use integration of
pytorch
,
huggingface datasets
,
or tensorflow
,
we highly recommend manually install those dependencies.
You can also run pip install jax-dataloader[all]
to install everything (not recommended).
Usage¤
jax_dataloader.core.DataLoader
follows similar API as the pytorch dataloader.
- The
dataset
should be an object of the subclass ofjax_dataloader.core.Dataset
ortorch.utils.data.Dataset
or (the huggingface)datasets.Dataset
ortf.data.Dataset
. - The
backend
should be one of"jax"
or"pytorch"
or"tensorflow"
. This argument specifies which backend dataloader to load batches.
Note that not every dataset is compatible with every backend. See the compatibility table below:
jdl.Dataset |
torch_data.Dataset |
tf.data.Dataset |
datasets.Dataset |
|
---|---|---|---|---|
"jax" |
✅ | ❌ | ❌ | ✅ |
"pytorch" |
✅ | ✅ | ❌ | ✅ |
"tensorflow" |
✅ | ❌ | ✅ | ✅ |
Using ArrayDataset
¤
The jax_dataloader.core.ArrayDataset
is an easy way to wrap multiple
jax.numpy.array
into one Dataset. For example, we can create an
ArrayDataset
as follows:
# Create features `X` and labels `y`
X = jnp.arange(100).reshape(10, 10)
y = jnp.arange(10)
# Create an `ArrayDataset`
arr_ds = jdl.ArrayDataset(X, y)
This arr_ds
can be loaded by every backends.
# Create a `DataLoader` from the `ArrayDataset` via jax backend
dataloader = jdl.DataLoader(arr_ds, 'jax', batch_size=5, shuffle=True)
# Or we can use the pytorch backend
dataloader = jdl.DataLoader(arr_ds, 'pytorch', batch_size=5, shuffle=True)
# Or we can use the tensorflow backend
dataloader = jdl.DataLoader(arr_ds, 'tensorflow', batch_size=5, shuffle=True)
Using Huggingface Datasets¤
The huggingface datasets is a
morden library for downloading, pre-processing, and sharing datasets.
jax_dataloader
supports directly passing the huggingface datasets.
from datasets import load_dataset
For example, We load the "squad"
dataset from datasets
:
hf_ds = load_dataset("squad")
Then, we can use jax_dataloader
to load batches of hf_ds
.
# Create a `DataLoader` from the `datasets.Dataset` via jax backend
dataloader = jdl.DataLoader(hf_ds['train'], 'jax', batch_size=5, shuffle=True)
# Or we can use the pytorch backend
dataloader = jdl.DataLoader(hf_ds['train'], 'pytorch', batch_size=5, shuffle=True)
# Or we can use the tensorflow backend
dataloader = jdl.DataLoader(hf_ds['train'], 'tensorflow', batch_size=5, shuffle=True)
Using Pytorch Datasets¤
The pytorch Dataset and its
ecosystems (e.g.,
torchvision,
torchtext,
torchaudio) supports many
built-in datasets. jax_dataloader
supports directly passing the
pytorch Dataset.
Note
Unfortuantely, the pytorch Dataset
can only work with backend=pytorch
. See the belowing example.
from torchvision.datasets import MNIST
import numpy as np
We load the MNIST dataset from torchvision
. The ToNumpy
object
transforms images to numpy.array
.
pt_ds = MNIST('/tmp/mnist/', download=True, transform=lambda x: np.array(x, dtype=float), train=False)
This pt_ds
can only be loaded via "pytorch"
dataloaders.
dataloader = jdl.DataLoader(pt_ds, 'pytorch', batch_size=5, shuffle=True)
Using Tensowflow Datasets¤
jax_dataloader
supports directly passing the tensorflow
datasets.
import tensorflow_datasets as tfds
import tensorflow as tf
For instance, we can load the MNIST dataset from tensorflow_datasets
tf_ds = tfds.load('mnist', split='test', as_supervised=True)
and use jax_dataloader
for iterating the dataset.
dataloader = jdl.DataLoader(tf_ds, 'tensorflow', batch_size=5, shuffle=True)