class LearningConfigs(BaseParser):
lr: floatUtils
relax.
Configurations
relax.utils.validate_configs
relax.utils.validate_configs (configs, config_cls)
return a valid configuration object.
Parameters:
- configs (
dict | pydantic.main.BaseModel) – A configuration of the model/dataset. - config_cls (
<class 'pydantic.main.BaseModel'>) – The desired configuration class.
Returns:
(<class 'pydantic.main.BaseModel'>)
We define a configuration object (which inherent BaseParser) to manage training/model/data configurations. validate_configs ensures to return the designated configuration object.
For example, we define a configuration object LearningConfigs:
A configuration can be LearningConfigs, or the raw data in dictionary.
configs_dict = dict(lr=0.01)validate_configs will return a designated configuration object.
configs = validate_configs(configs_dict, LearningConfigs)
assert type(configs) == LearningConfigs
assert configs.lr == configs_dict['lr']Categorical normalization
relax.legacy.utils.cat_normalize
relax.legacy.utils.cat_normalize (cf, cat_arrays, cat_idx, hard=False)
Ensure generated counterfactual explanations to respect one-hot encoding constraints.
Parameters:
- cf (
<class 'jax.Array'>) – Unnormalized counterfactual explanations[n_samples, n_features] - cat_arrays (
typing.List[typing.List[str]]) – A list of a list of each categorical feature name - cat_idx (
<class 'int'>) – Index that starts categorical features - hard (
<class 'bool'>, default=False) – IfTrue, return one-hot vectors; IfFalse, return probability normalized via softmax
Returns:
(<class 'jax.Array'>)
A tabular data point is encoded as x = [\underbrace{x_{0}, x_{1}, ..., x_{m}}_{\text{cont features}}, \underbrace{x_{m+1}^{c=1},..., x_{m+p}^{c=1}}_{\text{cat feature (1)}}, ..., \underbrace{x_{k-q}^{c=i},..., x_{k}^{^{c=i}}}_{\text{cat feature (i)}}]
cat_normalize ensures the generated cf that satisfy the categorical constraints, i.e., \sum_j x^{c=i}_j=1, x^{c=i}_j > 0, \forall c=[1, ..., i].
cat_idx is the index of the first categorical feature. In the above example, cat_idx is m+1.
For example, let’s define a valid input data point:
x = np.array([
[1., .9, 'dog', 'gray'],
[.3, .3, 'cat', 'gray'],
[.7, .1, 'fish', 'red'],
[1., .6, 'dog', 'gray'],
[.1, .2, 'fish', 'yellow']
])We encode the categorical features via the OneHotEncoder in sklearn.
from sklearn.preprocessing import OneHotEncodercat_idx = 2
ohe = OneHotEncoder(sparse_output=False)
x_cat = ohe.fit_transform(x[:, cat_idx:])
x_cont = x[:, :cat_idx].astype(float)
x_transformed = np.concatenate(
(x_cont, x_cat), axis=1
)If hard=True, the categorical features are in one-hot format.
cfs = np.random.randn(*x_transformed.shape)
cfs = cat_normalize(cfs, ohe.categories_,
cat_idx=cat_idx, hard=True)
cfs[:1]No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Array([[-0.47835127, -0.32345298, 1. , 0. , 0. ,
0. , 0. , 1. ]], dtype=float32)
If hard=False, the categorical features are normalized via softmax function.
cfs = np.random.randn(*x_transformed.shape)
cfs = cat_normalize(cfs, ohe.categories_,
cat_idx=cat_idx, hard=False)
n_cat_feats = len(ohe.categories_)
assert (cfs[:, cat_idx:].sum(axis=1) - n_cat_feats * jnp.ones(len(cfs))).sum() < 1e-6Training Utils
relax.legacy.utils.make_model
relax.legacy.utils.make_model (m_configs, model)
Parameters:
- m_configs (
typing.Dict[str, typing.Any]) - model (
<class 'haiku._src.module.Module'>) – model configs
Returns:
(<class 'haiku._src.transform.Transformed'>)
relax.legacy.utils.make_hk_module
relax.legacy.utils.make_hk_module (module, *args, **kargs)
Parameters:
- module (
<class 'haiku._src.module.Module'>) – haiku module - args (
VAR_POSITIONAL) – haiku module arguments - kargs (
VAR_KEYWORD)
Returns:
(<class 'haiku._src.transform.Transformed'>) – haiku module arguments
relax.legacy.utils.init_net_opt
relax.legacy.utils.init_net_opt (net, opt, X, key)
relax.utils.grad_update
relax.utils.grad_update (grads, params, opt_state, opt)
relax.legacy.utils.check_cat_info
relax.legacy.utils.check_cat_info (method)
Helper functions
relax.utils.load_json
relax.utils.load_json (f_name)
Parameters:
- f_name (
<class 'str'>)
Returns:
(typing.Dict[str, typing.Any]) – file name
Loss Functions
relax.legacy.utils.binary_cross_entropy
relax.legacy.utils.binary_cross_entropy (preds, labels)
Per-sample binary cross-entropy loss function.
Parameters:
- preds (
<class 'jax.Array'>) – The predicted values - labels (
<class 'jax.Array'>) – The ground-truth labels
Returns:
(<class 'jax.Array'>) – Loss value
relax.legacy.utils.sigmoid
relax.legacy.utils.sigmoid (x)
Metrics
relax.legacy.utils.proximity
relax.legacy.utils.proximity (x, cf)
relax.legacy.utils.dist
relax.legacy.utils.dist (x, cf, ord=2)
relax.legacy.utils.accuracy
relax.legacy.utils.accuracy (y_true, y_pred)
Config
relax.utils.get_config
relax.utils.get_config ()