from relax.module import PredictiveTrainingModule, load_pred_model
from relax.evaluate import generate_cf_explanations, benchmark_cfs
CLUE
DECODER
CLASS relax.methods.vaecf.Decoder (sizes, input_size, dropout=0.1)
Base class for Haiku modules.
A Haiku module is a lightweight container for variables and other modules. Modules typically define one or more “forward” methods (e.g. __call__
) which apply operations combining user input and module parameters.
Modules must be initialized inside a :func:transform
call.
For example:
class AddModule(hk.Module): … def call(self, x): … w = hk.get_parameter(“w”, [], init=jnp.ones) … return x + w
def forward_fn(x): … mod = AddModule() … return mod(x)
forward = hk.transform(forward_fn) x = 1. rng = None params = forward.init(rng, x) print(forward.apply(params, None, x)) 2.0
ENCODER
CLASS relax.methods.vaecf.Encoder (sizes, dropout=0.1)
Base class for Haiku modules.
A Haiku module is a lightweight container for variables and other modules. Modules typically define one or more “forward” methods (e.g. __call__
) which apply operations combining user input and module parameters.
Modules must be initialized inside a :func:transform
call.
For example:
class AddModule(hk.Module): … def call(self, x): … w = hk.get_parameter(“w”, [], init=jnp.ones) … return x + w
def forward_fn(x): … mod = AddModule() … return mod(x)
forward = hk.transform(forward_fn) x = 1. rng = None params = forward.init(rng, x) print(forward.apply(params, None, x)) 2.0
KL_DIVERGENCE
relax.methods.clue.kl_divergence (p, q, eps=7.62939453125e-06)
VAEGAUSSCATCONFIGS
CLASS relax.methods.clue.VAEGaussCatConfigs (lr=0.001, enc_sizes=[20, 16, 14, 12], dec_sizes=[12, 14, 16, 20], dropout_rate=0.1)
Create a new model by parsing and validating input data from keyword arguments.
Raises ValidationError if the input data cannot be parsed to form a valid model.
VAEGAUSSCAT
CLASS relax.methods.clue.VAEGaussCat (m_configs=None)
Helper class that provides a standard way to create an ABC using inheritance.
CLUECONFIGS
CLASS relax.methods.clue.CLUEConfigs (enc_sizes=[20, 16, 14, 12], dec_sizes=[12, 14, 16, 20], encoded_size=5, lr=0.001, max_steps=500, step_size=0.01, vae_n_epochs=10, vae_batch_size=128, seed=0)
Create a new model by parsing and validating input data from keyword arguments.
Raises ValidationError if the input data cannot be parsed to form a valid model.
CLUE
CLASS relax.methods.clue.CLUE (m_config=None)
Base CF Explanation Module.
Test
= 'cancer'
data_name = load_data(data_name) # ,) data_configs=dict(sample_frac=0.1)) dm
# load model
= load_pred_model(data_name)
params, training_module = training_module.pred_fn pred_fn
= dm.train_dataloader(128)
dl = next(iter(dl)) X, y
= CLUE()
clue clue.train(dm)
/home/birk/mambaforge-pypy3/envs/nbdev2/lib/python3.8/site-packages/haiku/_src/base.py:515: UserWarning: Explicitly requested dtype float64 requested in zeros is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
param = init(shape, dtype)
/home/birk/code/ReLax/relax/_ckpt_manager.py:47: UserWarning: `monitor_metrics` is not specified in `CheckpointManager`. No checkpoints will be stored.
warnings.warn(
Epoch 9: 100%|██████████| 4/4 [00:00<00:00, 125.86batch/s, train/loss=2.1]
= CLUE()
clue_test = clue.params
clue_test.params = clue.module clue_test.module
= generate_cf_explanations(
cf_exp =dict(
clue_test, dm, pred_fn, pred_fn_args=params, rng_key=random.PRNGKey(0)
params=dict(
), t_configs=5, batch_size=256
n_epochs
) )
benchmark_cfs([cf_exp])
acc | validity | proximity | ||
---|---|---|---|---|
breast cancer | CLUE | 0.909091 | 0.608392 | 9.972657 |