exp = fake_explanation(n_cfs=1)
xs_shape = exp.xs.shape
assert exp.cfs.shape == (xs_shape[0], 1, xs_shape[-1])
train_exp = exp['train']
val_exp = exp['val']
test_exp = exp['test']
assert jnp.concatenate(
[train_exp['cfs'], val_exp['cfs']], axis=0
).shape == exp.cfs.shape
assert test_exp['cfs'].shape == val_exp['cfs'].shape
exp = fake_explanation(n_cfs=5)
assert exp.cfs.shape == (xs_shape[0], 5, xs_shape[-1])Explanation
relax.explain.Explanation
class relax.explain.Explanation (cfs, pred_fn, data_module=None, xs=None, ys=None, total_time=None, cf_name=‘CFModule’, data=None)
Generated CF Explanations class. It inherits a DataModule.
Parameters:
- cfs (
<class 'jax.Array'>) – Generated cf explanation ofxsindata - pred_fn (
typing.Callable[[jax.Array], jax.Array]) – Predict function - data_module (
<class 'relax.data_module.DataModule'>, default=None) – Data module - xs (
<class 'jax.Array'>, default=None) – Input data - ys (
<class 'jax.Array'>, default=None) – Target data - total_time (
<class 'float'>, default=None) – Total runtime - cf_name (
<class 'str'>, default=CFModule) – CF method’s name - data (
<class 'NoneType'>, default=None) – Deprecated argument
Methods
copy ()
Return a deep copy of the explanation.
Warning: this method will not create a deepcopy of pred_fn.
save (path)
Save the explanation to a directory.
load_from_path (path, ml_module_path=None)
Load DataModule from a directory.
relax.explain.fake_explanation
relax.explain.fake_explanation (n_cfs=1)
exp.save('tmp/exp/')
exp = Explanation.load_from_path('tmp/exp/',
ml_module_path='relax-assets/dummy/model/')exp_1 = exp.copy()
assert exp_1 is not exp
assert np.array_equal(exp_1.cfs, exp.cfs)Generate Explanations
relax.explain.prepare_rng_keys
relax.explain.prepare_rng_keys (rng_key, n_instances)
Prepare random number generator keys.
relax.explain.prepare_cf_module
relax.explain.prepare_cf_module (cf_module, data_module, pred_fn=None, train_config=None)
Prepare the CF module. It will hook up the data module, and its apply functions via the init_apply_fns method (e.g., apply_constraints_fn and compute_reg_loss_fn). Next, it will train the model if cf_module is a ParametricCFModule. Finally, it will call before_generate_cf method.
relax.explain.prepare_pred_fn
relax.explain.prepare_pred_fn (cf_module, data, pred_fn, pred_fn_args=None)
Prepare the predictive function for the CF module. We will train the model if pred_fn is not provided and cf_module does not have pred_fn. If pred_fn is found in cf_module, we will use it irrespective of pred_fn argument. If pred_fn is provided, we will use it.
Parameters:
- cf_module (
<class 'relax.methods.base.CFModule'>) - data (
<class 'relax.data_module.DataModule'>) - pred_fn (
typing.Callable[[jax.Array, ...], jax.Array]) – Predictive function. - pred_fn_args (
typing.Dict, default=None)
Returns:
(typing.Callable[[jax.Array], jax.Array]) – Return predictive function with signature (x: Array) -> Array.
relax.explain.generate_cf_explanations
relax.explain.generate_cf_explanations (cf_module, data, pred_fn=None, strategy=None, train_config=None, pred_fn_args=None, rng_key=None)
Generate CF explanations.
Parameters:
- cf_module (
<class 'relax.methods.base.CFModule'>) – CF Explanation Module - data (
<class 'relax.data_module.DataModule'>) – Data Module - pred_fn (
typing.Callable[[jax.Array, ...], jax.Array], default=None) – Predictive function - strategy (
str | relax.strategy.BaseStrategy, default=None) – Parallelism Strategy for generating CFs. Default tovmap. - train_config (
typing.Dict[str, typing.Any], default=None) - pred_fn_args (
<class 'dict'>, default=None) – auxiliary arguments forpred_fn - rng_key (
<function PRNGKey at 0x7f08b3fbdea0>, default=None) – Random number generator key
Returns:
(<class '__main__.Explanation'>) – Return counterfactual explanations.
dm = load_data("adult")
ml_model = load_ml_module("adult")exps = generate_cf_explanations(
VanillaCF(),
dm, ml_model.pred_fn,
)/tmp/ipykernel_5475/4129963786.py:17: DeprecationWarning: Argument `data` is deprecated. Use `data_module` instead.
warnings.warn(
cfnet = CounterNet()
cfnet.train(dm, epochs=1)
# Test cases for checking if ParametricCFModule is trained twice.
# If it is trained twice, cfs will be different.
cfs = jax.vmap(cfnet.generate_cf)(dm.xs)
assert cfnet.is_trained == True
exp = generate_cf_explanations(cfnet, dm)
assert np.allclose(einops.rearrange(exp.cfs, 'N 1 K -> N K'), cfs)/home/birk/miniconda3/envs/dev/lib/python3.10/site-packages/relax/legacy/ckpt_manager.py:47: UserWarning: `monitor_metrics` is not specified in `CheckpointManager`. No checkpoints will be stored.
warnings.warn(
Epoch 0: 100%|██████████| 191/191 [00:08<00:00, 22.21batch/s, train/train_loss_1=0.06329722, train/train_loss_2=0.07011371, train/train_loss_3=0.101814255]
/tmp/ipykernel_5475/4129963786.py:17: DeprecationWarning: Argument `data` is deprecated. Use `data_module` instead.
warnings.warn(
# hide
# dm = load_data("dummy")
# ml_model = load_ml_module("dummy")
# for cf_module in [CounterNet, CCHVAE, VAECF, L2C, ProtoCF, CLUE]:
# m = cf_module()
# assert m.is_trained == False
# m.train(dm, pred_fn=ml_model.pred_fn, epochs=1)
# assert m.is_trained == True
# exp = generate_cf_explanations(m, dm, pred_fn=ml_model.pred_fn)