Skip to content

Core API¤


source

DataloaderBackends¤

DataloaderBackends (pytorch:jax_dataloader.loaders.base.BaseDataLoader=< class 'jax_dataloader.loaders.torch.DataLoaderPytorch'>, te nsorflow:jax_dataloader.loaders.base.BaseDataLoader=< class 'jax_dataloader.loaders.tensorflow.DataLoaderTe nsorflow'>, merlin:jax_dataloader.loaders.base.BaseDa taLoader=None)


source

get_backend_compatibilities¤

get_backend_compatibilities ()

Return list of supported dataloader backends for each dataset type


source

DataLoader¤

DataLoader (dataset, backend:Literal['jax','pytorch','tensorflow'], batch_size:int=1, shuffle:bool=False, drop_last:bool=False, g enerator:Union[jax_dataloader.utils.Generator,jax.Array,Forwa rdRef('torch.Generator'),NoneType]=None, **kwargs)

Main Dataloader class to load Numpy data batches

Type Default Details
dataset Dataset from which to load the data
backend Literal[‘jax’, ‘pytorch’, ‘tensorflow’] Dataloader backend to load the dataset
batch_size int 1 How many samples per batch to load
shuffle bool False If true, dataloader reshuffles every epoch
drop_last bool False If true, drop the last incomplete batch
generator Optional[GeneratorType] None Random seed generator
kwargs

A Minimum Example of using Dataloader¤

We showcase how to use Dataloader for training a simple regression model.

from sklearn.datasets import make_regression
import optax
import haiku as hk
X, y = make_regression(n_samples=500, n_features=20)
dataset = ArrayDataset(X, y.reshape(-1, 1))
keys = hk.PRNGSequence(0)
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.

Define loss, step, train:

def loss(w, x, y):
    return jnp.mean(vmap(optax.l2_loss)(x @ w.T, y))

def step(w, x, y):
    lr = 0.1
    grad = jax.grad(loss)(w, x, y)
    w -= lr * grad
    return w

def train(dataloader: DataLoader, key: jax.random.PRNGKey):
    w = jrand.normal(key, shape=(1, 20))
    n_epochs = 10
    for _ in range(n_epochs):
        for x, y in dataloader:
            w = step(w, x, y)
    return w

def eval(dataloader: DataLoader, w):
    err = []
    for x, y in dataloader:
        err.append(loss(w, x, y))
    return np.mean(err)

Train this linear regression model via DataLoaderJAX:

dataloader = DataLoader(
    dataset, 'jax', batch_size=128, shuffle=True)
w = train(dataloader, next(keys)).block_until_ready()
# assert np.allclose(eval(dataloader, w), 0.)
dataloader = DataLoader(dataset, 'jax', batch_size=200, shuffle=True)
w = train(dataloader, next(keys)).block_until_ready()
# assert np.allclose(eval(dataloader, w), 0.)

Train this linear regression model via pytorch backend:

dataloader = DataLoader(
    dataset, 'pytorch', batch_size=128, shuffle=True)
w = train(dataloader, next(keys)).block_until_ready()
# assert np.allclose(eval(dataloader, w), 0.)

Train this linear regression model via jax backend:

dataloader = DataLoader(
    dataset, 'tensorflow', batch_size=128, shuffle=True)
w = train(dataloader, next(keys)).block_until_ready()
# assert np.allclose(eval(dataloader, w), 0.)