class LearningConfigs(BaseParser):
lr: floatUtils
relax.
Configurations
relax.utils.validate_configs
relax.utils.validate_configs (configs, config_cls)
return a valid configuration object.
Parameters:
- configs (
dict | pydantic.main.BaseModel) – A configuration of the model/dataset. - config_cls (
<class 'pydantic.main.BaseModel'>) – The desired configuration class.
Returns:
(<class 'pydantic.main.BaseModel'>)
We define a configuration object (which inherent BaseParser) to manage training/model/data configurations. validate_configs ensures to return the designated configuration object.
For example, we define a configuration object LearningConfigs:
A configuration can be LearningConfigs, or the raw data in dictionary.
configs_dict = dict(lr=0.01)validate_configs will return a designated configuration object.
configs = validate_configs(configs_dict, LearningConfigs)
assert type(configs) == LearningConfigs
assert configs.lr == configs_dict['lr']Serialization
relax.utils.save_pytree
relax.utils.save_pytree (pytree, saved_dir)
Save a pytree to a directory.
The pytree will be stored under a directory with two files:
{saved_dir}/data.npy: This file stores the flattened leaves.{saved_dir}/treedef.json: This file stores the pytree structure and the information on whether the leave is an array or not.
For example, a pytree
pytree = {
'a': np.random.randn(5, 1),
'b': 1,
'c': {
'd': True,
'e': "Hello",
'f': np.array(["a", "b", "c"])
}
}will be stored as
relax.utils.load_pytree
relax.utils.load_pytree (saved_dir)
Load a pytree from a saved directory.
# Store a dictionary to disk
pytree = {
'a': np.random.randn(100, 1),
'b': 1,
'c': {
'd': True,
'e': "Hello",
'f': np.array(["a", "b", "c"])
}
}
os.makedirs('tmp', exist_ok=True)
save_pytree(pytree, 'tmp')
pytree_loaded = load_pytree('tmp')
assert np.allclose(pytree['a'], pytree_loaded['a'])
assert pytree['a'].dtype == pytree_loaded['a'].dtype
assert pytree['b'] == pytree_loaded['b']
assert pytree['c']['d'] == pytree_loaded['c']['d']
assert pytree['c']['e'] == pytree_loaded['c']['e']
assert np.all(pytree['c']['f'] == pytree_loaded['c']['f'])# Store a list to disk
pytree = [
np.random.randn(100, 1),
{'a': 1, 'b': np.array([1, 2, 3])},
1,
[1, 2, 3],
"good"
]
save_pytree(pytree, 'tmp')
pytree_loaded = load_pytree('tmp')
assert np.allclose(pytree[0], pytree_loaded[0])
assert pytree[0].dtype == pytree_loaded[0].dtype
assert pytree[1]['a'] == pytree_loaded[1]['a']
assert np.all(pytree[1]['b'] == pytree_loaded[1]['b'])
assert pytree[2] == pytree_loaded[2]
assert pytree[3] == pytree_loaded[3]
assert isinstance(pytree_loaded[3], list)
assert pytree[4] == pytree_loaded[4]Vectorization Utils
relax.utils.auto_reshaping
relax.utils.auto_reshaping (reshape_argname, reshape_output=True)
Decorator to automatically reshape function’s input into (1, k), and out to input’s shape.
Parameters:
- reshape_argname (
<class 'str'>) – The name of the argument to be reshaped. - reshape_output (
<class 'bool'>, default=True) – Whether to reshape the output. Useful to setFalsewhen returning multiple cfs.
This decorator ensures that the specified input argument and output of a function are in the same shape. This is particularly useful when using jax.vamp.
@auto_reshaping('x')
def f_vmap(x): return x * jnp.ones((10,))
assert vmap(f_vmap)(jnp.ones((10, 10))).shape == (10, 10)
@auto_reshaping('x', reshape_output=False)
def f_vmap(x): return x * jnp.ones((10,))
assert vmap(f_vmap)(jnp.ones((10, 10))).shape == (10, 1, 10)Gradient Utils
relax.utils.grad_update
relax.utils.grad_update (grads, params, opt_state, opt)
Parameters:
- grads – A pytree of gradients.
- params – A pytree of parameters.
- opt_state (
typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, typing.Iterable[ForwardRef('ArrayTree')], typing.Mapping[typing.Any, ForwardRef('ArrayTree')]]) - opt (
<class 'optax._src.base.GradientTransformation'>)
Functional Utils
relax.utils.gumbel_softmax
relax.utils.gumbel_softmax (key, logits, tau, axis=-1)
The Gumbel softmax function.
Parameters:
- key (
<function PRNGKey at 0x7f248027e0e0>) – Random key - logits (
<class 'jax.Array'>) – Logits for each class. Shape (batch_size, num_classes) - tau (
<class 'float'>) – Temperature for the Gumbel softmax - axis (
int | tuple[int, ...], default=-1) – The axis or axes along which the gumbel softmax should be computed
Helper functions
relax.utils.load_json
relax.utils.load_json (f_name)
Parameters:
- f_name (
<class 'str'>)
Returns:
(typing.Dict[str, typing.Any]) – file name
Config
relax.utils.get_config
relax.utils.get_config ()
relax.utils.set_config
relax.utils.set_config (rng_reserve_size=None, global_seed=None, **kwargs)
Sets the global configurations.
Parameters:
- rng_reserve_size (
<class 'int'>, default=None) – The number of random number generators to reserve. - global_seed (
<class 'int'>, default=None) – The global seed for random number generators. - kwargs (
VAR_KEYWORD)
Returns:
(None)
# Generic Test cases
set_config()
assert get_config().rng_reserve_size == 1 and get_config().global_seed == 42
set_config(rng_reserve_size=100)
assert get_config().rng_reserve_size == 100
set_config(global_seed=1234)
assert get_config().global_seed == 1234
set_config(rng_reserve_size=2, global_seed=234)
assert get_config().rng_reserve_size == 2 and get_config().global_seed == 234
set_config()
assert get_config().rng_reserve_size == 2 and get_config().global_seed == 234
set_config(invalid_key = 80)
assert get_config().rng_reserve_size == 2 and get_config().global_seed == 234