= load_data('adult')
dm = dm.xs[:1]
x_sliced = features_to_infos_and_perturb_fn(dm.features)
feats_info, perturb_fn = feats_info
cont_masks, immut_masks, num_categories assert np.array_equal(cont_masks, np.array([1, 1] + [0] * 27))
assert immut_masks.sum() == 2 + 2
assert x_sliced.ndim == 2
= perturb_function_with_features(
cfs 0), x_sliced, 1000, 1, 0, 2, *feats_info, perturb_fn
jrand.PRNGKey(
)assert cfs.shape == (1000, 29)
assert cfs[:, 2:].sum() == 1000 * 6
assert default_perturb_function(
0), x_sliced, 100, 1, 0, 2,
jrand.PRNGKey(== (100, 29) ).shape
Growing Sphere
relax.methods.sphere.hyper_sphere_coordindates
relax.methods.sphere.hyper_sphere_coordindates (rng_key, x, n_samples, high, low, p_norm=2)
Parameters:
- rng_key (
<function PRNGKey at 0x7f175f791990>
) – Random number generator key - x (
<class 'jax.Array'>
) – Input instance with only continuous features. Shape: (1, n_features) - n_samples (
<class 'int'>
) – Number of samples - high (
<class 'float'>
) – Upper bound - low (
<class 'float'>
) – Lower bound - p_norm (
<class 'int'>
, default=2) – Norm
relax.methods.sphere.sample_categorical
relax.methods.sphere.sample_categorical (rng_key, col_size, n_samples)
relax.methods.sphere.default_perturb_function
relax.methods.sphere.default_perturb_function (rng_key, x, n_samples, high, low, p_norm)
Parameters:
- rng_key (
<function PRNGKey at 0x7f175f791990>
) - x (
<class 'numpy.ndarray'>
) – Shape: (1, k) - n_samples (
<class 'int'>
) - high (
<class 'float'>
) - low (
<class 'float'>
) - p_norm (
<class 'int'>
)
relax.methods.sphere.perturb_function_with_features
relax.methods.sphere.perturb_function_with_features (rng_key, x, n_samples, high, low, p_norm, cont_masks, immut_masks, num_categories, cat_perturb_fn)
Parameters:
- rng_key (
<function PRNGKey at 0x7f175f791990>
) - x (
<class 'numpy.ndarray'>
) – Shape: (1, k) - n_samples (
<class 'int'>
) - high (
<class 'float'>
) - low (
<class 'float'>
) - p_norm (
<class 'int'>
) - cont_masks (
<class 'jax.Array'>
) - immut_masks (
<class 'jax.Array'>
) - num_categories (
list[int]
) - cat_perturb_fn (
typing.Callable
)
relax.methods.sphere.GSConfig
class relax.methods.sphere.GSConfig (n_steps=100, n_samples=100, step_size=0.05, p_norm=2)
Base class for all config classes.
relax.methods.sphere.GrowingSphere
class relax.methods.sphere.GrowingSphere (config=None, name=None, perturb_fn=None)
Base class for all counterfactual modules.
Methods
set_apply_constraints_fn (apply_constraints_fn)
set_compute_reg_loss_fn (compute_reg_loss_fn)
apply_constraints (*args, **kwargs)
compute_reg_loss (*args, **kwargs)
save (path, save_data_module=True)
load_from_path (path)
before_generate_cf (*args, **kwargs)
generate_cf (*args, **kwargs)
= load_data('dummy')
dm = load_ml_module('dummy')
model = dm['train']
xs_train, ys_train = dm['test']
xs_test, ys_test = xs_test.shape x_shape
= GrowingSphere()
gs assert not gs.has_data_module()
gs.set_data_module(dm)assert gs.has_data_module()
gs.set_apply_constraints_fn(dm.apply_constraints)
gs.before_generate_cf()
= gs.generate_cf(xs_test[0], pred_fn=model.pred_fn, rng_key=jax.random.PRNGKey(0)) cf
'tmp/gs/')
gs.save(= GrowingSphere.load_from_path('tmp/gs/')
gs_1 assert gs_1.has_data_module()
gs_1.set_apply_constraints_fn(dm.apply_constraints)
gs_1.before_generate_cf()
= gs_1.generate_cf(xs_test[0], pred_fn=model.pred_fn, rng_key=jax.random.PRNGKey(0))
cf_1 assert jnp.allclose(cf, cf_1)
'tmp/gs/')
shutil.rmtree('tmp/gs/', save_data_module=False)
gs.save(= GrowingSphere.load_from_path('tmp/gs/')
gs_2 assert not gs_2.has_data_module()
= partial(gs.generate_cf, pred_fn=model.pred_fn)
partial_gen = jax.jit(jax.vmap(partial_gen))(xs_test, rng_key=jrand.split(jrand.PRNGKey(0), len(xs_test)))
cfs
assert cfs.shape == (x_shape[0], x_shape[1])
assert cfs.min() >= 0 and cfs.max() <= 1
print("Validity: ", keras.metrics.binary_accuracy(
1 - model.pred_fn(xs_test)).round(),
(
model.pred_fn(cfs[:, :]) ).mean())
Validity: 1.0