Skip to content

Utils¤

Global configs, PRNGSequence, check installed.

Configs¤


source

Config¤

Config (rng_reserve_size:int, global_seed:int)

Global configuration for the library


source

get_config¤

get_config ()


source

manual_seed¤

manual_seed (seed:int)

Set the seed for the library

manual_seed(11)
assert get_config().global_seed == 11

Check Installation¤


source

check_pytorch_installed¤

check_pytorch_installed ()

check_pytorch_installed()

source

has_pytorch_tensor¤

has_pytorch_tensor (batch)


source

check_hf_installed¤

check_hf_installed ()

check_hf_installed()

source

check_tf_installed¤

check_tf_installed ()

check_tf_installed()

Seed Generator¤


source

Generator¤

Generator (generator:jax.Array|torch._C.Generator=None)

A wrapper around JAX and PyTorch generators. This is used to generate random numbers in a reproducible way.

# Example of using the generator
g = Generator()
assert g.seed() == get_config().global_seed
assert jnp.array_equal(g.jax_generator(), jax.random.PRNGKey(get_config().global_seed)) 
assert g.torch_generator().initial_seed() == get_config().global_seed

# Examples of using the generator when passing a `jax.random.PRNGKey` or `torch.Generator`
g_jax = Generator(generator=jax.random.PRNGKey(123))
assert jnp.array_equal(g_jax.jax_generator(), jax.random.PRNGKey(123))
assert g_jax.seed() is None

g_torch = Generator(generator=torch.Generator().manual_seed(123))
assert g_torch.torch_generator().initial_seed() == 123
assert g_torch.seed() == 123
assert jnp.array_equal(g_torch.jax_generator(), jax.random.PRNGKey(123))
# Example of using `manual_seed` to set the seed
g_jax.manual_seed(456)
assert g_jax.seed() == 456
assert jnp.array_equal(g_jax.jax_generator(), jax.random.PRNGKey(456))
assert g_jax.torch_generator().initial_seed() == 456

g_torch.manual_seed(789)
assert g_torch.seed() == 789
assert g_torch.torch_generator().initial_seed() == 789
assert jnp.array_equal(g_torch.jax_generator(), jax.random.PRNGKey(789))

Util Functions¤


source

asnumpy¤

asnumpy (x)

np_x = np.array([1, 2, 3])
jnp_x = jnp.array([1, 2, 3])
torch_x = torch.tensor([1, 2, 3])
tf_x = tf.constant([1, 2, 3])
assert np.array_equal(asnumpy(np_x), np_x)
assert np.array_equal(asnumpy(jnp_x), np_x) and not isinstance(asnumpy(jnp_x), jnp.ndarray)
assert np.array_equal(asnumpy(torch_x), np_x) and not isinstance(asnumpy(torch_x), torch.Tensor)
assert np.array_equal(asnumpy(tf_x), np_x) and not isinstance(asnumpy(tf_x), tf.Tensor)