Utils¤
Global configs, PRNGSequence, check installed.
Configs¤
Config¤
Config (rng_reserve_size:int, global_seed:int)
Global configuration for the library
get_config¤
get_config ()
manual_seed¤
manual_seed (seed:int)
Set the seed for the library
manual_seed(11)
assert get_config().global_seed == 11
Check Installation¤
check_pytorch_installed¤
check_pytorch_installed ()
check_pytorch_installed()
has_pytorch_tensor¤
has_pytorch_tensor (batch)
check_hf_installed¤
check_hf_installed ()
check_hf_installed()
check_tf_installed¤
check_tf_installed ()
check_tf_installed()
Seed Generator¤
Generator¤
Generator (generator:jax.Array|torch._C.Generator=None)
A wrapper around JAX and PyTorch generators. This is used to generate random numbers in a reproducible way.
# Example of using the generator
g = Generator()
assert g.seed() == get_config().global_seed
assert jnp.array_equal(g.jax_generator(), jax.random.PRNGKey(get_config().global_seed))
assert g.torch_generator().initial_seed() == get_config().global_seed
# Examples of using the generator when passing a `jax.random.PRNGKey` or `torch.Generator`
g_jax = Generator(generator=jax.random.PRNGKey(123))
assert jnp.array_equal(g_jax.jax_generator(), jax.random.PRNGKey(123))
assert g_jax.seed() is None
g_torch = Generator(generator=torch.Generator().manual_seed(123))
assert g_torch.torch_generator().initial_seed() == 123
assert g_torch.seed() == 123
assert jnp.array_equal(g_torch.jax_generator(), jax.random.PRNGKey(123))
# Example of using `manual_seed` to set the seed
g_jax.manual_seed(456)
assert g_jax.seed() == 456
assert jnp.array_equal(g_jax.jax_generator(), jax.random.PRNGKey(456))
assert g_jax.torch_generator().initial_seed() == 456
g_torch.manual_seed(789)
assert g_torch.seed() == 789
assert g_torch.torch_generator().initial_seed() == 789
assert jnp.array_equal(g_torch.jax_generator(), jax.random.PRNGKey(789))
Util Functions¤
asnumpy¤
asnumpy (x)
np_x = np.array([1, 2, 3])
jnp_x = jnp.array([1, 2, 3])
torch_x = torch.tensor([1, 2, 3])
tf_x = tf.constant([1, 2, 3])
assert np.array_equal(asnumpy(np_x), np_x)
assert np.array_equal(asnumpy(jnp_x), np_x) and not isinstance(asnumpy(jnp_x), jnp.ndarray)
assert np.array_equal(asnumpy(torch_x), np_x) and not isinstance(asnumpy(torch_x), torch.Tensor)
assert np.array_equal(asnumpy(tf_x), np_x) and not isinstance(asnumpy(tf_x), tf.Tensor)