dm = load_data('adult')
x_sliced = dm.xs[:1]
feats_info, perturb_fn = features_to_infos_and_perturb_fn(dm.features)
cont_masks, immut_masks, num_categories = feats_info
assert np.array_equal(cont_masks, np.array([1, 1] + [0] * 27))
assert immut_masks.sum() == 2 + 2
assert x_sliced.ndim == 2
cfs = perturb_function_with_features(
jrand.PRNGKey(0), x_sliced, 1000, 1, 0, 2, *feats_info, perturb_fn
)
assert cfs.shape == (1000, 29)
assert cfs[:, 2:].sum() == 1000 * 6
assert default_perturb_function(
jrand.PRNGKey(0), x_sliced, 100, 1, 0, 2,
).shape == (100, 29)Growing Sphere
relax.methods.sphere.hyper_sphere_coordindates
relax.methods.sphere.hyper_sphere_coordindates (rng_key, x, n_samples, high, low, p_norm=2)
Parameters:
- rng_key (
<function PRNGKey at 0x7fcaaae82d40>) – Random number generator key - x (
<class 'jax.Array'>) – Input instance with only continuous features. Shape: (1, n_features) - n_samples (
<class 'int'>) – Number of samples - high (
<class 'float'>) – Upper bound - low (
<class 'float'>) – Lower bound - p_norm (
<class 'int'>, default=2) – Norm
relax.methods.sphere.sample_categorical
relax.methods.sphere.sample_categorical (rng_key, col_size, n_samples)
relax.methods.sphere.default_perturb_function
relax.methods.sphere.default_perturb_function (rng_key, x, n_samples, high, low, p_norm)
Parameters:
- rng_key (
<function PRNGKey at 0x7fcaaae82d40>) - x (
<class 'numpy.ndarray'>) – Shape: (1, k) - n_samples (
<class 'int'>) - high (
<class 'float'>) - low (
<class 'float'>) - p_norm (
<class 'int'>)
relax.methods.sphere.perturb_function_with_features
relax.methods.sphere.perturb_function_with_features (rng_key, x, n_samples, high, low, p_norm, cont_masks, immut_masks, num_categories, cat_perturb_fn)
Parameters:
- rng_key (
<function PRNGKey at 0x7fcaaae82d40>) - x (
<class 'numpy.ndarray'>) – Shape: (1, k) - n_samples (
<class 'int'>) - high (
<class 'float'>) - low (
<class 'float'>) - p_norm (
<class 'int'>) - cont_masks (
<class 'jax.Array'>) - immut_masks (
<class 'jax.Array'>) - num_categories (
list[int]) - cat_perturb_fn (
typing.Callable)
relax.methods.sphere.GSConfig
class relax.methods.sphere.GSConfig (n_steps=100, n_samples=100, step_size=0.05, p_norm=2)
Base class for all config classes.
relax.methods.sphere.GrowingSphere
class relax.methods.sphere.GrowingSphere (config=None, name=None, perturb_fn=None)
Base class for all counterfactual modules.
Methods
set_apply_constraints_fn (apply_constraints_fn)
set_compute_reg_loss_fn (compute_reg_loss_fn)
apply_constraints (*args, **kwargs)
compute_reg_loss (*args, **kwargs)
save (path, save_data_module=True)
load_from_path (path)
before_generate_cf (*args, **kwargs)
generate_cf (*args, **kwargs)
dm = load_data('dummy')
model = load_ml_module('dummy')
xs_train, ys_train = dm['train']
xs_test, ys_test = dm['test']
x_shape = xs_test.shapegs = GrowingSphere()
assert not gs.has_data_module()
gs.set_data_module(dm)
assert gs.has_data_module()
gs.set_apply_constraints_fn(dm.apply_constraints)
gs.before_generate_cf()
cf = gs.generate_cf(xs_test[0], pred_fn=model.pred_fn, rng_key=jax.random.PRNGKey(0))gs.save('tmp/gs/')
gs_1 = GrowingSphere.load_from_path('tmp/gs/')
assert gs_1.has_data_module()
gs_1.set_apply_constraints_fn(dm.apply_constraints)
gs_1.before_generate_cf()
cf_1 = gs_1.generate_cf(xs_test[0], pred_fn=model.pred_fn, rng_key=jax.random.PRNGKey(0))
assert jnp.allclose(cf, cf_1)
shutil.rmtree('tmp/gs/')
gs.save('tmp/gs/', save_data_module=False)
gs_2 = GrowingSphere.load_from_path('tmp/gs/')
assert not gs_2.has_data_module()partial_gen = partial(gs.generate_cf, pred_fn=model.pred_fn)
cfs = jax.jit(jax.vmap(partial_gen))(xs_test, rng_key=jrand.split(jrand.PRNGKey(0), len(xs_test)))
assert cfs.shape == (x_shape[0], x_shape[1])
assert cfs.min() >= 0 and cfs.max() <= 1
print("Validity: ", keras.metrics.binary_accuracy(
(1 - model.pred_fn(xs_test)).round(),
model.pred_fn(cfs[:, :])
).mean())Validity: 1.0