from relax.module import PredictiveTrainingModule, load_pred_model
from relax.evaluate import generate_cf_explanations, benchmark_cfsCCHVAE
CCHVAECONFIGS
CLASS relax.methods.cchvae.CCHVAEConfigs (enc_sizes=[20, 16, 14, 12], dec_sizes=[12, 14, 16, 20], encoded_size=5, lr=0.001, max_steps=100, n_search_samples=300, step_size=0.1, seed=0)
Create a new model by parsing and validating input data from keyword arguments.
Raises ValidationError if the input data cannot be parsed to form a valid model.
CCHVAE
CLASS relax.methods.cchvae.CCHVAE (configs=None)
Base CF Explanation Module.
Test
dm = load_data('adult', data_configs=dict(sample_frac=0.1))# load model
params, training_module = load_pred_model('adult')
# predict function
# pred_fn = lambda x: training_module.forward(params, x, is_training=False)
pred_fn = lambda x, params, key: training_module.forward(
params, key, x, is_training=False
)cchvae_test = CCHVAE()
cchvae_test.train(dm)/Users/chuck/opt/anaconda3/envs/relax/lib/python3.8/site-packages/haiku/_src/base.py:515: UserWarning: Explicitly requested dtype float64 requested in zeros is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
param = init(shape, dtype)
/Users/chuck/opt/anaconda3/envs/relax/lib/python3.8/site-packages/relax/_ckpt_manager.py:47: UserWarning: `monitor_metrics` is not specified in `CheckpointManager`. No checkpoints will be stored.
warnings.warn(
Epoch 9: 100%|██████████| 20/20 [00:00<00:00, 493.63batch/s, train/loss=0.552]
cf_exp = generate_cf_explanations(
cchvae_test, dm, pred_fn, pred_fn_args=dict(
params=params, key=random.PRNGKey(0)
)
)benchmark_cfs([cf_exp])| acc | validity | proximity | ||
|---|---|---|---|---|
| adult | C-CHVAE | 0.8241 | 1.0 | 4.124321 |