Explanation

relax.explain.Explanation

[source]

class relax.explain.Explanation (cfs, pred_fn, data_module=None, xs=None, ys=None, total_time=None, cf_name=‘CFModule’, data=None)

Generated CF Explanations class. It inherits a DataModule.

Parameters:

  • cfs (<class 'jax.Array'>) – Generated cf explanation of xs in data
  • pred_fn (typing.Callable[[jax.Array], jax.Array]) – Predict function
  • data_module (<class 'relax.data_module.DataModule'>, default=None) – Data module
  • xs (<class 'jax.Array'>, default=None) – Input data
  • ys (<class 'jax.Array'>, default=None) – Target data
  • total_time (<class 'float'>, default=None) – Total runtime
  • cf_name (<class 'str'>, default=CFModule) – CF method’s name
  • data (<class 'NoneType'>, default=None) – Deprecated argument

Methods

[source]

copy ()

Return a deep copy of the explanation.

Warning: this method will not create a deepcopy of pred_fn.

[source]

save (path)

Save the explanation to a directory.

[source]

load_from_path (path, ml_module_path=None)

Load DataModule from a directory.

relax.explain.fake_explanation

[source]

relax.explain.fake_explanation (n_cfs=1)

exp = fake_explanation(n_cfs=1)
xs_shape = exp.xs.shape
assert exp.cfs.shape == (xs_shape[0], 1, xs_shape[-1])
train_exp = exp['train']
val_exp = exp['val']
test_exp = exp['test']
assert jnp.concatenate(
    [train_exp['cfs'], val_exp['cfs']], axis=0
).shape == exp.cfs.shape
assert test_exp['cfs'].shape == val_exp['cfs'].shape

exp = fake_explanation(n_cfs=5)
assert exp.cfs.shape == (xs_shape[0], 5, xs_shape[-1])
exp.save('tmp/exp/')
exp = Explanation.load_from_path('tmp/exp/', 
    ml_module_path='relax-assets/dummy/model/')
exp_1 = exp.copy()
assert exp_1 is not exp
assert np.array_equal(exp_1.cfs, exp.cfs)

Generate Explanations

relax.explain.prepare_rng_keys

[source]

relax.explain.prepare_rng_keys (rng_key, n_instances)

Prepare random number generator keys.

relax.explain.prepare_cf_module

[source]

relax.explain.prepare_cf_module (cf_module, data_module, pred_fn=None, train_config=None)

Prepare the CF module. It will hook up the data module, and its apply functions via the init_apply_fns method (e.g., apply_constraints_fn and compute_reg_loss_fn). Next, it will train the model if cf_module is a ParametricCFModule. Finally, it will call before_generate_cf method.

relax.explain.prepare_pred_fn

[source]

relax.explain.prepare_pred_fn (cf_module, data, pred_fn, pred_fn_args=None)

Prepare the predictive function for the CF module. We will train the model if pred_fn is not provided and cf_module does not have pred_fn. If pred_fn is found in cf_module, we will use it irrespective of pred_fn argument. If pred_fn is provided, we will use it.

Parameters:

  • cf_module (<class 'relax.methods.base.CFModule'>)
  • data (<class 'relax.data_module.DataModule'>)
  • pred_fn (typing.Callable[[jax.Array, ...], jax.Array]) – Predictive function.
  • pred_fn_args (typing.Dict, default=None)

Returns:

    (typing.Callable[[jax.Array], jax.Array]) – Return predictive function with signature (x: Array) -> Array.

relax.explain.generate_cf_explanations

[source]

relax.explain.generate_cf_explanations (cf_module, data, pred_fn=None, strategy=None, train_config=None, pred_fn_args=None, rng_key=None)

Generate CF explanations.

Parameters:

  • cf_module (<class 'relax.methods.base.CFModule'>) – CF Explanation Module
  • data (<class 'relax.data_module.DataModule'>) – Data Module
  • pred_fn (typing.Callable[[jax.Array, ...], jax.Array], default=None) – Predictive function
  • strategy (str | relax.strategy.BaseStrategy, default=None) – Parallelism Strategy for generating CFs. Default to vmap.
  • train_config (typing.Dict[str, typing.Any], default=None)
  • pred_fn_args (<class 'dict'>, default=None) – auxiliary arguments for pred_fn
  • rng_key (<function PRNGKey at 0x7fcff174d360>, default=None) – Random number generator key

Returns:

    (<class '__main__.Explanation'>) – Return counterfactual explanations.

dm = load_data("adult")
ml_model = load_ml_module("adult")
exps = generate_cf_explanations(
    VanillaCF(),
    dm, ml_model.pred_fn,
)
/tmp/ipykernel_5475/4129963786.py:17: DeprecationWarning: Argument `data` is deprecated. Use `data_module` instead.
  warnings.warn(
cfnet = CounterNet()
cfnet.train(dm, epochs=1)
# Test cases for checking if ParametricCFModule is trained twice.
# If it is trained twice, cfs will be different.
cfs = jax.vmap(cfnet.generate_cf)(dm.xs)
assert cfnet.is_trained == True
exp = generate_cf_explanations(cfnet, dm)
assert np.allclose(einops.rearrange(exp.cfs, 'N 1 K -> N K'), cfs)
/home/birk/miniconda3/envs/dev/lib/python3.10/site-packages/relax/legacy/ckpt_manager.py:47: UserWarning: `monitor_metrics` is not specified in `CheckpointManager`. No checkpoints will be stored.
  warnings.warn(
Epoch 0: 100%|██████████| 191/191 [00:08<00:00, 22.21batch/s, train/train_loss_1=0.06329722, train/train_loss_2=0.07011371, train/train_loss_3=0.101814255]   
/tmp/ipykernel_5475/4129963786.py:17: DeprecationWarning: Argument `data` is deprecated. Use `data_module` instead.
  warnings.warn(
# hide
# dm = load_data("dummy")
# ml_model = load_ml_module("dummy")

# for cf_module in [CounterNet, CCHVAE, VAECF, L2C, ProtoCF, CLUE]:
#     m = cf_module()
#     assert m.is_trained == False
#     m.train(dm, pred_fn=ml_model.pred_fn, epochs=1)
#     assert m.is_trained == True
#     exp = generate_cf_explanations(m, dm, pred_fn=ml_model.pred_fn)