Vanilla CF

Vanilla counterfactual explanation.

source

VANILLACFCONFIG

CLASS relax.methods.vanilla.VanillaCFConfig (n_steps=1000, lr=0.001, lambda_=0.01)

Create a new model by parsing and validating input data from keyword arguments.

Raises ValidationError if the input data cannot be parsed to form a valid model.


source

VANILLACF

CLASS relax.methods.vanilla.VanillaCF (configs=None)

Base CF Explanation Module.

from relax.data import load_data
from relax.module import PredictiveTrainingModule, PredictiveTrainingModuleConfigs, load_pred_model
from relax.evaluate import generate_cf_explanations, benchmark_cfs
from relax.trainer import train_model

Load data:

dm = load_data('adult', data_configs=dict(sample_frac=0.1))
/Users/chuck/opt/anaconda3/envs/relax/lib/python3.8/site-packages/sklearn/preprocessing/_encoders.py:828: FutureWarning: `sparse` was renamed to `sparse_output` in version 1.2 and will be removed in 1.4. `sparse_output` is ignored unless you leave `sparse` to its default value.
  warnings.warn(

Load predictive model:

# load model
params, module = load_pred_model('adult')

# predict function
pred_fn = lambda x, params, key: module.forward(
    params, key, x, is_training=False
)

Define VanillaCF:

vanillacf = VanillaCF()

Generate explanations:

cf_exp = generate_cf_explanations(
    vanillacf, dm, pred_fn=pred_fn, 
    t_configs=dict(
        n_epochs=5, batch_size=128
    ), 
    pred_fn_args=dict(
        params=params, key=random.PRNGKey(0)
    )
)

Evaluate explanations:

benchmark_cfs([cf_exp])
acc validity proximity
adult VanillaCF 0.8241 0.891414 6.703655