CCHVAE


source

CCHVAECONFIGS

CLASS relax.methods.cchvae.CCHVAEConfigs (enc_sizes=[20, 16, 14, 12], dec_sizes=[12, 14, 16, 20], encoded_size=5, lr=0.001, max_steps=100, n_search_samples=300, step_size=0.1, seed=0)

Create a new model by parsing and validating input data from keyword arguments.

Raises ValidationError if the input data cannot be parsed to form a valid model.

Parameters:
  • enc_sizes (List[int], default=[20, 16, 14, 12]) – Encoder hidden sizes
  • dec_sizes (List[int], default=[12, 14, 16, 20]) – Decoder hidden sizes
  • encoded_size (int, default=5) – Encoded size
  • lr (float, default=0.001) – Learning rate
  • max_steps (int, default=100) – Max steps
  • n_search_samples (int, default=300) – Number of generated candidate counterfactuals.
  • step_size (float, default=0.1) – Step size
  • seed (int, default=0) – Seed for random number generator

source

CCHVAE

CLASS relax.methods.cchvae.CCHVAE (configs=None)

Base CF Explanation Module.

Test

from relax.module import PredictiveTrainingModule, load_pred_model
from relax.evaluate import generate_cf_explanations, benchmark_cfs
dm = load_data('adult', data_configs=dict(sample_frac=0.1))
# load model
params, training_module = load_pred_model('adult')

# predict function
# pred_fn = lambda x: training_module.forward(params, x, is_training=False)
pred_fn = lambda x, params, key: training_module.forward(
    params, key, x, is_training=False
)
cchvae_test = CCHVAE()
cchvae_test.train(dm)
/Users/chuck/opt/anaconda3/envs/relax/lib/python3.8/site-packages/haiku/_src/base.py:515: UserWarning: Explicitly requested dtype float64 requested in zeros is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  param = init(shape, dtype)
/Users/chuck/opt/anaconda3/envs/relax/lib/python3.8/site-packages/relax/_ckpt_manager.py:47: UserWarning: `monitor_metrics` is not specified in `CheckpointManager`. No checkpoints will be stored.
  warnings.warn(
Epoch 9: 100%|██████████| 20/20 [00:00<00:00, 493.63batch/s, train/loss=0.552]
cf_exp = generate_cf_explanations(
    cchvae_test, dm, pred_fn, pred_fn_args=dict(
        params=params, key=random.PRNGKey(0)
    )
)
benchmark_cfs([cf_exp])
acc validity proximity
adult C-CHVAE 0.8241 1.0 4.124321