Skip to content

Dataset¤


source

Dataset¤

 Dataset ()

A pytorch-like Dataset class.


source

ArrayDataset¤

 ArrayDataset (*arrays:jax.Array, asnumpy:bool=True)

Dataset wrapping numpy arrays.

This is similar to torch.utils.data.TensorDataset, but it wrapps numpy arrays.

X = jnp.arange(10000).reshape(1000, 10)
y = jnp.arange(1000)
ds = ArrayDataset(X, y)
assert len(ds) == 1000

We index numpy arrays along the first dimension. Dataset indexing is done via ds[index].

x1, y1 = ds[1] # get the first sample
assert jnp.array_equal(x1, X[1])
assert jnp.array_equal(y1, y[1])

x10, y10 = ds[:10]
assert jnp.array_equal(x10, X[:10])
assert jnp.array_equal(y10, y[:10])

By default, ArrayDataset stores arrays as numpy.array.

x, _ = ds[:10]
assert isinstance(x, np.ndarray)
assert not isinstance(x, jnp.ndarray)

If you want to store the array type the way you passed, you can simply pass asnumpy=False.

ds = ArrayDataset(X, y, asnumpy=False)
x, _ = ds[:10]
assert isinstance(x, jnp.ndarray)