Growing Sphere

relax.methods.sphere.hyper_sphere_coordindates

[source]

relax.methods.sphere.hyper_sphere_coordindates (rng_key, x, n_samples, high, low, p_norm=2)

Parameters:

  • rng_key (<function PRNGKey at 0x7f175f791990>) – Random number generator key
  • x (<class 'jax.Array'>) – Input instance with only continuous features. Shape: (1, n_features)
  • n_samples (<class 'int'>) – Number of samples
  • high (<class 'float'>) – Upper bound
  • low (<class 'float'>) – Lower bound
  • p_norm (<class 'int'>, default=2) – Norm

relax.methods.sphere.sample_categorical

[source]

relax.methods.sphere.sample_categorical (rng_key, col_size, n_samples)

relax.methods.sphere.default_perturb_function

[source]

relax.methods.sphere.default_perturb_function (rng_key, x, n_samples, high, low, p_norm)

Parameters:

  • rng_key (<function PRNGKey at 0x7f175f791990>)
  • x (<class 'numpy.ndarray'>) – Shape: (1, k)
  • n_samples (<class 'int'>)
  • high (<class 'float'>)
  • low (<class 'float'>)
  • p_norm (<class 'int'>)

relax.methods.sphere.perturb_function_with_features

[source]

relax.methods.sphere.perturb_function_with_features (rng_key, x, n_samples, high, low, p_norm, cont_masks, immut_masks, num_categories, cat_perturb_fn)

Parameters:

  • rng_key (<function PRNGKey at 0x7f175f791990>)
  • x (<class 'numpy.ndarray'>) – Shape: (1, k)
  • n_samples (<class 'int'>)
  • high (<class 'float'>)
  • low (<class 'float'>)
  • p_norm (<class 'int'>)
  • cont_masks (<class 'jax.Array'>)
  • immut_masks (<class 'jax.Array'>)
  • num_categories (list[int])
  • cat_perturb_fn (typing.Callable)
dm = load_data('adult')
x_sliced = dm.xs[:1]
feats_info, perturb_fn = features_to_infos_and_perturb_fn(dm.features)
cont_masks, immut_masks, num_categories = feats_info
assert np.array_equal(cont_masks, np.array([1, 1] + [0] * 27))
assert immut_masks.sum() == 2 + 2
assert x_sliced.ndim == 2
cfs = perturb_function_with_features(
    jrand.PRNGKey(0), x_sliced, 1000, 1, 0, 2, *feats_info, perturb_fn
)
assert cfs.shape == (1000, 29)
assert cfs[:, 2:].sum() == 1000 * 6
assert default_perturb_function(
    jrand.PRNGKey(0), x_sliced, 100, 1, 0, 2,
).shape == (100, 29)

relax.methods.sphere.GSConfig

[source]

class relax.methods.sphere.GSConfig (n_steps=100, n_samples=100, step_size=0.05, p_norm=2)

Base class for all config classes.

relax.methods.sphere.GrowingSphere

[source]

class relax.methods.sphere.GrowingSphere (config=None, name=None, perturb_fn=None)

Base class for all counterfactual modules.

Methods

[source]

set_apply_constraints_fn (apply_constraints_fn)

[source]

set_compute_reg_loss_fn (compute_reg_loss_fn)

[source]

apply_constraints (*args, **kwargs)

[source]

compute_reg_loss (*args, **kwargs)

[source]

save (path, save_data_module=True)

[source]

load_from_path (path)

[source]

before_generate_cf (*args, **kwargs)

generate_cf (*args, **kwargs)

dm = load_data('dummy')
model = load_ml_module('dummy')
xs_train, ys_train = dm['train']
xs_test, ys_test = dm['test']
x_shape = xs_test.shape
gs = GrowingSphere()
assert not gs.has_data_module()
gs.set_data_module(dm)
assert gs.has_data_module()
gs.set_apply_constraints_fn(dm.apply_constraints)
gs.before_generate_cf()

cf = gs.generate_cf(xs_test[0], pred_fn=model.pred_fn, rng_key=jax.random.PRNGKey(0))
gs.save('tmp/gs/')
gs_1 = GrowingSphere.load_from_path('tmp/gs/')
assert gs_1.has_data_module()
gs_1.set_apply_constraints_fn(dm.apply_constraints)
gs_1.before_generate_cf()

cf_1 = gs_1.generate_cf(xs_test[0], pred_fn=model.pred_fn, rng_key=jax.random.PRNGKey(0))
assert jnp.allclose(cf, cf_1)

shutil.rmtree('tmp/gs/')
gs.save('tmp/gs/', save_data_module=False)
gs_2 = GrowingSphere.load_from_path('tmp/gs/')
assert not gs_2.has_data_module()
partial_gen = partial(gs.generate_cf, pred_fn=model.pred_fn)
cfs = jax.jit(jax.vmap(partial_gen))(xs_test, rng_key=jrand.split(jrand.PRNGKey(0), len(xs_test)))

assert cfs.shape == (x_shape[0], x_shape[1])
assert cfs.min() >= 0 and cfs.max() <= 1

print("Validity: ", keras.metrics.binary_accuracy(
    (1 - model.pred_fn(xs_test)).round(),
    model.pred_fn(cfs[:, :])
).mean())
Validity:  1.0