Skip to content

Controlling Randomness¤

jax-dataloader provides flexible mechanisms to manage the pseudo-random number generation used during data loading, which is essential for reproducibility, especially when shuffling data. This tutorial outlines the two primary ways to control randomness:

  • Setting a global seed
  • Assigning specific seed generators to individual dataloaders.

Prerequisites¤

Let’s set up the necessary imports and a simple dataset for our examples:

import jax_dataloader as jdl
import jax
import jax.numpy as jnp
import torch

# Sample dataset
data = jnp.arange(20).reshape(10, 2)
labels = jnp.arange(10)
ds = jdl.ArrayDataset(data, labels)

Method 1: Setting the Global Seed¤

The simplest way to control randomness across all jax-dataloader instances is by setting a global seed. This affects all dataloaders created after the seed is set, unless they have their own specific generator specified.

Use the jax_dataloader.manual_seed() function:

# Set the global seed for all subsequent dataloaders
jdl.manual_seed(1234)

# Both dataloaders below will use the same underlying seed sequence
# resulting in identical shuffling order if other parameters are the same.
dl_1 = jdl.DataLoader(ds, backend='jax', batch_size=2, shuffle=True)
dl_2 = jdl.DataLoader(ds, backend='jax', batch_size=2, shuffle=True)

# Iterate through dl_1 and dl_2 to observe the same order
print("DataLoader 1 first batch:", next(iter(dl_1)))
print("DataLoader 2 first batch:", next(iter(dl_2)))
DataLoader 1 first batch: (array([[2, 3],
       [4, 5]], dtype=int32), array([1, 2], dtype=int32))
DataLoader 2 first batch: (array([[2, 3],
       [4, 5]], dtype=int32), array([1, 2], dtype=int32))

Method 2: Setting Per-Dataloader Seed Generators¤

For more fine-grained control, assign a specific seed generator to individual DataLoader instances using the generator argument. This overrides any global seed for that specific dataloader.

jax-dataloader supports generators from jax-dataloader, jax.random.PRNGKey, and torch.Generator.

1. Using jdl.Generator¤

Create and seed a jdl.Generator object and pass it to the jdl.DataLoader.

# Create a specific generator with its own seed
g1 = jdl.Generator().manual_seed(4321)

# This dataloader will use g1, overriding any global seed
dl_jdl_gen = jdl.DataLoader(ds, backend='jax', batch_size=2, shuffle=True, generator=g1)

print("DataLoader with jdl.Generator first batch:", next(iter(dl_jdl_gen)))
DataLoader with jdl.Generator first batch: (array([[ 6,  7],
       [10, 11]], dtype=int32), array([3, 5], dtype=int32))

2. Using jax.random.PRNGKey¤

Directly use a jax.random.PRNGKey as the generator.

# Create a JAX PRNGKey
key = jax.random.PRNGKey(4321)

# This dataloader will use the JAX key, overriding any global seed
# jax-dataloader handles the key internally for reproducible iteration.
dl_jax_key = jdl.DataLoader(ds, backend='jax', batch_size=2, shuffle=True, generator=key)

print("DataLoader with JAX PRNGKey first batch:", next(iter(dl_jax_key)))
DataLoader with JAX PRNGKey first batch: (array([[ 6,  7],
       [10, 11]], dtype=int32), array([3, 5], dtype=int32))

3. Using torch.Generator¤

When using the 'torch' backend, you can use a torch.Generator.

# Create a PyTorch generator
g3 = torch.Generator().manual_seed(5678)

# This dataloader uses the 'torch' backend and the PyTorch generator
dl_torch_gen = jdl.DataLoader(ds, backend='pytorch', batch_size=2, shuffle=True, generator=g3)

print("DataLoader with torch.Generator first batch:", next(iter(dl_torch_gen)))
DataLoader with torch.Generator first batch: [array([[ 0,  1],
       [14, 15]], dtype=int32), array([0, 7], dtype=int32)]

Trade-offs: Global Seed vs. Per-Dataloader Generators¤

Consider these trade-offs when deciding how to manage randomness.

Global Seed (jdl.manual_seed())¤

  • Simplicity: Very easy to implement with one line for basic reproducibility.
  • Implicit Consistency: Automatically ensures dataloaders created subsequently (without their own generator) share the same base randomness, useful for simple synchronization.

Per-Dataloader Generator (generator=...)¤

  • Fine-grained Control: Allows independent and precise randomness management for each dataloader.
  • Isolation: Prevents randomness in one dataloader from affecting others.
  • Integration: Works naturally with JAX keys or PyTorch generators.
  • Modularity: Better suited for complex applications or libraries where components need self-contained randomness.