Skip to content

Dataloader for JAX¤

Python CI
status Docs pypi GitHub
License Downloads

Overview¤

jax_dataloader brings pytorch-like dataloader API to jax. It supports

A minimum jax-dataloader example:

import jax_dataloader as jdl

jdl.manual_seed(1234) # Set the global seed to 1234 for reproducibility

dataloader = jdl.DataLoader(
    dataset, # Can be a jdl.Dataset or pytorch or huggingface or tensorflow dataset
    backend='jax', # Use 'jax' backend for loading data
    batch_size=32, # Batch size 
    shuffle=True, # Shuffle the dataloader every iteration or not
    drop_last=False, # Drop the last batch or not
)

batch = next(iter(dataloader)) # iterate next batch

Installation¤

The latest jax-dataloader release can directly be installed from PyPI:

pip install jax-dataloader

or install directly from the repository:

pip install git+https://github.com/BirkhoffG/jax-dataloader.git

Note

We keep jax-dataloader's dependencies minimum, which only install jax and plum-dispatch (for backend dispatching) when installing. If you wish to use integration of pytorch, huggingface datasets, or tensorflow, we highly recommend manually install those dependencies.

You can also run pip install jax-dataloader[all] to install everything (not recommended).

Usage¤

jax_dataloader.core.DataLoader follows similar API as the pytorch dataloader.

  • The dataset should be an object of the subclass of jax_dataloader.core.Dataset or torch.utils.data.Dataset or (the huggingface) datasets.Dataset or tf.data.Dataset.
  • The backend should be one of "jax" or "pytorch" or "tensorflow". This argument specifies which backend dataloader to load batches.

Note that not every dataset is compatible with every backend. See the compatibility table below:

jdl.Dataset torch_data.Dataset tf.data.Dataset datasets.Dataset
"jax"
"pytorch"
"tensorflow"

Using ArrayDataset¤

The jax_dataloader.core.ArrayDataset is an easy way to wrap multiple jax.numpy.array into one Dataset. For example, we can create an ArrayDataset as follows:

# Create features `X` and labels `y`
X = jnp.arange(100).reshape(10, 10)
y = jnp.arange(10)
# Create an `ArrayDataset`
arr_ds = jdl.ArrayDataset(X, y)

This arr_ds can be loaded by every backends.

# Create a `DataLoader` from the `ArrayDataset` via jax backend
dataloader = jdl.DataLoader(arr_ds, 'jax', batch_size=5, shuffle=True)
# Or we can use the pytorch backend
dataloader = jdl.DataLoader(arr_ds, 'pytorch', batch_size=5, shuffle=True)
# Or we can use the tensorflow backend
dataloader = jdl.DataLoader(arr_ds, 'tensorflow', batch_size=5, shuffle=True)

Using Huggingface Datasets¤

The huggingface datasets is a morden library for downloading, pre-processing, and sharing datasets. jax_dataloader supports directly passing the huggingface datasets.

from datasets import load_dataset

For example, We load the "squad" dataset from datasets:

hf_ds = load_dataset("squad")

Then, we can use jax_dataloader to load batches of hf_ds.

# Create a `DataLoader` from the `datasets.Dataset` via jax backend
dataloader = jdl.DataLoader(hf_ds['train'], 'jax', batch_size=5, shuffle=True)
# Or we can use the pytorch backend
dataloader = jdl.DataLoader(hf_ds['train'], 'pytorch', batch_size=5, shuffle=True)
# Or we can use the tensorflow backend
dataloader = jdl.DataLoader(hf_ds['train'], 'tensorflow', batch_size=5, shuffle=True)

Using Pytorch Datasets¤

The pytorch Dataset and its ecosystems (e.g., torchvision, torchtext, torchaudio) supports many built-in datasets. jax_dataloader supports directly passing the pytorch Dataset.

Note

Unfortuantely, the pytorch Dataset can only work with backend=pytorch. See the belowing example.

from torchvision.datasets import MNIST
import numpy as np

We load the MNIST dataset from torchvision. The ToNumpy object transforms images to numpy.array.

pt_ds = MNIST('/tmp/mnist/', download=True, transform=lambda x: np.array(x, dtype=float), train=False)

This pt_ds can only be loaded via "pytorch" dataloaders.

dataloader = jdl.DataLoader(pt_ds, 'pytorch', batch_size=5, shuffle=True)

Using Tensowflow Datasets¤

jax_dataloader supports directly passing the tensorflow datasets.

import tensorflow_datasets as tfds
import tensorflow as tf

For instance, we can load the MNIST dataset from tensorflow_datasets

tf_ds = tfds.load('mnist', split='test', as_supervised=True)

and use jax_dataloader for iterating the dataset.

dataloader = jdl.DataLoader(tf_ds, 'tensorflow', batch_size=5, shuffle=True)