= fake_explanation(n_cfs=1)
exp = exp.xs.shape
xs_shape assert exp.cfs.shape == (xs_shape[0], 1, xs_shape[-1])
= exp['train']
train_exp = exp['val']
val_exp = exp['test']
test_exp assert jnp.concatenate(
'cfs'], val_exp['cfs']], axis=0
[train_exp[== exp.cfs.shape
).shape assert test_exp['cfs'].shape == val_exp['cfs'].shape
= fake_explanation(n_cfs=5)
exp assert exp.cfs.shape == (xs_shape[0], 5, xs_shape[-1])
Explanation
relax.explain.Explanation
class relax.explain.Explanation (cfs, pred_fn, data_module=None, xs=None, ys=None, total_time=None, cf_name=‘CFModule’, data=None)
Generated CF Explanations class. It inherits a DataModule
.
Parameters:
- cfs (
<class 'jax.Array'>
) – Generated cf explanation ofxs
indata
- pred_fn (
typing.Callable[[jax.Array], jax.Array]
) – Predict function - data_module (
<class 'relax.data_module.DataModule'>
, default=None) – Data module - xs (
<class 'jax.Array'>
, default=None) – Input data - ys (
<class 'jax.Array'>
, default=None) – Target data - total_time (
<class 'float'>
, default=None) – Total runtime - cf_name (
<class 'str'>
, default=CFModule) – CF method’s name - data (
<class 'NoneType'>
, default=None) – Deprecated argument
Methods
copy ()
Return a deep copy of the explanation.
Warning: this method will not create a deepcopy of pred_fn
.
save (path)
Save the explanation to a directory.
load_from_path (path, ml_module_path=None)
Load DataModule
from a directory.
relax.explain.fake_explanation
relax.explain.fake_explanation (n_cfs=1)
'tmp/exp/')
exp.save(= Explanation.load_from_path('tmp/exp/',
exp ='relax-assets/dummy/model/') ml_module_path
= exp.copy()
exp_1 assert exp_1 is not exp
assert np.array_equal(exp_1.cfs, exp.cfs)
Generate Explanations
relax.explain.prepare_rng_keys
relax.explain.prepare_rng_keys (rng_key, n_instances)
Prepare random number generator keys.
relax.explain.prepare_cf_module
relax.explain.prepare_cf_module (cf_module, data_module, pred_fn=None, train_config=None)
Prepare the CF module. It will hook up the data module, and its apply functions via the init_apply_fns
method (e.g., apply_constraints_fn
and compute_reg_loss_fn
). Next, it will train the model if cf_module
is a ParametricCFModule
. Finally, it will call before_generate_cf
method.
relax.explain.prepare_pred_fn
relax.explain.prepare_pred_fn (cf_module, data, pred_fn, pred_fn_args=None)
Prepare the predictive function for the CF module. We will train the model if pred_fn
is not provided and cf_module
does not have pred_fn
. If pred_fn
is found in cf_module
, we will use it irrespective of pred_fn
argument. If pred_fn
is provided, we will use it.
Parameters:
- cf_module (
<class 'relax.methods.base.CFModule'>
) - data (
<class 'relax.data_module.DataModule'>
) - pred_fn (
typing.Callable[[jax.Array, ...], jax.Array]
) – Predictive function. - pred_fn_args (
typing.Dict
, default=None)
Returns:
(typing.Callable[[jax.Array], jax.Array]
) – Return predictive function with signature (x: Array) -> Array
.
relax.explain.generate_cf_explanations
relax.explain.generate_cf_explanations (cf_module, data, pred_fn=None, strategy=None, train_config=None, pred_fn_args=None, rng_key=None)
Generate CF explanations.
Parameters:
- cf_module (
<class 'relax.methods.base.CFModule'>
) – CF Explanation Module - data (
<class 'relax.data_module.DataModule'>
) – Data Module - pred_fn (
typing.Callable[[jax.Array, ...], jax.Array]
, default=None) – Predictive function - strategy (
str | relax.strategy.BaseStrategy
, default=None) – Parallelism Strategy for generating CFs. Default tovmap
. - train_config (
typing.Dict[str, typing.Any]
, default=None) - pred_fn_args (
<class 'dict'>
, default=None) – auxiliary arguments forpred_fn
- rng_key (
<function PRNGKey at 0x7fcff174d360>
, default=None) – Random number generator key
Returns:
(<class '__main__.Explanation'>
) – Return counterfactual explanations.
= load_data("adult")
dm = load_ml_module("adult") ml_model
= generate_cf_explanations(
exps
VanillaCF(),
dm, ml_model.pred_fn, )
/tmp/ipykernel_5475/4129963786.py:17: DeprecationWarning: Argument `data` is deprecated. Use `data_module` instead.
warnings.warn(
= CounterNet()
cfnet =1)
cfnet.train(dm, epochs# Test cases for checking if ParametricCFModule is trained twice.
# If it is trained twice, cfs will be different.
= jax.vmap(cfnet.generate_cf)(dm.xs)
cfs assert cfnet.is_trained == True
= generate_cf_explanations(cfnet, dm)
exp assert np.allclose(einops.rearrange(exp.cfs, 'N 1 K -> N K'), cfs)
/home/birk/miniconda3/envs/dev/lib/python3.10/site-packages/relax/legacy/ckpt_manager.py:47: UserWarning: `monitor_metrics` is not specified in `CheckpointManager`. No checkpoints will be stored.
warnings.warn(
Epoch 0: 100%|██████████| 191/191 [00:08<00:00, 22.21batch/s, train/train_loss_1=0.06329722, train/train_loss_2=0.07011371, train/train_loss_3=0.101814255]
/tmp/ipykernel_5475/4129963786.py:17: DeprecationWarning: Argument `data` is deprecated. Use `data_module` instead.
warnings.warn(
# hide
# dm = load_data("dummy")
# ml_model = load_ml_module("dummy")
# for cf_module in [CounterNet, CCHVAE, VAECF, L2C, ProtoCF, CLUE]:
# m = cf_module()
# assert m.is_trained == False
# m.train(dm, pred_fn=ml_model.pred_fn, epochs=1)
# assert m.is_trained == True
# exp = generate_cf_explanations(m, dm, pred_fn=ml_model.pred_fn)