VAECF

# pred_prob = jrand.uniform(jrand.PRNGKey(0), (6, ))
# pred_prob = jnp.array([0.3, 0.1, 0.8, 0.8, .99, .99])
pred_prob = jnp.array([.99, .99, 0.3, 0.1, 0.1, 0.1])

y = jnp.array([1, 1, 0, 0, 0, 0])
target = jnp.array([-1])

tempt_1, tempt_0 = pred_prob[y == 1], pred_prob[y == 0]
validity_loss_1 = hindge_embedding_loss(tempt_1 - (1. - tempt_1), target, 0.165) + \
    hindge_embedding_loss(1. - 2 * tempt_0, target, 0.165)


# tempt_1 = hindge_embedding_loss(pred_prob - (1. - pred_prob), target, 0.165)
# tempt_0 = hindge_embedding_loss(1. - 2 * pred_prob, target, 0.165)
# validity_loss = jnp.where(
#     y == 1, tempt_1, tempt_0
# )

tempt_1 = jnp.where(
    y == 1,
    hindge_embedding_loss(pred_prob - (1. - pred_prob), target, 0.165, reduction=None),
    0
).sum() / y.sum()
tempt_0 = jnp.where(
    y == 0,
    hindge_embedding_loss(1. - 2 * pred_prob, target, 0.165, reduction=None),
    0
).sum() / (y.shape[0] - y.sum())
# validity_loss = jnp.where(
#     y == 1,
#     hindge_embedding_loss(pred_prob - (1. - pred_prob), target, 0.165, reduction=None),
#     hindge_embedding_loss(1. - 2 * pred_prob, target, 0.165, reduction=None)
# )
# validity_loss_2 = jnp.sum(validity_loss)
validity_loss_2 = tempt_1 + tempt_0
hindge_embedding_loss(tempt_1 - (1. - tempt_1), target, 0.165)
DeviceArray(1.165, dtype=float32)
validity_loss_1, validity_loss_2
(DeviceArray(0., dtype=float32), DeviceArray(0., dtype=float32))

source

VAECFCONFIGS

CLASS relax.methods.vaecf.VAECFConfigs (enc_sizes=[20, 16, 14, 12, 5], dec_sizes=[12, 14, 16, 20], dropout_rate=0.1, lr=0.001, mu_samples=50, validity_reg=42.0)

Configurator of VAECFModule.

Parameters:
  • enc_sizes (List[int], default=[20, 16, 14, 12, 5]) – Sequence of Encoder layer sizes.
  • dec_sizes (List[int], default=[12, 14, 16, 20]) – Sequence of Decoder layer sizes.
  • dropout_rate (float, default=0.1) – Dropout rate.
  • lr (float, default=0.001) – Learning rate.
  • mu_samples (int, default=50) – Number of samples for mu.
  • validity_reg (float, default=42.0) – Regularization for validity.

source

VAECF

CLASS relax.methods.vaecf.VAECF (m_config=None)

Base CF Explanation Module.

Test

from relax.trainer import train_model
from relax.data import load_data
from relax.module import PredictiveTrainingModule, load_pred_model
from relax.evaluate import _AuxPredFn, generate_cf_explanations, benchmark_cfs
dm = load_data('adult', data_configs=dict(sample_frac=0.1))
# load model
params, training_module = load_pred_model('adult')

# predict function
pred_fn = training_module.pred_fn
vaecf = VAECF()
# vaecf.train(dm, t_config, pred_fn)
cf_exp = generate_cf_explanations(
    vaecf, dm, pred_fn, pred_fn_args=dict(
        params=params, rng_key=random.PRNGKey(0)
    )
)
VAECF contains parametric models. Starts training before generating explanations...
/Users/chuck/opt/anaconda3/envs/relax/lib/python3.8/site-packages/relax/_ckpt_manager.py:47: UserWarning: `monitor_metrics` is not specified in `CheckpointManager`. No checkpoints will be stored.
  warnings.warn(
Epoch 9: 100%|██████████| 20/20 [00:00<00:00, 81.43batch/s, train/loss=nan]    
benchmark_cfs([cf_exp])
acc validity proximity
adult C-CHVAE 0.8241 0.18241 7.563876