dm = load_data('dummy')
model = load_ml_module('dummy')
xs_train, ys_train = dm['train']
xs_test, ys_test = dm['test']Vanilla CF
relax.methods.vanilla.VanillaCFConfig
class relax.methods.vanilla.VanillaCFConfig (n_steps=100, lr=0.1, lambda_=0.1, validity_fn=‘KLDivergence’)
Base class for all config classes.
relax.methods.vanilla.VanillaCF
class relax.methods.vanilla.VanillaCF (config=None, name=None)
Base class for all counterfactual modules.
Methods
set_apply_constraints_fn (apply_constraints_fn)
set_compute_reg_loss_fn (compute_reg_loss_fn)
apply_constraints (*args, **kwargs)
compute_reg_loss (*args, **kwargs)
save (path)
load_from_path (path)
before_generate_cf (*args, **kwargs)
generate_cf (*args, **kwargs)
vcf = VanillaCF()
cf = vcf.generate_cf(xs_test[0], model.pred_fn)
assert cf.shape == xs_test[0].shape
partial_gen = ft.partial(vcf.generate_cf, pred_fn=model.pred_fn)
cfs = jax.vmap(partial_gen)(xs_test)
print("Validity: ", keras.metrics.binary_accuracy(
    (1 - model.pred_fn(xs_test)).round(),
    model.pred_fn(cfs)
).mean())Validity:  0.99600005def apply_constraint_fn(x, cf, hard=False):
    return jax.lax.cond(
        hard,
        lambda: jnp.clip(cf, 0, 1),
        lambda: cf,
    )
vcf.set_apply_constraints_fn(apply_constraint_fn)
cfs = jax.vmap(partial_gen)(xs_test)
print("Validity: ", keras.metrics.binary_accuracy(
    (1 - model.pred_fn(xs_test)).round(),
    model.pred_fn(cfs)
).mean())
assert (cfs >= 0).all() and (cfs <= 1).all()Validity:  0.98800004vcf.save('tmp/vanillacf/')
vcf_1 = VanillaCF.load_from_path('tmp/vanillacf/')
vcf_1.set_apply_constraints_fn(apply_constraint_fn)
partial_gen_1 = ft.partial(vcf_1.generate_cf, pred_fn=model.pred_fn)
cfs_1 = jax.vmap(partial_gen_1)(xs_test)
assert jnp.allclose(cfs, cfs_1)