candidates = cat_sample(jrand.PRNGKey(0), [2, 3], 10)
assert candidates.shape == (10, 5)
# No categorical features
candidates = cat_sample(jrand.PRNGKey(0), [], 10)
assert jnp.concatenate([jnp.ones((10, 5)), candidates], axis=1).shape == (10, 5)Growing Sphere
HYPER_SPHERE_COORDINDATES
relax.methods.sphere.hyper_sphere_coordindates (rng_key, x, n_samples, high, low, p_norm=2)
CAT_SAMPLE
relax.methods.sphere.cat_sample (rng_key, cat_array_sizes, n_samples)
SAMPLE_CATEGORICAL
relax.methods.sphere.sample_categorical (rng_key, col_size, n_samples)
APPLY_IMMUTABLE
relax.methods.sphere.apply_immutable (x, cf, immutable_idx)
GSCONFIG
CLASS relax.methods.sphere.GSConfig (seed=42, n_steps=100, n_samples=1000, step_size=0.05, p_norm=2)
Create a new model by parsing and validating input data from keyword arguments.
Raises ValidationError if the input data cannot be parsed to form a valid model.
GROWINGSPHERE
CLASS relax.methods.sphere.GrowingSphere (configs=None)
Base CF Explanation Module.
Test
from relax.data import load_data
from relax.module import PredictiveTrainingModule, PredictiveTrainingModuleConfigs, load_pred_model
from relax.evaluate import generate_cf_explanations, benchmark_cfs
from relax.trainer import train_modeldm = load_data('adult', data_configs=dict(sample_frac=0.1))
# dm = load_data('adult',)# load model
params, training_module = load_pred_model('adult')
# predict function
# pred_fn = lambda x: training_module.forward(params, x, is_training=False)
pred_fn = lambda x, params, key: training_module.forward(
params, key, x, is_training=False
)gs = GrowingSphere({'n_steps': 50, 'n_samples': 100, 'step_size': 0.1})
# gs.hook_data_module(dm)cf_exp = generate_cf_explanations(
gs, dm, pred_fn=pred_fn,
pred_fn_args=dict(
params=params, key=random.PRNGKey(0)
)
)
assert not np.array_equal(cf_exp.cfs[0], cf_exp.cfs[1])benchmark_cfs([cf_exp])| acc | validity | proximity | ||
|---|---|---|---|---|
| adult | Growing Sphere | 0.8241 | 1.0 | 6.270596 |