# pred_prob = jrand.uniform(jrand.PRNGKey(0), (6, ))
# pred_prob = jnp.array([0.3, 0.1, 0.8, 0.8, .99, .99])
= jnp.array([.99, .99, 0.3, 0.1, 0.1, 0.1])
pred_prob
= jnp.array([1, 1, 0, 0, 0, 0])
y = jnp.array([-1])
target
= pred_prob[y == 1], pred_prob[y == 0]
tempt_1, tempt_0 = hindge_embedding_loss(tempt_1 - (1. - tempt_1), target, 0.165) + \
validity_loss_1 1. - 2 * tempt_0, target, 0.165)
hindge_embedding_loss(
# tempt_1 = hindge_embedding_loss(pred_prob - (1. - pred_prob), target, 0.165)
# tempt_0 = hindge_embedding_loss(1. - 2 * pred_prob, target, 0.165)
# validity_loss = jnp.where(
# y == 1, tempt_1, tempt_0
# )
= jnp.where(
tempt_1 == 1,
y - (1. - pred_prob), target, 0.165, reduction=None),
hindge_embedding_loss(pred_prob 0
sum() / y.sum()
).= jnp.where(
tempt_0 == 0,
y 1. - 2 * pred_prob, target, 0.165, reduction=None),
hindge_embedding_loss(0
sum() / (y.shape[0] - y.sum())
).# validity_loss = jnp.where(
# y == 1,
# hindge_embedding_loss(pred_prob - (1. - pred_prob), target, 0.165, reduction=None),
# hindge_embedding_loss(1. - 2 * pred_prob, target, 0.165, reduction=None)
# )
# validity_loss_2 = jnp.sum(validity_loss)
= tempt_1 + tempt_0 validity_loss_2
VAECF
- (1. - tempt_1), target, 0.165) hindge_embedding_loss(tempt_1
DeviceArray(1.165, dtype=float32)
validity_loss_1, validity_loss_2
(DeviceArray(0., dtype=float32), DeviceArray(0., dtype=float32))
VAECFCONFIGS
CLASS relax.methods.vaecf.VAECFConfigs (enc_sizes=[20, 16, 14, 12, 5], dec_sizes=[12, 14, 16, 20], dropout_rate=0.1, lr=0.001, mu_samples=50, validity_reg=42.0)
Configurator of VAECFModule
.
VAECF
CLASS relax.methods.vaecf.VAECF (m_config=None)
Base CF Explanation Module.
Test
from relax.trainer import train_model
from relax.data import load_data
from relax.module import PredictiveTrainingModule, load_pred_model
from relax.evaluate import _AuxPredFn, generate_cf_explanations, benchmark_cfs
= load_data('adult', data_configs=dict(sample_frac=0.1)) dm
# load model
= load_pred_model('adult')
params, training_module
# predict function
= training_module.pred_fn pred_fn
= VAECF()
vaecf # vaecf.train(dm, t_config, pred_fn)
= generate_cf_explanations(
cf_exp =dict(
vaecf, dm, pred_fn, pred_fn_args=params, rng_key=random.PRNGKey(0)
params
) )
VAECF contains parametric models. Starts training before generating explanations...
/Users/chuck/opt/anaconda3/envs/relax/lib/python3.8/site-packages/relax/_ckpt_manager.py:47: UserWarning: `monitor_metrics` is not specified in `CheckpointManager`. No checkpoints will be stored.
warnings.warn(
Epoch 9: 100%|██████████| 20/20 [00:00<00:00, 81.43batch/s, train/loss=nan]
benchmark_cfs([cf_exp])
acc | validity | proximity | ||
---|---|---|---|---|
adult | C-CHVAE | 0.8241 | 0.18241 | 7.563876 |