# pred_prob = jrand.uniform(jrand.PRNGKey(0), (6, ))
# pred_prob = jnp.array([0.3, 0.1, 0.8, 0.8, .99, .99])
pred_prob = jnp.array([.99, .99, 0.3, 0.1, 0.1, 0.1])
y = jnp.array([1, 1, 0, 0, 0, 0])
target = jnp.array([-1])
tempt_1, tempt_0 = pred_prob[y == 1], pred_prob[y == 0]
validity_loss_1 = hindge_embedding_loss(tempt_1 - (1. - tempt_1), target, 0.165) + \
hindge_embedding_loss(1. - 2 * tempt_0, target, 0.165)
# tempt_1 = hindge_embedding_loss(pred_prob - (1. - pred_prob), target, 0.165)
# tempt_0 = hindge_embedding_loss(1. - 2 * pred_prob, target, 0.165)
# validity_loss = jnp.where(
# y == 1, tempt_1, tempt_0
# )
tempt_1 = jnp.where(
y == 1,
hindge_embedding_loss(pred_prob - (1. - pred_prob), target, 0.165, reduction=None),
0
).sum() / y.sum()
tempt_0 = jnp.where(
y == 0,
hindge_embedding_loss(1. - 2 * pred_prob, target, 0.165, reduction=None),
0
).sum() / (y.shape[0] - y.sum())
# validity_loss = jnp.where(
# y == 1,
# hindge_embedding_loss(pred_prob - (1. - pred_prob), target, 0.165, reduction=None),
# hindge_embedding_loss(1. - 2 * pred_prob, target, 0.165, reduction=None)
# )
# validity_loss_2 = jnp.sum(validity_loss)
validity_loss_2 = tempt_1 + tempt_0VAECF
hindge_embedding_loss(tempt_1 - (1. - tempt_1), target, 0.165)DeviceArray(1.165, dtype=float32)
validity_loss_1, validity_loss_2(DeviceArray(0., dtype=float32), DeviceArray(0., dtype=float32))
VAECFCONFIGS
CLASS relax.methods.vaecf.VAECFConfigs (enc_sizes=[20, 16, 14, 12, 5], dec_sizes=[12, 14, 16, 20], dropout_rate=0.1, lr=0.001, mu_samples=50, validity_reg=42.0)
Configurator of VAECFModule.
VAECF
CLASS relax.methods.vaecf.VAECF (m_config=None)
Base CF Explanation Module.
Test
from relax.trainer import train_model
from relax.data import load_data
from relax.module import PredictiveTrainingModule, load_pred_model
from relax.evaluate import _AuxPredFn, generate_cf_explanations, benchmark_cfsdm = load_data('adult', data_configs=dict(sample_frac=0.1))# load model
params, training_module = load_pred_model('adult')
# predict function
pred_fn = training_module.pred_fnvaecf = VAECF()
# vaecf.train(dm, t_config, pred_fn)cf_exp = generate_cf_explanations(
vaecf, dm, pred_fn, pred_fn_args=dict(
params=params, rng_key=random.PRNGKey(0)
)
)VAECF contains parametric models. Starts training before generating explanations...
/Users/chuck/opt/anaconda3/envs/relax/lib/python3.8/site-packages/relax/_ckpt_manager.py:47: UserWarning: `monitor_metrics` is not specified in `CheckpointManager`. No checkpoints will be stored.
warnings.warn(
Epoch 9: 100%|██████████| 20/20 [00:00<00:00, 81.43batch/s, train/loss=nan]
benchmark_cfs([cf_exp])| acc | validity | proximity | ||
|---|---|---|---|---|
| adult | C-CHVAE | 0.8241 | 0.18241 | 7.563876 |