Skip to content

Utils¤

Global configs, PRNGSequence, check installed.

Configs¤


source

Config¤

 Config (rng_reserve_size:int, global_seed:int)

Global configuration for the library


source

get_config¤

 get_config ()

source

manual_seed¤

 manual_seed (seed:int)

Set the seed for the library

manual_seed(11)
assert get_config().global_seed == 11

Check Installation¤


source

check_pytorch_installed¤

 check_pytorch_installed ()
check_pytorch_installed()

source

has_pytorch_tensor¤

 has_pytorch_tensor (batch)

source

check_hf_installed¤

 check_hf_installed ()
check_hf_installed()

source

check_tf_installed¤

 check_tf_installed ()
check_tf_installed()

Util Functions¤


source

asnumpy¤

 asnumpy (x)
np_x = np.array([1, 2, 3])
jnp_x = jnp.array([1, 2, 3])
torch_x = torch.tensor([1, 2, 3])
tf_x = tf.constant([1, 2, 3])
assert np.array_equal(asnumpy(np_x), np_x)
assert np.array_equal(asnumpy(jnp_x), np_x) and not isinstance(asnumpy(jnp_x), jnp.ndarray)
assert np.array_equal(asnumpy(torch_x), np_x) and not isinstance(asnumpy(torch_x), torch.Tensor)
assert np.array_equal(asnumpy(tf_x), np_x) and not isinstance(asnumpy(tf_x), tf.Tensor)