CLUE


source

DECODER

CLASS relax.methods.vaecf.Decoder (sizes, input_size, dropout=0.1)

Base class for Haiku modules.

A Haiku module is a lightweight container for variables and other modules. Modules typically define one or more “forward” methods (e.g. __call__) which apply operations combining user input and module parameters.

Modules must be initialized inside a :func:transform call.

For example:

class AddModule(hk.Module): … def call(self, x): … w = hk.get_parameter(“w”, [], init=jnp.ones) … return x + w

def forward_fn(x): … mod = AddModule() … return mod(x)

forward = hk.transform(forward_fn) x = 1. rng = None params = forward.init(rng, x) print(forward.apply(params, None, x)) 2.0


source

ENCODER

CLASS relax.methods.vaecf.Encoder (sizes, dropout=0.1)

Base class for Haiku modules.

A Haiku module is a lightweight container for variables and other modules. Modules typically define one or more “forward” methods (e.g. __call__) which apply operations combining user input and module parameters.

Modules must be initialized inside a :func:transform call.

For example:

class AddModule(hk.Module): … def call(self, x): … w = hk.get_parameter(“w”, [], init=jnp.ones) … return x + w

def forward_fn(x): … mod = AddModule() … return mod(x)

forward = hk.transform(forward_fn) x = 1. rng = None params = forward.init(rng, x) print(forward.apply(params, None, x)) 2.0


source

KL_DIVERGENCE

relax.methods.clue.kl_divergence (p, q, eps=7.62939453125e-06)


source

VAEGAUSSCATCONFIGS

CLASS relax.methods.clue.VAEGaussCatConfigs (lr=0.001, enc_sizes=[20, 16, 14, 12], dec_sizes=[12, 14, 16, 20], dropout_rate=0.1)

Create a new model by parsing and validating input data from keyword arguments.

Raises ValidationError if the input data cannot be parsed to form a valid model.

Parameters:
  • lr (float, default=0.001) – Learning rate.
  • enc_sizes (List[int], default=[20, 16, 14, 12]) – Sequence of Encoder layer sizes.
  • dec_sizes (List[int], default=[12, 14, 16, 20]) – Sequence of Decoder layer sizes.
  • dropout_rate (float, default=0.1) – Dropout rate.

source

VAEGAUSSCAT

CLASS relax.methods.clue.VAEGaussCat (m_configs=None)

Helper class that provides a standard way to create an ABC using inheritance.


source

CLUECONFIGS

CLASS relax.methods.clue.CLUEConfigs (enc_sizes=[20, 16, 14, 12], dec_sizes=[12, 14, 16, 20], encoded_size=5, lr=0.001, max_steps=500, step_size=0.01, vae_n_epochs=10, vae_batch_size=128, seed=0)

Create a new model by parsing and validating input data from keyword arguments.

Raises ValidationError if the input data cannot be parsed to form a valid model.

Parameters:
  • enc_sizes (List[int], default=[20, 16, 14, 12]) – Sequence of Encoder layer sizes.
  • dec_sizes (List[int], default=[12, 14, 16, 20]) – Sequence of Decoder layer sizes.
  • encoded_size (int, default=5) – Encoded size
  • lr (float, default=0.001) – Learning rate
  • max_steps (int, default=500) – Max steps
  • step_size (float, default=0.01) – Step size
  • vae_n_epochs (int, default=10) – Number of epochs for VAE
  • vae_batch_size (int, default=128) – Batch size for VAE
  • seed (int, default=0) – Seed for random number generator

source

CLUE

CLASS relax.methods.clue.CLUE (m_config=None)

Base CF Explanation Module.

Test

from relax.module import PredictiveTrainingModule, load_pred_model
from relax.evaluate import generate_cf_explanations, benchmark_cfs
data_name = 'cancer'
dm = load_data(data_name) # ,) data_configs=dict(sample_frac=0.1))
# load model
params, training_module = load_pred_model(data_name)
pred_fn = training_module.pred_fn
dl = dm.train_dataloader(128)
X, y = next(iter(dl))
clue = CLUE()
clue.train(dm)
/home/birk/mambaforge-pypy3/envs/nbdev2/lib/python3.8/site-packages/haiku/_src/base.py:515: UserWarning: Explicitly requested dtype float64 requested in zeros is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  param = init(shape, dtype)
/home/birk/code/ReLax/relax/_ckpt_manager.py:47: UserWarning: `monitor_metrics` is not specified in `CheckpointManager`. No checkpoints will be stored.
  warnings.warn(
Epoch 9: 100%|██████████| 4/4 [00:00<00:00, 125.86batch/s, train/loss=2.1]
clue_test = CLUE()
clue_test.params = clue.params
clue_test.module = clue.module
cf_exp = generate_cf_explanations(
    clue_test, dm, pred_fn, pred_fn_args=dict(
        params=params, rng_key=random.PRNGKey(0)
    ), t_configs=dict(
        n_epochs=5, batch_size=256
    )
)
benchmark_cfs([cf_exp])
acc validity proximity
breast cancer CLUE 0.909091 0.608392 9.972657