Utils

Define utility funtions for relax.

Configurations

relax.utils.validate_configs

[source]

relax.utils.validate_configs (configs, config_cls)

return a valid configuration object.

Parameters:

  • configs (dict | pydantic.main.BaseModel) – A configuration of the model/dataset.
  • config_cls (<class 'pydantic.main.BaseModel'>) – The desired configuration class.

Returns:

    (<class 'pydantic.main.BaseModel'>)

We define a configuration object (which inherent BaseParser) to manage training/model/data configurations. validate_configs ensures to return the designated configuration object.

For example, we define a configuration object LearningConfigs:

class LearningConfigs(BaseParser):
    lr: float

A configuration can be LearningConfigs, or the raw data in dictionary.

configs_dict = dict(lr=0.01)

validate_configs will return a designated configuration object.

configs = validate_configs(configs_dict, LearningConfigs)
assert type(configs) == LearningConfigs
assert configs.lr == configs_dict['lr']

Serialization

relax.utils.save_pytree

[source]

relax.utils.save_pytree (pytree, saved_dir)

Save a pytree to a directory.

The pytree will be stored under a directory with two files:

  • {saved_dir}/data.npy: This file stores the flattened leaves.
  • {saved_dir}/treedef.json: This file stores the pytree structure and the information on whether the leave is an array or not.

For example, a pytree

pytree = {
    'a': np.random.randn(5, 1),
    'b': 1,
    'c': {
        
        'd': True,
        'e': "Hello",
        'f': np.array(["a", "b", "c"])
    }
}

will be stored as

relax.utils.load_pytree

[source]

relax.utils.load_pytree (saved_dir)

Load a pytree from a saved directory.

# Store a dictionary to disk
pytree = {
    'a': np.random.randn(100, 1),
    'b': 1,
    'c': {
        'd': True,
        'e': "Hello",
        'f': np.array(["a", "b", "c"])
    }
}
os.makedirs('tmp', exist_ok=True)
save_pytree(pytree, 'tmp')
pytree_loaded = load_pytree('tmp')
assert np.allclose(pytree['a'], pytree_loaded['a'])
assert pytree['a'].dtype == pytree_loaded['a'].dtype
assert pytree['b'] == pytree_loaded['b']
assert pytree['c']['d'] == pytree_loaded['c']['d']
assert pytree['c']['e'] == pytree_loaded['c']['e']
assert np.all(pytree['c']['f'] == pytree_loaded['c']['f'])
# Store a list to disk
pytree = [
    np.random.randn(100, 1),
    {'a': 1, 'b': np.array([1, 2, 3])},
    1,
    [1, 2, 3],
    "good"
]
save_pytree(pytree, 'tmp')
pytree_loaded = load_pytree('tmp')

assert np.allclose(pytree[0], pytree_loaded[0])
assert pytree[0].dtype == pytree_loaded[0].dtype
assert pytree[1]['a'] == pytree_loaded[1]['a']
assert np.all(pytree[1]['b'] == pytree_loaded[1]['b'])
assert pytree[2] == pytree_loaded[2]
assert pytree[3] == pytree_loaded[3]
assert isinstance(pytree_loaded[3], list)
assert pytree[4] == pytree_loaded[4]

Vectorization Utils

relax.utils.auto_reshaping

[source]

relax.utils.auto_reshaping (reshape_argname, reshape_output=True)

Decorator to automatically reshape function’s input into (1, k), and out to input’s shape.

Parameters:

  • reshape_argname (<class 'str'>) – The name of the argument to be reshaped.
  • reshape_output (<class 'bool'>, default=True) – Whether to reshape the output. Useful to set False when returning multiple cfs.

This decorator ensures that the specified input argument and output of a function are in the same shape. This is particularly useful when using jax.vamp.

@auto_reshaping('x')
def f_vmap(x): return x * jnp.ones((10,))
assert vmap(f_vmap)(jnp.ones((10, 10))).shape == (10, 10)

@auto_reshaping('x', reshape_output=False)
def f_vmap(x): return x * jnp.ones((10,))
assert vmap(f_vmap)(jnp.ones((10, 10))).shape == (10, 1, 10)

Gradient Utils

relax.utils.grad_update

[source]

relax.utils.grad_update (grads, params, opt_state, opt)

Parameters:

  • grads – A pytree of gradients.
  • params – A pytree of parameters.
  • opt_state (typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, typing.Iterable[ForwardRef('ArrayTree')], typing.Mapping[typing.Any, ForwardRef('ArrayTree')]])
  • opt (<class 'optax._src.base.GradientTransformation'>)

Functional Utils

relax.utils.gumbel_softmax

[source]

relax.utils.gumbel_softmax (key, logits, tau, axis=-1)

The Gumbel softmax function.

Parameters:

  • key (<function PRNGKey at 0x7f175f791990>) – Random key
  • logits (<class 'jax.Array'>) – Logits for each class. Shape (batch_size, num_classes)
  • tau (<class 'float'>) – Temperature for the Gumbel softmax
  • axis (int | tuple[int, ...], default=-1) – The axis or axes along which the gumbel softmax should be computed

Helper functions

relax.utils.load_json

[source]

relax.utils.load_json (f_name)

Parameters:

  • f_name (<class 'str'>)

Returns:

    (typing.Dict[str, typing.Any]) – file name

Config

relax.utils.get_config

[source]

relax.utils.get_config ()

relax.utils.set_config

[source]

relax.utils.set_config (rng_reserve_size=None, global_seed=None, **kwargs)

Sets the global configurations.

# Generic Test cases
set_config()
assert get_config().rng_reserve_size == 1 and get_config().global_seed == 42
set_config(rng_reserve_size=100)
assert get_config().rng_reserve_size == 100
set_config(global_seed=1234)
assert get_config().global_seed == 1234
set_config(rng_reserve_size=2, global_seed=234)
assert get_config().rng_reserve_size == 2 and get_config().global_seed == 234
set_config()
assert get_config().rng_reserve_size == 2 and get_config().global_seed == 234
set_config(invalid_key = 80)
assert get_config().rng_reserve_size == 2 and get_config().global_seed == 234