Core API¤
DataloaderBackends¤
DataloaderBackends (pytorch:BaseDataLoader=<class 'jax_dataloader.loaders.torch.DataLoaderPytorch'>, tensorflow:BaseDataLoader=<class 'jax_dataloader.load ers.tensorflow.DataLoaderTensorflow'>, merlin:BaseDataLoader=None)
get_backend_compatibilities¤
get_backend_compatibilities ()
Return list of supported dataloader backends for each dataset type
DataLoader¤
DataLoader (dataset, backend:"Literal['jax','pytorch','tensorflow']", batch_size:int=1, shuffle:bool=False, drop_last:bool=False, **kwargs)
Main Dataloader class to load Numpy data batches
Type | Default | Details | |
---|---|---|---|
dataset | Dataset from which to load the data | ||
backend | typing.Literal[‘jax’, ‘pytorch’, ‘tensorflow’] | Dataloader backend to load the dataset | |
batch_size | int | 1 | How many samples per batch to load |
shuffle | bool | False | If true, dataloader reshuffles every epoch |
drop_last | bool | False | If true, drop the last incomplete batch |
kwargs |
A Minimum Example of using Dataloader¤
We showcase how to use Dataloader
for training a simple regression
model.
from sklearn.datasets import make_regression
import optax
import haiku as hk
X, y = make_regression(n_samples=500, n_features=20)
dataset = ArrayDataset(X, y.reshape(-1, 1))
keys = hk.PRNGSequence(0)
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
Define loss
, step
, train
:
def loss(w, x, y):
return jnp.mean(vmap(optax.l2_loss)(x @ w.T, y))
def step(w, x, y):
lr = 0.1
grad = jax.grad(loss)(w, x, y)
w -= lr * grad
return w
def train(dataloader: DataLoader, key: jax.random.PRNGKey):
w = jrand.normal(key, shape=(1, 20))
n_epochs = 10
for _ in range(n_epochs):
for x, y in dataloader:
w = step(w, x, y)
return w
def eval(dataloader: DataLoader, w):
err = []
for x, y in dataloader:
err.append(loss(w, x, y))
return np.mean(err)
Train this linear regression model via
DataLoaderJAX
:
dataloader = DataLoader(
dataset, 'jax', batch_size=128, shuffle=True)
w = train(dataloader, next(keys)).block_until_ready()
# assert np.allclose(eval(dataloader, w), 0.)
dataloader = DataLoader(dataset, 'jax', batch_size=200, shuffle=True)
w = train(dataloader, next(keys)).block_until_ready()
# assert np.allclose(eval(dataloader, w), 0.)
Train this linear regression model via pytorch
backend:
dataloader = DataLoader(
dataset, 'pytorch', batch_size=128, shuffle=True)
w = train(dataloader, next(keys)).block_until_ready()
# assert np.allclose(eval(dataloader, w), 0.)
Train this linear regression model via jax
backend:
dataloader = DataLoader(
dataset, 'tensorflow', batch_size=128, shuffle=True)
w = train(dataloader, next(keys)).block_until_ready()
# assert np.allclose(eval(dataloader, w), 0.)