from relax.utils import load_json
from relax.module import PredictiveTrainingModule
from relax.trainer import train_model
from relax.evaluate import generate_cf_explanations
from relax.methods import VanillaCF
Plotting
Plot the result of CF explanations.
SUMMARY_PLOT
relax.plots.summary_plot (exp, sample_frac=1.0, only_valid=False, figsize=(15, 7))
Globally visualize generated explanations.
INDIVIDUAL_PLOT
relax.plots.individual_plot (exp, idx, figsize=(15, 7))
Locally visualize individual explanations.
Example
We first use VanillaCF
to generate Explanation
.
# load configs
= load_json('assets/configs/data_configs/adult.json')
configs = configs['mlp_configs']
m_configs = configs['data_configs']
data_configs = dict(n_epochs=10, batch_size=256)
t_configs
# load data and model
= TabularDataModule(data_configs)
dm = PredictiveTrainingModule(m_configs)
model
# train predictive models
= train_model(model, dm, t_configs)
params, opt_state = lambda x, params, prng_key: model.forward(params, prng_key, x, is_training=False)
pred_fn
# generate explanations
= generate_cf_explanations(
exp
VanillaCF(), dm, pred_fn, =dict(params=params, prng_key=random.PRNGKey(0))
pred_fn_args )
To visualize individual explanation:
# this visualize the differences between `exp.X[0]` and `exp.cfs[0]`
= individual_plot(exp, idx=0) fig
To analyze the entire explanation distribution:
= summary_plot(exp, sample_frac=0.01) fig