Utils

Define utility funtions for relax.

Configurations


source

VALIDATE_CONFIGS

relax.utils.validate_configs (configs, config_cls)

return a valid configuration object.

Parameters:
  • configs (dict | BaseParser) – A configuration of the model/dataset.
  • config_cls (BaseParser) – The desired configuration class.
Returns:

    (BaseParser)

We define a configuration object (which inherent BaseParser) to manage training/model/data configurations. validate_configs ensures to return the designated configuration object.

For example, we define a configuration object LearningConfigs:

class LearningConfigs(BaseParser):
    lr: float

A configuration can be LearningConfigs, or the raw data in dictionary.

configs_dict = dict(lr=0.01)

validate_configs will return a designated configuration object.

configs = validate_configs(configs_dict, LearningConfigs)
assert type(configs) == LearningConfigs
assert configs.lr == configs_dict['lr']

source

SHOW_DOC

relax.utils.show_doc (sym)

Same functionality as nbdev.show_doc, but provide additional support for BaseParser.

Parameters:
  • sym – Symbol to document

Categorical normalization


source

CAT_NORMALIZE

relax.utils.cat_normalize (cf, cat_arrays, cat_idx, hard=False)

Ensure generated counterfactual explanations to respect one-hot encoding constraints.

Parameters:
  • cf (jnp.ndarray) – Unnormalized counterfactual explanations [n_samples, n_features]
  • cat_arrays (List[List[str]]) – A list of a list of each categorical feature name
  • cat_idx (int) – Index that starts categorical features
  • hard (bool, default=False) – If True, return one-hot vectors; If False, return probability normalized via softmax
Returns:

    (jnp.ndarray)

A tabular data point is encoded as x = [\underbrace{x_{0}, x_{1}, ..., x_{m}}_{\text{cont features}}, \underbrace{x_{m+1}^{c=1},..., x_{m+p}^{c=1}}_{\text{cat feature (1)}}, ..., \underbrace{x_{k-q}^{c=i},..., x_{k}^{^{c=i}}}_{\text{cat feature (i)}}]

cat_normalize ensures the generated cf that satisfy the categorical constraints, i.e., \sum_j x^{c=i}_j=1, x^{c=i}_j > 0, \forall c=[1, ..., i].

cat_idx is the index of the first categorical feature. In the above example, cat_idx is m+1.

For example, let’s define a valid input data point:

x = np.array([
    [1., .9, 'dog', 'gray'],
    [.3, .3, 'cat', 'gray'],
    [.7, .1, 'fish', 'red'],
    [1., .6, 'dog', 'gray'],
    [.1, .2, 'fish', 'yellow']
])

We encode the categorical features via the OneHotEncoder in sklearn.

from sklearn.preprocessing import OneHotEncoder
cat_idx = 2
ohe = OneHotEncoder(sparse=False)
x_cat = ohe.fit_transform(x[:, cat_idx:])
x_cont = x[:, :cat_idx].astype(float)
x_transformed = np.concatenate(
    (x_cont, x_cat), axis=1
)

If hard=True, the categorical features are in one-hot format.

cfs = np.random.randn(*x_transformed.shape)
cfs = cat_normalize(cfs, ohe.categories_, 
    cat_idx=cat_idx, hard=True)
cfs[:1]
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Array([[-1.1198508 , -0.57364744,  0.        ,  1.        ,  0.        ,
         1.        ,  0.        ,  0.        ]], dtype=float32)

If hard=False, the categorical features are normalized via softmax function.

cfs = np.random.randn(*x_transformed.shape)
cfs = cat_normalize(cfs, ohe.categories_, 
    cat_idx=cat_idx, hard=False)
n_cat_feats = len(ohe.categories_)

assert (cfs[:, cat_idx:].sum(axis=1) - n_cat_feats * jnp.ones(len(cfs))).sum() < 1e-6

Vectorization Utils


source

AUTO_RESHAPING

relax.utils.auto_reshaping (reshape_argname)

Decorator to automatically reshape function’s input into (1, k), and out to input’s shape.

This decorator ensures that the specified input argument and output of a function are in the same shape. This is particularly useful when using jax.vamp.

@auto_reshaping('x')
def f_vmap(x): return x * jnp.ones((10,))
assert vmap(f_vmap)(jnp.ones((10, 10))).shape == (10, 10)

Training Utils


source

MAKE_MODEL

relax.utils.make_model (m_configs, model)

Parameters:
  • m_configs (Dict[str, Any])
  • model (hk.Module) – model configs
Returns:

    (hk.Transformed)


source

MAKE_HK_MODULE

relax.utils.make_hk_module (module, *args, **kargs)

Parameters:
  • module (hk.Module) – haiku module
  • args
  • kargs
Returns:

    (hk.Transformed) – haiku module arguments haiku module arguments

<string>:1: DeprecationWarning: jax.numpy.DeviceArray is deprecated. Use jax.Array.

source

INIT_NET_OPT

relax.utils.init_net_opt (net, opt, X, key)


source

GRAD_UPDATE

relax.utils.grad_update (grads, params, opt_state, opt)


source

CHECK_CAT_INFO

relax.utils.check_cat_info (method)

Helper functions


source

LOAD_JSON

relax.utils.load_json (f_name)

Parameters:
  • f_name (str)
Returns:

    (Dict[str, Any]) – file name

Loss Functions

<string>:1: DeprecationWarning: jax.numpy.DeviceArray is deprecated. Use jax.Array.

source

BINARY_CROSS_ENTROPY

relax.utils.binary_cross_entropy (preds, labels)

Per-sample binary cross-entropy loss function.

Parameters:
  • preds (jnp.DeviceArray) – The predicted values
  • labels (jnp.DeviceArray) – The ground-truth labels
Returns:

    (jnp.DeviceArray) – Loss value


source

SIGMOID

relax.utils.sigmoid (x)

Metrics

<string>:1: DeprecationWarning: jax.numpy.DeviceArray is deprecated. Use jax.Array.

source

PROXIMITY

relax.utils.proximity (x, cf)

<string>:1: DeprecationWarning: jax.numpy.DeviceArray is deprecated. Use jax.Array.

source

DIST

relax.utils.dist (x, cf, ord=2)

<string>:1: DeprecationWarning: jax.numpy.DeviceArray is deprecated. Use jax.Array.

source

ACCURACY

relax.utils.accuracy (y_true, y_pred)

Config


source

GET_CONFIG

relax.utils.get_config ()