Dataset¤
Dataset¤
Dataset ()
A pytorch-like Dataset class.
ArrayDataset¤
ArrayDataset (*arrays:jax.Array, asnumpy:bool=True)
Dataset wrapping numpy arrays.
This is similar to torch.utils.data.TensorDataset, but it wrapps numpy arrays.
X = jnp.arange(10000).reshape(1000, 10)
y = jnp.arange(1000)
ds = ArrayDataset(X, y)
assert len(ds) == 1000
We index numpy arrays along the first dimension. Dataset indexing is
done via ds[index]
.
x1, y1 = ds[1] # get the first sample
assert jnp.array_equal(x1, X[1])
assert jnp.array_equal(y1, y[1])
x10, y10 = ds[:10]
assert jnp.array_equal(x10, X[:10])
assert jnp.array_equal(y10, y[:10])
By default,
ArrayDataset
stores arrays as
numpy.array.
x, _ = ds[:10]
assert isinstance(x, np.ndarray)
assert not isinstance(x, jnp.ndarray)
If you want to store the array type the way you passed, you can simply
pass asnumpy=False
.
ds = ArrayDataset(X, y, asnumpy=False)
x, _ = ds[:10]
assert isinstance(x, jnp.ndarray)