class LearningConfigs(BaseParser):
float lr:
Utils
relax
.
Configurations
relax.utils.validate_configs
relax.utils.validate_configs (configs, config_cls)
return a valid configuration object.
Parameters:
- configs (
dict | pydantic.main.BaseModel
) – A configuration of the model/dataset. - config_cls (
<class 'pydantic.main.BaseModel'>
) – The desired configuration class.
Returns:
(<class 'pydantic.main.BaseModel'>
)
We define a configuration object (which inherent BaseParser
) to manage training/model/data configurations. validate_configs
ensures to return the designated configuration object.
For example, we define a configuration object LearningConfigs
:
A configuration can be LearningConfigs
, or the raw data in dictionary.
= dict(lr=0.01) configs_dict
validate_configs
will return a designated configuration object.
= validate_configs(configs_dict, LearningConfigs)
configs assert type(configs) == LearningConfigs
assert configs.lr == configs_dict['lr']
Serialization
relax.utils.save_pytree
relax.utils.save_pytree (pytree, saved_dir)
Save a pytree to a directory.
The pytree will be stored under a directory with two files:
{saved_dir}/data.npy
: This file stores the flattened leaves.{saved_dir}/treedef.json
: This file stores the pytree structure and the information on whether the leave is an array or not.
For example, a pytree
= {
pytree 'a': np.random.randn(5, 1),
'b': 1,
'c': {
'd': True,
'e': "Hello",
'f': np.array(["a", "b", "c"])
} }
will be stored as
relax.utils.load_pytree
relax.utils.load_pytree (saved_dir)
Load a pytree from a saved directory.
# Store a dictionary to disk
= {
pytree 'a': np.random.randn(100, 1),
'b': 1,
'c': {
'd': True,
'e': "Hello",
'f': np.array(["a", "b", "c"])
}
}'tmp', exist_ok=True)
os.makedirs('tmp')
save_pytree(pytree, = load_pytree('tmp')
pytree_loaded assert np.allclose(pytree['a'], pytree_loaded['a'])
assert pytree['a'].dtype == pytree_loaded['a'].dtype
assert pytree['b'] == pytree_loaded['b']
assert pytree['c']['d'] == pytree_loaded['c']['d']
assert pytree['c']['e'] == pytree_loaded['c']['e']
assert np.all(pytree['c']['f'] == pytree_loaded['c']['f'])
# Store a list to disk
= [
pytree 100, 1),
np.random.randn('a': 1, 'b': np.array([1, 2, 3])},
{1,
1, 2, 3],
["good"
]'tmp')
save_pytree(pytree, = load_pytree('tmp')
pytree_loaded
assert np.allclose(pytree[0], pytree_loaded[0])
assert pytree[0].dtype == pytree_loaded[0].dtype
assert pytree[1]['a'] == pytree_loaded[1]['a']
assert np.all(pytree[1]['b'] == pytree_loaded[1]['b'])
assert pytree[2] == pytree_loaded[2]
assert pytree[3] == pytree_loaded[3]
assert isinstance(pytree_loaded[3], list)
assert pytree[4] == pytree_loaded[4]
Vectorization Utils
relax.utils.auto_reshaping
relax.utils.auto_reshaping (reshape_argname, reshape_output=True)
Decorator to automatically reshape function’s input into (1, k), and out to input’s shape.
Parameters:
- reshape_argname (
<class 'str'>
) – The name of the argument to be reshaped. - reshape_output (
<class 'bool'>
, default=True) – Whether to reshape the output. Useful to setFalse
when returning multiple cfs.
This decorator ensures that the specified input argument and output of a function are in the same shape. This is particularly useful when using jax.vamp
.
@auto_reshaping('x')
def f_vmap(x): return x * jnp.ones((10,))
assert vmap(f_vmap)(jnp.ones((10, 10))).shape == (10, 10)
@auto_reshaping('x', reshape_output=False)
def f_vmap(x): return x * jnp.ones((10,))
assert vmap(f_vmap)(jnp.ones((10, 10))).shape == (10, 1, 10)
Gradient Utils
relax.utils.grad_update
relax.utils.grad_update (grads, params, opt_state, opt)
Parameters:
- grads – A pytree of gradients.
- params – A pytree of parameters.
- opt_state (
typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, typing.Iterable[ForwardRef('ArrayTree')], typing.Mapping[typing.Any, ForwardRef('ArrayTree')]]
) - opt (
<class 'optax._src.base.GradientTransformation'>
)
Functional Utils
relax.utils.gumbel_softmax
relax.utils.gumbel_softmax (key, logits, tau, axis=-1)
The Gumbel softmax function.
Parameters:
- key (
<function PRNGKey at 0x7f175f791990>
) – Random key - logits (
<class 'jax.Array'>
) – Logits for each class. Shape (batch_size, num_classes) - tau (
<class 'float'>
) – Temperature for the Gumbel softmax - axis (
int | tuple[int, ...]
, default=-1) – The axis or axes along which the gumbel softmax should be computed
Helper functions
relax.utils.load_json
relax.utils.load_json (f_name)
Parameters:
- f_name (
<class 'str'>
)
Returns:
(typing.Dict[str, typing.Any]
) – file name
Config
relax.utils.get_config
relax.utils.get_config ()
relax.utils.set_config
relax.utils.set_config (rng_reserve_size=None, global_seed=None, **kwargs)
Sets the global configurations.
# Generic Test cases
set_config()assert get_config().rng_reserve_size == 1 and get_config().global_seed == 42
=100)
set_config(rng_reserve_sizeassert get_config().rng_reserve_size == 100
=1234)
set_config(global_seedassert get_config().global_seed == 1234
=2, global_seed=234)
set_config(rng_reserve_sizeassert get_config().rng_reserve_size == 2 and get_config().global_seed == 234
set_config()assert get_config().rng_reserve_size == 2 and get_config().global_seed == 234
= 80)
set_config(invalid_key assert get_config().rng_reserve_size == 2 and get_config().global_seed == 234