Utils¤
Global configs, PRNGSequence, check installed.
Configs¤
Config¤
Config (rng_reserve_size:int, global_seed:int)
Global configuration for the library
get_config¤
get_config ()
manual_seed¤
manual_seed (seed:int)
Set the seed for the library
manual_seed(11)
assert get_config().global_seed == 11
Check Installation¤
check_pytorch_installed¤
check_pytorch_installed ()
check_pytorch_installed()
has_pytorch_tensor¤
has_pytorch_tensor (batch)
check_hf_installed¤
check_hf_installed ()
check_hf_installed()
check_tf_installed¤
check_tf_installed ()
check_tf_installed()
Util Functions¤
asnumpy¤
asnumpy (x)
np_x = np.array([1, 2, 3])
jnp_x = jnp.array([1, 2, 3])
torch_x = torch.tensor([1, 2, 3])
tf_x = tf.constant([1, 2, 3])
assert np.array_equal(asnumpy(np_x), np_x)
assert np.array_equal(asnumpy(jnp_x), np_x) and not isinstance(asnumpy(jnp_x), jnp.ndarray)
assert np.array_equal(asnumpy(torch_x), np_x) and not isinstance(asnumpy(torch_x), torch.Tensor)
assert np.array_equal(asnumpy(tf_x), np_x) and not isinstance(asnumpy(tf_x), tf.Tensor)