Controlling Randomness¤
jax-dataloader
provides flexible mechanisms to manage the
pseudo-random number generation used during data loading, which is
essential for reproducibility, especially when shuffling data. This
tutorial outlines the two primary ways to control randomness:
- Setting a global seed
- Assigning specific seed generators to individual dataloaders.
Prerequisites¤
Let’s set up the necessary imports and a simple dataset for our examples:
import jax_dataloader as jdl
import jax
import jax.numpy as jnp
import torch
# Sample dataset
data = jnp.arange(20).reshape(10, 2)
labels = jnp.arange(10)
ds = jdl.ArrayDataset(data, labels)
Method 1: Setting the Global Seed¤
The simplest way to control randomness across all jax-dataloader
instances is by setting a global seed. This affects all dataloaders
created after the seed is set, unless they have their own specific
generator specified.
Use the jax_dataloader.manual_seed()
function:
# Set the global seed for all subsequent dataloaders
jdl.manual_seed(1234)
# Both dataloaders below will use the same underlying seed sequence
# resulting in identical shuffling order if other parameters are the same.
dl_1 = jdl.DataLoader(ds, backend='jax', batch_size=2, shuffle=True)
dl_2 = jdl.DataLoader(ds, backend='jax', batch_size=2, shuffle=True)
# Iterate through dl_1 and dl_2 to observe the same order
print("DataLoader 1 first batch:", next(iter(dl_1)))
print("DataLoader 2 first batch:", next(iter(dl_2)))
DataLoader 1 first batch: (array([[2, 3],
[4, 5]], dtype=int32), array([1, 2], dtype=int32))
DataLoader 2 first batch: (array([[2, 3],
[4, 5]], dtype=int32), array([1, 2], dtype=int32))
Method 2: Setting Per-Dataloader Seed Generators¤
For more fine-grained control, assign a specific seed generator to individual DataLoader instances using the generator argument. This overrides any global seed for that specific dataloader.
jax-dataloader supports generators from jax-dataloader
,
jax.random.PRNGKey
, and torch.Generator
.
1. Using jdl.Generator
¤
Create and seed a jdl.Generator
object and pass it to the
jdl.DataLoader
.
# Create a specific generator with its own seed
g1 = jdl.Generator().manual_seed(4321)
# This dataloader will use g1, overriding any global seed
dl_jdl_gen = jdl.DataLoader(ds, backend='jax', batch_size=2, shuffle=True, generator=g1)
print("DataLoader with jdl.Generator first batch:", next(iter(dl_jdl_gen)))
DataLoader with jdl.Generator first batch: (array([[ 6, 7],
[10, 11]], dtype=int32), array([3, 5], dtype=int32))
2. Using jax.random.PRNGKey
¤
Directly use a jax.random.PRNGKey
as the generator.
# Create a JAX PRNGKey
key = jax.random.PRNGKey(4321)
# This dataloader will use the JAX key, overriding any global seed
# jax-dataloader handles the key internally for reproducible iteration.
dl_jax_key = jdl.DataLoader(ds, backend='jax', batch_size=2, shuffle=True, generator=key)
print("DataLoader with JAX PRNGKey first batch:", next(iter(dl_jax_key)))
DataLoader with JAX PRNGKey first batch: (array([[ 6, 7],
[10, 11]], dtype=int32), array([3, 5], dtype=int32))
3. Using torch.Generator¤
When using the 'torch'
backend, you can use a torch.Generator
.
# Create a PyTorch generator
g3 = torch.Generator().manual_seed(5678)
# This dataloader uses the 'torch' backend and the PyTorch generator
dl_torch_gen = jdl.DataLoader(ds, backend='pytorch', batch_size=2, shuffle=True, generator=g3)
print("DataLoader with torch.Generator first batch:", next(iter(dl_torch_gen)))
DataLoader with torch.Generator first batch: [array([[ 0, 1],
[14, 15]], dtype=int32), array([0, 7], dtype=int32)]
Trade-offs: Global Seed vs. Per-Dataloader Generators¤
Consider these trade-offs when deciding how to manage randomness.
Global Seed (jdl.manual_seed()
)¤
- Simplicity: Very easy to implement with one line for basic reproducibility.
- Implicit Consistency: Automatically ensures dataloaders created subsequently (without their own generator) share the same base randomness, useful for simple synchronization.
Per-Dataloader Generator (generator=...
)¤
- Fine-grained Control: Allows independent and precise randomness management for each dataloader.
- Isolation: Prevents randomness in one dataloader from affecting others.
- Integration: Works naturally with JAX keys or PyTorch generators.
- Modularity: Better suited for complex applications or libraries where components need self-contained randomness.