Utils

Define utility funtions for relax.

Configurations

relax.utils.validate_configs

[source]

relax.utils.validate_configs (configs, config_cls)

return a valid configuration object.

Parameters:

  • configs (dict | pydantic.main.BaseModel) – A configuration of the model/dataset.
  • config_cls (<class 'pydantic.main.BaseModel'>) – The desired configuration class.

Returns:

    (<class 'pydantic.main.BaseModel'>)

We define a configuration object (which inherent BaseParser) to manage training/model/data configurations. validate_configs ensures to return the designated configuration object.

For example, we define a configuration object LearningConfigs:

class LearningConfigs(BaseParser):
    lr: float

A configuration can be LearningConfigs, or the raw data in dictionary.

configs_dict = dict(lr=0.01)

validate_configs will return a designated configuration object.

configs = validate_configs(configs_dict, LearningConfigs)
assert type(configs) == LearningConfigs
assert configs.lr == configs_dict['lr']

Categorical normalization

relax.legacy.utils.cat_normalize

[source]

relax.legacy.utils.cat_normalize (cf, cat_arrays, cat_idx, hard=False)

Ensure generated counterfactual explanations to respect one-hot encoding constraints.

Parameters:

  • cf (<class 'jax.Array'>) – Unnormalized counterfactual explanations [n_samples, n_features]
  • cat_arrays (typing.List[typing.List[str]]) – A list of a list of each categorical feature name
  • cat_idx (<class 'int'>) – Index that starts categorical features
  • hard (<class 'bool'>, default=False) – If True, return one-hot vectors; If False, return probability normalized via softmax

Returns:

    (<class 'jax.Array'>)

A tabular data point is encoded as x = [\underbrace{x_{0}, x_{1}, ..., x_{m}}_{\text{cont features}}, \underbrace{x_{m+1}^{c=1},..., x_{m+p}^{c=1}}_{\text{cat feature (1)}}, ..., \underbrace{x_{k-q}^{c=i},..., x_{k}^{^{c=i}}}_{\text{cat feature (i)}}]

cat_normalize ensures the generated cf that satisfy the categorical constraints, i.e., \sum_j x^{c=i}_j=1, x^{c=i}_j > 0, \forall c=[1, ..., i].

cat_idx is the index of the first categorical feature. In the above example, cat_idx is m+1.

For example, let’s define a valid input data point:

x = np.array([
    [1., .9, 'dog', 'gray'],
    [.3, .3, 'cat', 'gray'],
    [.7, .1, 'fish', 'red'],
    [1., .6, 'dog', 'gray'],
    [.1, .2, 'fish', 'yellow']
])

We encode the categorical features via the OneHotEncoder in sklearn.

from sklearn.preprocessing import OneHotEncoder
cat_idx = 2
ohe = OneHotEncoder(sparse_output=False)
x_cat = ohe.fit_transform(x[:, cat_idx:])
x_cont = x[:, :cat_idx].astype(float)
x_transformed = np.concatenate(
    (x_cont, x_cat), axis=1
)

If hard=True, the categorical features are in one-hot format.

cfs = np.random.randn(*x_transformed.shape)
cfs = cat_normalize(cfs, ohe.categories_, 
    cat_idx=cat_idx, hard=True)
cfs[:1]
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Array([[-0.47835127, -0.32345298,  1.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  1.        ]], dtype=float32)

If hard=False, the categorical features are normalized via softmax function.

cfs = np.random.randn(*x_transformed.shape)
cfs = cat_normalize(cfs, ohe.categories_, 
    cat_idx=cat_idx, hard=False)
n_cat_feats = len(ohe.categories_)

assert (cfs[:, cat_idx:].sum(axis=1) - n_cat_feats * jnp.ones(len(cfs))).sum() < 1e-6

Training Utils

relax.legacy.utils.make_model

[source]

relax.legacy.utils.make_model (m_configs, model)

Parameters:

  • m_configs (typing.Dict[str, typing.Any])
  • model (<class 'haiku._src.module.Module'>) – model configs

Returns:

    (<class 'haiku._src.transform.Transformed'>)

relax.legacy.utils.make_hk_module

[source]

relax.legacy.utils.make_hk_module (module, *args, **kargs)

Parameters:

  • module (<class 'haiku._src.module.Module'>) – haiku module
  • args
  • kargs

Returns:

    (<class 'haiku._src.transform.Transformed'>) – haiku module arguments haiku module arguments

relax.legacy.utils.init_net_opt

[source]

relax.legacy.utils.init_net_opt (net, opt, X, key)

relax.utils.grad_update

[source]

relax.utils.grad_update (grads, params, opt_state, opt)

relax.legacy.utils.check_cat_info

[source]

relax.legacy.utils.check_cat_info (method)

Helper functions

relax.utils.load_json

[source]

relax.utils.load_json (f_name)

Parameters:

  • f_name (<class 'str'>)

Returns:

    (typing.Dict[str, typing.Any]) – file name

Loss Functions

relax.legacy.utils.binary_cross_entropy

[source]

relax.legacy.utils.binary_cross_entropy (preds, labels)

Per-sample binary cross-entropy loss function.

Parameters:

  • preds (<class 'jax.Array'>) – The predicted values
  • labels (<class 'jax.Array'>) – The ground-truth labels

Returns:

    (<class 'jax.Array'>) – Loss value

relax.legacy.utils.sigmoid

[source]

relax.legacy.utils.sigmoid (x)

Metrics

relax.legacy.utils.proximity

[source]

relax.legacy.utils.proximity (x, cf)

relax.legacy.utils.dist

[source]

relax.legacy.utils.dist (x, cf, ord=2)

relax.legacy.utils.accuracy

[source]

relax.legacy.utils.accuracy (y_true, y_pred)

Config

relax.utils.get_config

[source]

relax.utils.get_config ()