Dataloader

Support various dataloader for loading batches.

Dataset


source

DATASET

CLASS relax.data.loader.Dataset ()

A pytorch-like abstract Dataset class.


source

ARRAYDATASET

CLASS relax.data.loader.ArrayDataset (*arrays)

Dataset wrapping tensors.

key = random.PRNGKey(0)
X = jax.random.normal(key, shape=(10, 10))
y = jax.random.normal(key, shape=(10, ))
ds = ArrayDataset(X, y)

assert len(ds) == 10
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

Indexing Dataset using ds[idx]

x1, y1 = ds[1]
assert jnp.array_equal(x1, X[1])
assert jnp.array_equal(y1, y[1])

Dataloader


source

BASEDATALOADER

CLASS relax.data.loader.BaseDataLoader (dataset, backend, batch_size=1, shuffle=False, drop_last=False, **kwargs)

Dataloader Interface

Parameters:
  • dataset
  • backend (str)
  • batch_size (int, default=1) – Batch size
  • shuffle (bool, default=False) – If true, dataloader shuffles before sampling each batch
  • drop_last (bool, default=False) – Drop last batches or not
  • kwargs

source

JAXDATALOADER

CLASS relax.data.loader.JaxDataloader (dataset, backend=‘jax’, batch_size=1, shuffle=False, drop_last=False, **kwargs)

Dataloder in vanilla Jax

Parameters:
  • dataset (Dataset)
  • backend (str, default=jax) – Position argument
  • batch_size (int, default=1) – Batch size
  • shuffle (bool, default=False) – If true, dataloader shuffles before sampling each batch
  • drop_last (bool, default=False) – Drop last batches or not
  • kwargs

source

TORCHDATALOADER

CLASS relax.data.loader.TorchDataloader (dataset, backend=‘pytorch’, batch_size=1, shuffle=False, num_workers=0, drop_last=False, **kwargs)

Use Pytorch to load batches. It requires pytorch to be installed.

Parameters:
  • dataset (Dataset)
  • backend (str, default=pytorch) – positional argument
  • batch_size (int, default=1) – batch size
  • shuffle (bool, default=False) – if true, dataloader shuffles before sampling each batch
  • num_workers (int, default=0) – number of workers
  • drop_last (bool, default=False) – drop last batch or not
  • kwargs

Main Dataloader Class


source

DATALOADERBACKENDS

CLASS relax.data.loader.DataloaderBackends (jax=<class ‘main.JaxDataloader’>, pytorch=<class ‘main.TorchDataloader’>, tensorflow=None, merlin=None)


source

_DISPATCH_DATALOADER

relax.data.loader._dispatch_dataloader (backend)

Return Dataloader class based on given backend

Parameters:
  • backend (str) – dataloader backend
Returns:

    (BaseDataLoader)

assert _dispatch_dataloader('jax') == JaxDataloader
assert _dispatch_dataloader('pytorch') == TorchDataloader

source

DATALOADER

CLASS relax.data.loader.DataLoader (dataset, backend, batch_size=1, shuffle=False, num_workers=0, drop_last=False, **kwargs)

Main Dataloader class to load Numpy data batches

Parameters:
  • dataset (Dataset)
  • backend (str) – Dataloader backend; Currently supports ['jax', 'pytorch']
  • batch_size (int, default=1) – batch size
  • shuffle (bool, default=False) – if true, dataloader shuffles before sampling each batch
  • num_workers (int, default=0) – number of workers
  • drop_last (bool, default=False) – drop last batches or not
  • kwargs

A Minimum Example of using Dataloader

We showcase how to use Dataloader for training a simple regression model.

from sklearn.datasets import make_regression
X, y = make_regression(n_samples=10000, n_features=20)
dataset = ArrayDataset(X, y.reshape(-1, 1))
keys = hk.PRNGSequence(0)

Define loss, step, train:

def loss(w, x, y):
    return jnp.mean(vmap(optax.l2_loss)(x @ w.T, y))

def step(w, x, y):
    lr = 0.1
    grad = jax.grad(loss)(w, x, y)
    w -= lr * grad
    return w

def train(dataloader: DataLoader, key: random.PRNGKey):
    w = jax.random.normal(key, shape=(1, 20))
    n_epochs = 10
    for _ in range(n_epochs):
        for x, y in dataloader:
            w = step(w, x, y)
    return w

def eval(dataloader: DataLoader, w):
    err = []
    for x, y in dataloader:
        err.append(loss(w, x, y))
    return np.mean(err)

Train this linear regression model via DataLoaderJax:

dataloader = DataLoader(
    dataset, 'jax', batch_size=128, shuffle=True)
w = train(dataloader, next(keys)).block_until_ready()
assert np.allclose(eval(dataloader, w), 0.)

Train this linear regression model via DataLoaderPytorch:

dataloader = DataLoader(
    dataset, 'pytorch', batch_size=128, shuffle=True)
w = train(dataloader, next(keys)).block_until_ready()
assert np.allclose(eval(dataloader, w), 0.)