from relax.data import load_data
from relax.module import PredictiveTrainingModule, PredictiveTrainingModuleConfigs, load_pred_model
from relax.evaluate import generate_cf_explanations, benchmark_cfsProto CF
PROTOCFCONFIG
CLASS relax.methods.proto.ProtoCFConfig (n_steps=1000, lr=0.01, lambda_=0.01, ae_configs={‘enc_sizes’: [50, 10], ‘dec_sizes’: [10, 50], ‘dropout_rate’: 0.3, ‘lr’: 0.03})
Create a new model by parsing and validating input data from keyword arguments.
Raises ValidationError if the input data cannot be parsed to form a valid model.
PROTOCF
CLASS relax.methods.proto.ProtoCF (configs=None)
Base CF Explanation Module.
Load data:
dm = load_data('adult', data_configs=dict(sample_frac=0.1))/Users/chuck/opt/anaconda3/envs/relax/lib/python3.8/site-packages/sklearn/preprocessing/_encoders.py:828: FutureWarning: `sparse` was renamed to `sparse_output` in version 1.2 and will be removed in 1.4. `sparse_output` is ignored unless you leave `sparse` to its default value.
warnings.warn(
Train predictive model:
# load model
params, training_module = load_pred_model('adult')
# predict function
pred_fn = lambda x, params, key: training_module.forward(
params, key, x, is_training=False
)Define ProtoCF:
protocf = ProtoCF()Generate explanations:
cf_exp = generate_cf_explanations(
protocf, dm, pred_fn=pred_fn,
t_configs=dict(
n_epochs=5, batch_size=128
),
pred_fn_args=dict(
params=params, key=random.PRNGKey(0)
)
)Evaluate explanations:
benchmark_cfs([cf_exp])| acc | validity | proximity | ||
|---|---|---|---|---|
| adult | ProtoCF | 0.8241 | 0.812308 | 6.427959 |