key = random.PRNGKey(0)
X = jax.random.normal(key, shape=(10, 10))
y = jax.random.normal(key, shape=(10, ))
ds = ArrayDataset(X, y)
assert len(ds) == 10WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)