from relax.data import load_data
from relax.module import PredictiveTrainingModule, PredictiveTrainingModuleConfigs, load_pred_model
from relax.evaluate import generate_cf_explanations, benchmark_cfs
from relax.trainer import train_model
Vanilla CF
Vanilla counterfactual explanation.
VANILLACFCONFIG
CLASS relax.methods.vanilla.VanillaCFConfig (n_steps=1000, lr=0.001, lambda_=0.01)
Create a new model by parsing and validating input data from keyword arguments.
Raises ValidationError if the input data cannot be parsed to form a valid model.
VANILLACF
CLASS relax.methods.vanilla.VanillaCF (configs=None)
Base CF Explanation Module.
Load data:
= load_data('adult', data_configs=dict(sample_frac=0.1)) dm
/Users/chuck/opt/anaconda3/envs/relax/lib/python3.8/site-packages/sklearn/preprocessing/_encoders.py:828: FutureWarning: `sparse` was renamed to `sparse_output` in version 1.2 and will be removed in 1.4. `sparse_output` is ignored unless you leave `sparse` to its default value.
warnings.warn(
Load predictive model:
# load model
= load_pred_model('adult')
params, module
# predict function
= lambda x, params, key: module.forward(
pred_fn =False
params, key, x, is_training )
Define VanillaCF
:
= VanillaCF() vanillacf
Generate explanations:
= generate_cf_explanations(
cf_exp =pred_fn,
vanillacf, dm, pred_fn=dict(
t_configs=5, batch_size=128
n_epochs
), =dict(
pred_fn_args=params, key=random.PRNGKey(0)
params
) )
Evaluate explanations:
benchmark_cfs([cf_exp])
acc | validity | proximity | ||
---|---|---|---|---|
adult | VanillaCF | 0.8241 | 0.891414 | 6.703655 |