class LearningConfigs(BaseParser):
float lr:
Utils
relax
.
Configurations
relax.utils.validate_configs
relax.utils.validate_configs (configs, config_cls)
return a valid configuration object.
Parameters:
- configs (
dict | pydantic.main.BaseModel
) – A configuration of the model/dataset. - config_cls (
<class 'pydantic.main.BaseModel'>
) – The desired configuration class.
Returns:
(<class 'pydantic.main.BaseModel'>
)
We define a configuration object (which inherent BaseParser
) to manage training/model/data configurations. validate_configs
ensures to return the designated configuration object.
For example, we define a configuration object LearningConfigs
:
A configuration can be LearningConfigs
, or the raw data in dictionary.
= dict(lr=0.01) configs_dict
validate_configs
will return a designated configuration object.
= validate_configs(configs_dict, LearningConfigs)
configs assert type(configs) == LearningConfigs
assert configs.lr == configs_dict['lr']
Categorical normalization
relax.legacy.utils.cat_normalize
relax.legacy.utils.cat_normalize (cf, cat_arrays, cat_idx, hard=False)
Ensure generated counterfactual explanations to respect one-hot encoding constraints.
Parameters:
- cf (
<class 'jax.Array'>
) – Unnormalized counterfactual explanations[n_samples, n_features]
- cat_arrays (
typing.List[typing.List[str]]
) – A list of a list of each categorical feature name - cat_idx (
<class 'int'>
) – Index that starts categorical features - hard (
<class 'bool'>
, default=False) – IfTrue
, return one-hot vectors; IfFalse
, return probability normalized via softmax
Returns:
(<class 'jax.Array'>
)
A tabular data point is encoded as x = [\underbrace{x_{0}, x_{1}, ..., x_{m}}_{\text{cont features}}, \underbrace{x_{m+1}^{c=1},..., x_{m+p}^{c=1}}_{\text{cat feature (1)}}, ..., \underbrace{x_{k-q}^{c=i},..., x_{k}^{^{c=i}}}_{\text{cat feature (i)}}]
cat_normalize
ensures the generated cf
that satisfy the categorical constraints, i.e., \sum_j x^{c=i}_j=1, x^{c=i}_j > 0, \forall c=[1, ..., i].
cat_idx
is the index of the first categorical feature. In the above example, cat_idx
is m+1
.
For example, let’s define a valid input data point:
= np.array([
x 1., .9, 'dog', 'gray'],
[.3, .3, 'cat', 'gray'],
[.7, .1, 'fish', 'red'],
[1., .6, 'dog', 'gray'],
[.1, .2, 'fish', 'yellow']
[ ])
We encode the categorical features via the OneHotEncoder
in sklearn.
from sklearn.preprocessing import OneHotEncoder
= 2
cat_idx = OneHotEncoder(sparse_output=False)
ohe = ohe.fit_transform(x[:, cat_idx:])
x_cat = x[:, :cat_idx].astype(float)
x_cont = np.concatenate(
x_transformed =1
(x_cont, x_cat), axis )
If hard=True
, the categorical features are in one-hot format.
= np.random.randn(*x_transformed.shape)
cfs = cat_normalize(cfs, ohe.categories_,
cfs =cat_idx, hard=True)
cat_idx1] cfs[:
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Array([[-0.47835127, -0.32345298, 1. , 0. , 0. ,
0. , 0. , 1. ]], dtype=float32)
If hard=False
, the categorical features are normalized via softmax function.
= np.random.randn(*x_transformed.shape)
cfs = cat_normalize(cfs, ohe.categories_,
cfs =cat_idx, hard=False)
cat_idx= len(ohe.categories_)
n_cat_feats
assert (cfs[:, cat_idx:].sum(axis=1) - n_cat_feats * jnp.ones(len(cfs))).sum() < 1e-6
Training Utils
relax.legacy.utils.make_model
relax.legacy.utils.make_model (m_configs, model)
Parameters:
- m_configs (
typing.Dict[str, typing.Any]
) - model (
<class 'haiku._src.module.Module'>
) – model configs
Returns:
(<class 'haiku._src.transform.Transformed'>
)
relax.legacy.utils.make_hk_module
relax.legacy.utils.make_hk_module (module, *args, **kargs)
Parameters:
- module (
<class 'haiku._src.module.Module'>
) – haiku module - args
- kargs
Returns:
(<class 'haiku._src.transform.Transformed'>
) – haiku module arguments haiku module arguments
relax.legacy.utils.init_net_opt
relax.legacy.utils.init_net_opt (net, opt, X, key)
relax.utils.grad_update
relax.utils.grad_update (grads, params, opt_state, opt)
relax.legacy.utils.check_cat_info
relax.legacy.utils.check_cat_info (method)
Helper functions
relax.utils.load_json
relax.utils.load_json (f_name)
Parameters:
- f_name (
<class 'str'>
)
Returns:
(typing.Dict[str, typing.Any]
) – file name
Loss Functions
relax.legacy.utils.binary_cross_entropy
relax.legacy.utils.binary_cross_entropy (preds, labels)
Per-sample binary cross-entropy loss function.
Parameters:
- preds (
<class 'jax.Array'>
) – The predicted values - labels (
<class 'jax.Array'>
) – The ground-truth labels
Returns:
(<class 'jax.Array'>
) – Loss value
relax.legacy.utils.sigmoid
relax.legacy.utils.sigmoid (x)
Metrics
relax.legacy.utils.proximity
relax.legacy.utils.proximity (x, cf)
relax.legacy.utils.dist
relax.legacy.utils.dist (x, cf, ord=2)
relax.legacy.utils.accuracy
relax.legacy.utils.accuracy (y_true, y_pred)
Config
relax.utils.get_config
relax.utils.get_config ()