Plotting

Plot the result of CF explanations.

source

SUMMARY_PLOT

relax.plots.summary_plot (exp, sample_frac=1.0, only_valid=False, figsize=(15, 7))

Globally visualize generated explanations.

Parameters:
  • exp (Explanation) – Explanations to visualize
  • sample_frac (float, default=1.0) – Sample part of data for visualization
  • only_valid (bool, default=False) – Use only valid data
  • figsize (<class 'tuple'>, default=(15, 7)) – Figure size
Returns:

    (plt.Figure)


source

INDIVIDUAL_PLOT

relax.plots.individual_plot (exp, idx, figsize=(15, 7))

Locally visualize individual explanations.

Parameters:
  • exp (Explanation) – Explanations to visualize
  • idx (int) – Index for visualization
  • figsize (<class 'tuple'>, default=(15, 7)) – Figure plot
Returns:

    (plt.Figure)

Example

We first use VanillaCF to generate Explanation.

from relax.utils import load_json
from relax.module import PredictiveTrainingModule
from relax.trainer import train_model
from relax.evaluate import generate_cf_explanations
from relax.methods import VanillaCF
# load configs
configs = load_json('assets/configs/data_configs/adult.json')
m_configs = configs['mlp_configs']
data_configs = configs['data_configs']
t_configs = dict(n_epochs=10, batch_size=256)

# load data and model
dm = TabularDataModule(data_configs)
model = PredictiveTrainingModule(m_configs)

# train predictive models
params, opt_state = train_model(model, dm, t_configs)
pred_fn = lambda x, params, prng_key: model.forward(params, prng_key, x, is_training=False)

# generate explanations
exp = generate_cf_explanations(
    VanillaCF(), dm, pred_fn, 
    pred_fn_args=dict(params=params, prng_key=random.PRNGKey(0))
)

To visualize individual explanation:

# this visualize the differences between `exp.X[0]` and `exp.cfs[0]`
fig = individual_plot(exp, idx=0)

To analyze the entire explanation distribution:

fig = summary_plot(exp, sample_frac=0.01)