class LearningConfigs(BaseParser):
float lr:
Utils
relax
.
Configurations
VALIDATE_CONFIGS
relax.utils.validate_configs (configs, config_cls)
return a valid configuration object.
We define a configuration object (which inherent BaseParser
) to manage training/model/data configurations. validate_configs
ensures to return the designated configuration object.
For example, we define a configuration object LearningConfigs
:
A configuration can be LearningConfigs
, or the raw data in dictionary.
= dict(lr=0.01) configs_dict
validate_configs
will return a designated configuration object.
= validate_configs(configs_dict, LearningConfigs)
configs assert type(configs) == LearningConfigs
assert configs.lr == configs_dict['lr']
SHOW_DOC
relax.utils.show_doc (sym)
Same functionality as nbdev.show_doc, but provide additional support for BaseParser
.
Categorical normalization
CAT_NORMALIZE
relax.utils.cat_normalize (cf, cat_arrays, cat_idx, hard=False)
Ensure generated counterfactual explanations to respect one-hot encoding constraints.
A tabular data point is encoded as x = [\underbrace{x_{0}, x_{1}, ..., x_{m}}_{\text{cont features}}, \underbrace{x_{m+1}^{c=1},..., x_{m+p}^{c=1}}_{\text{cat feature (1)}}, ..., \underbrace{x_{k-q}^{c=i},..., x_{k}^{^{c=i}}}_{\text{cat feature (i)}}]
cat_normalize
ensures the generated cf
that satisfy the categorical constraints, i.e., \sum_j x^{c=i}_j=1, x^{c=i}_j > 0, \forall c=[1, ..., i].
cat_idx
is the index of the first categorical feature. In the above example, cat_idx
is m+1
.
For example, let’s define a valid input data point:
= np.array([
x 1., .9, 'dog', 'gray'],
[.3, .3, 'cat', 'gray'],
[.7, .1, 'fish', 'red'],
[1., .6, 'dog', 'gray'],
[.1, .2, 'fish', 'yellow']
[ ])
We encode the categorical features via the OneHotEncoder
in sklearn.
from sklearn.preprocessing import OneHotEncoder
= 2
cat_idx = OneHotEncoder(sparse=False)
ohe = ohe.fit_transform(x[:, cat_idx:])
x_cat = x[:, :cat_idx].astype(float)
x_cont = np.concatenate(
x_transformed =1
(x_cont, x_cat), axis )
If hard=True
, the categorical features are in one-hot format.
= np.random.randn(*x_transformed.shape)
cfs = cat_normalize(cfs, ohe.categories_,
cfs =cat_idx, hard=True)
cat_idx1] cfs[:
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Array([[-1.1198508 , -0.57364744, 0. , 1. , 0. ,
1. , 0. , 0. ]], dtype=float32)
If hard=False
, the categorical features are normalized via softmax function.
= np.random.randn(*x_transformed.shape)
cfs = cat_normalize(cfs, ohe.categories_,
cfs =cat_idx, hard=False)
cat_idx= len(ohe.categories_)
n_cat_feats
assert (cfs[:, cat_idx:].sum(axis=1) - n_cat_feats * jnp.ones(len(cfs))).sum() < 1e-6
Vectorization Utils
AUTO_RESHAPING
relax.utils.auto_reshaping (reshape_argname)
Decorator to automatically reshape function’s input into (1, k), and out to input’s shape.
This decorator ensures that the specified input argument and output of a function are in the same shape. This is particularly useful when using jax.vamp
.
@auto_reshaping('x')
def f_vmap(x): return x * jnp.ones((10,))
assert vmap(f_vmap)(jnp.ones((10, 10))).shape == (10, 10)
Training Utils
MAKE_MODEL
relax.utils.make_model (m_configs, model)
MAKE_HK_MODULE
relax.utils.make_hk_module (module, *args, **kargs)
<string>:1: DeprecationWarning: jax.numpy.DeviceArray is deprecated. Use jax.Array.
INIT_NET_OPT
relax.utils.init_net_opt (net, opt, X, key)
GRAD_UPDATE
relax.utils.grad_update (grads, params, opt_state, opt)
CHECK_CAT_INFO
relax.utils.check_cat_info (method)
Helper functions
LOAD_JSON
relax.utils.load_json (f_name)
Loss Functions
<string>:1: DeprecationWarning: jax.numpy.DeviceArray is deprecated. Use jax.Array.
BINARY_CROSS_ENTROPY
relax.utils.binary_cross_entropy (preds, labels)
Per-sample binary cross-entropy loss function.
SIGMOID
relax.utils.sigmoid (x)
Metrics
<string>:1: DeprecationWarning: jax.numpy.DeviceArray is deprecated. Use jax.Array.
PROXIMITY
relax.utils.proximity (x, cf)
<string>:1: DeprecationWarning: jax.numpy.DeviceArray is deprecated. Use jax.Array.
DIST
relax.utils.dist (x, cf, ord=2)
<string>:1: DeprecationWarning: jax.numpy.DeviceArray is deprecated. Use jax.Array.
ACCURACY
relax.utils.accuracy (y_true, y_pred)
Config
GET_CONFIG
relax.utils.get_config ()