Growing Sphere


source

HYPER_SPHERE_COORDINDATES

relax.methods.sphere.hyper_sphere_coordindates (rng_key, x, n_samples, high, low, p_norm=2)

Parameters:
  • rng_key (jrand.PRNGKey) – Random number generator key
  • x (Array) – Input instance with only continuous features. Shape: (1, n_features)
  • n_samples (int) – Number of samples
  • high (float) – Upper bound
  • low (float) – Lower bound
  • p_norm (int, default=2) – Norm

source

CAT_SAMPLE

relax.methods.sphere.cat_sample (rng_key, cat_array_sizes, n_samples)

Parameters:
  • rng_key (jrand.PRNGKey) – Random number generator key
  • cat_array_sizes (List[int]) – A list of the number of categories for each categorical feature
  • n_samples (int) – Number of samples to sample

source

SAMPLE_CATEGORICAL

relax.methods.sphere.sample_categorical (rng_key, col_size, n_samples)

candidates = cat_sample(jrand.PRNGKey(0), [2, 3], 10)
assert candidates.shape == (10, 5)
# No categorical features
candidates = cat_sample(jrand.PRNGKey(0), [], 10)
assert jnp.concatenate([jnp.ones((10, 5)), candidates], axis=1).shape == (10, 5)

source

APPLY_IMMUTABLE

relax.methods.sphere.apply_immutable (x, cf, immutable_idx)


source

GSCONFIG

CLASS relax.methods.sphere.GSConfig (seed=42, n_steps=100, n_samples=1000, step_size=0.05, p_norm=2)

Create a new model by parsing and validating input data from keyword arguments.

Raises ValidationError if the input data cannot be parsed to form a valid model.


source

GROWINGSPHERE

CLASS relax.methods.sphere.GrowingSphere (configs=None)

Base CF Explanation Module.

Test

from relax.data import load_data
from relax.module import PredictiveTrainingModule, PredictiveTrainingModuleConfigs, load_pred_model
from relax.evaluate import generate_cf_explanations, benchmark_cfs
from relax.trainer import train_model
dm = load_data('adult', data_configs=dict(sample_frac=0.1))
# dm = load_data('adult',)
# load model
params, training_module = load_pred_model('adult')

# predict function
# pred_fn = lambda x: training_module.forward(params, x, is_training=False)
pred_fn = lambda x, params, key: training_module.forward(
    params, key, x, is_training=False
)
gs = GrowingSphere({'n_steps': 50, 'n_samples': 100, 'step_size': 0.1})
# gs.hook_data_module(dm)
cf_exp = generate_cf_explanations(
    gs, dm, pred_fn=pred_fn, 
    pred_fn_args=dict(
        params=params, key=random.PRNGKey(0)
    )
)
assert not np.array_equal(cf_exp.cfs[0], cf_exp.cfs[1])
benchmark_cfs([cf_exp])
acc validity proximity
adult Growing Sphere 0.8241 1.0 6.270596