Skip to content

Core API¤


source

DataloaderBackends¤

 DataloaderBackends (pytorch:BaseDataLoader=<class
                     'jax_dataloader.loaders.torch.DataLoaderPytorch'>,
                     tensorflow:BaseDataLoader=<class 'jax_dataloader.load
                     ers.tensorflow.DataLoaderTensorflow'>,
                     merlin:BaseDataLoader=None)

source

get_backend_compatibilities¤

 get_backend_compatibilities ()

Return list of supported dataloader backends for each dataset type


source

DataLoader¤

 DataLoader (dataset, backend:"Literal['jax','pytorch','tensorflow']",
             batch_size:int=1, shuffle:bool=False, drop_last:bool=False,
             **kwargs)

Main Dataloader class to load Numpy data batches

Type Default Details
dataset Dataset from which to load the data
backend typing.Literal[‘jax’, ‘pytorch’, ‘tensorflow’] Dataloader backend to load the dataset
batch_size int 1 How many samples per batch to load
shuffle bool False If true, dataloader reshuffles every epoch
drop_last bool False If true, drop the last incomplete batch
kwargs

A Minimum Example of using Dataloader¤

We showcase how to use Dataloader for training a simple regression model.

from sklearn.datasets import make_regression
import optax
import haiku as hk
X, y = make_regression(n_samples=500, n_features=20)
dataset = ArrayDataset(X, y.reshape(-1, 1))
keys = hk.PRNGSequence(0)
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.

Define loss, step, train:

def loss(w, x, y):
    return jnp.mean(vmap(optax.l2_loss)(x @ w.T, y))

def step(w, x, y):
    lr = 0.1
    grad = jax.grad(loss)(w, x, y)
    w -= lr * grad
    return w

def train(dataloader: DataLoader, key: jax.random.PRNGKey):
    w = jrand.normal(key, shape=(1, 20))
    n_epochs = 10
    for _ in range(n_epochs):
        for x, y in dataloader:
            w = step(w, x, y)
    return w

def eval(dataloader: DataLoader, w):
    err = []
    for x, y in dataloader:
        err.append(loss(w, x, y))
    return np.mean(err)

Train this linear regression model via DataLoaderJAX:

dataloader = DataLoader(
    dataset, 'jax', batch_size=128, shuffle=True)
w = train(dataloader, next(keys)).block_until_ready()
# assert np.allclose(eval(dataloader, w), 0.)
dataloader = DataLoader(dataset, 'jax', batch_size=200, shuffle=True)
w = train(dataloader, next(keys)).block_until_ready()
# assert np.allclose(eval(dataloader, w), 0.)

Train this linear regression model via pytorch backend:

dataloader = DataLoader(
    dataset, 'pytorch', batch_size=128, shuffle=True)
w = train(dataloader, next(keys)).block_until_ready()
# assert np.allclose(eval(dataloader, w), 0.)

Train this linear regression model via jax backend:

dataloader = DataLoader(
    dataset, 'tensorflow', batch_size=128, shuffle=True)
w = train(dataloader, next(keys)).block_until_ready()
# assert np.allclose(eval(dataloader, w), 0.)