= random.PRNGKey(0)
key = jax.random.normal(key, shape=(10, 10))
X = jax.random.normal(key, shape=(10, ))
y = ArrayDataset(X, y)
ds
assert len(ds) == 10
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)