= cat_sample(jrand.PRNGKey(0), [2, 3], 10)
candidates assert candidates.shape == (10, 5)
# No categorical features
= cat_sample(jrand.PRNGKey(0), [], 10)
candidates assert jnp.concatenate([jnp.ones((10, 5)), candidates], axis=1).shape == (10, 5)
Growing Sphere
HYPER_SPHERE_COORDINDATES
relax.methods.sphere.hyper_sphere_coordindates (rng_key, x, n_samples, high, low, p_norm=2)
CAT_SAMPLE
relax.methods.sphere.cat_sample (rng_key, cat_array_sizes, n_samples)
SAMPLE_CATEGORICAL
relax.methods.sphere.sample_categorical (rng_key, col_size, n_samples)
APPLY_IMMUTABLE
relax.methods.sphere.apply_immutable (x, cf, immutable_idx)
GSCONFIG
CLASS relax.methods.sphere.GSConfig (seed=42, n_steps=100, n_samples=1000, step_size=0.05, p_norm=2)
Create a new model by parsing and validating input data from keyword arguments.
Raises ValidationError if the input data cannot be parsed to form a valid model.
GROWINGSPHERE
CLASS relax.methods.sphere.GrowingSphere (configs=None)
Base CF Explanation Module.
Test
from relax.data import load_data
from relax.module import PredictiveTrainingModule, PredictiveTrainingModuleConfigs, load_pred_model
from relax.evaluate import generate_cf_explanations, benchmark_cfs
from relax.trainer import train_model
= load_data('adult', data_configs=dict(sample_frac=0.1))
dm # dm = load_data('adult',)
# load model
= load_pred_model('adult')
params, training_module
# predict function
# pred_fn = lambda x: training_module.forward(params, x, is_training=False)
= lambda x, params, key: training_module.forward(
pred_fn =False
params, key, x, is_training )
= GrowingSphere({'n_steps': 50, 'n_samples': 100, 'step_size': 0.1})
gs # gs.hook_data_module(dm)
= generate_cf_explanations(
cf_exp =pred_fn,
gs, dm, pred_fn=dict(
pred_fn_args=params, key=random.PRNGKey(0)
params
)
)assert not np.array_equal(cf_exp.cfs[0], cf_exp.cfs[1])
benchmark_cfs([cf_exp])
acc | validity | proximity | ||
---|---|---|---|---|
adult | Growing Sphere | 0.8241 | 1.0 | 6.270596 |