from relax.module import PredictiveTrainingModule, load_pred_model
from relax.evaluate import generate_cf_explanations, benchmark_cfs
CCHVAE
CCHVAECONFIGS
CLASS relax.methods.cchvae.CCHVAEConfigs (enc_sizes=[20, 16, 14, 12], dec_sizes=[12, 14, 16, 20], encoded_size=5, lr=0.001, max_steps=100, n_search_samples=300, step_size=0.1, seed=0)
Create a new model by parsing and validating input data from keyword arguments.
Raises ValidationError if the input data cannot be parsed to form a valid model.
CCHVAE
CLASS relax.methods.cchvae.CCHVAE (configs=None)
Base CF Explanation Module.
Test
= load_data('adult', data_configs=dict(sample_frac=0.1)) dm
# load model
= load_pred_model('adult')
params, training_module
# predict function
# pred_fn = lambda x: training_module.forward(params, x, is_training=False)
= lambda x, params, key: training_module.forward(
pred_fn =False
params, key, x, is_training )
= CCHVAE()
cchvae_test cchvae_test.train(dm)
/Users/chuck/opt/anaconda3/envs/relax/lib/python3.8/site-packages/haiku/_src/base.py:515: UserWarning: Explicitly requested dtype float64 requested in zeros is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
param = init(shape, dtype)
/Users/chuck/opt/anaconda3/envs/relax/lib/python3.8/site-packages/relax/_ckpt_manager.py:47: UserWarning: `monitor_metrics` is not specified in `CheckpointManager`. No checkpoints will be stored.
warnings.warn(
Epoch 9: 100%|██████████| 20/20 [00:00<00:00, 493.63batch/s, train/loss=0.552]
= generate_cf_explanations(
cf_exp =dict(
cchvae_test, dm, pred_fn, pred_fn_args=params, key=random.PRNGKey(0)
params
) )
benchmark_cfs([cf_exp])
acc | validity | proximity | ||
---|---|---|---|---|
adult | C-CHVAE | 0.8241 | 1.0 | 4.124321 |