Vanilla CF

relax.methods.vanilla.VanillaCFConfig

[source]

class relax.methods.vanilla.VanillaCFConfig (n_steps=100, lr=0.1, lambda_=0.1, validity_fn=‘KLDivergence’)

Base class for all config classes.

relax.methods.vanilla.VanillaCF

[source]

class relax.methods.vanilla.VanillaCF (config=None, name=None)

Base class for all counterfactual modules.

Methods

[source]

set_apply_constraints_fn (apply_constraints_fn)

[source]

set_compute_reg_loss_fn (compute_reg_loss_fn)

[source]

apply_constraints (*args, **kwargs)

[source]

compute_reg_loss (*args, **kwargs)

[source]

save (path)

[source]

load_from_path (path)

[source]

before_generate_cf (*args, **kwargs)

generate_cf (*args, **kwargs)

dm = load_data('dummy')
model = load_ml_module('dummy')
xs_train, ys_train = dm['train']
xs_test, ys_test = dm['test']
vcf = VanillaCF()
cf = vcf.generate_cf(xs_test[0], model.pred_fn)
assert cf.shape == xs_test[0].shape

partial_gen = ft.partial(vcf.generate_cf, pred_fn=model.pred_fn)
cfs = jax.vmap(partial_gen)(xs_test)

print("Validity: ", keras.metrics.binary_accuracy(
    (1 - model.pred_fn(xs_test)).round(),
    model.pred_fn(cfs)
).mean())
Validity:  0.99600005
def apply_constraint_fn(x, cf, hard=False):
    return jax.lax.cond(
        hard,
        lambda: jnp.clip(cf, 0, 1),
        lambda: cf,
    )

vcf.set_apply_constraints_fn(apply_constraint_fn)
cfs = jax.vmap(partial_gen)(xs_test)

print("Validity: ", keras.metrics.binary_accuracy(
    (1 - model.pred_fn(xs_test)).round(),
    model.pred_fn(cfs)
).mean())
assert (cfs >= 0).all() and (cfs <= 1).all()
Validity:  0.98800004
vcf.save('tmp/vanillacf/')
vcf_1 = VanillaCF.load_from_path('tmp/vanillacf/')
vcf_1.set_apply_constraints_fn(apply_constraint_fn)
partial_gen_1 = ft.partial(vcf_1.generate_cf, pred_fn=model.pred_fn)
cfs_1 = jax.vmap(partial_gen_1)(xs_test)

assert jnp.allclose(cfs, cfs_1)