= load_data('dummy')
dm = load_ml_module('dummy')
model = dm['train']
xs_train, ys_train = dm['test'] xs_test, ys_test
Vanilla CF
relax.methods.vanilla.VanillaCFConfig
class relax.methods.vanilla.VanillaCFConfig (n_steps=100, lr=0.1, lambda_=0.1, validity_fn=‘KLDivergence’)
Base class for all config classes.
relax.methods.vanilla.VanillaCF
class relax.methods.vanilla.VanillaCF (config=None, name=None)
Base class for all counterfactual modules.
Methods
set_apply_constraints_fn (apply_constraints_fn)
set_compute_reg_loss_fn (compute_reg_loss_fn)
apply_constraints (*args, **kwargs)
compute_reg_loss (*args, **kwargs)
save (path)
load_from_path (path)
before_generate_cf (*args, **kwargs)
generate_cf (*args, **kwargs)
= VanillaCF()
vcf = vcf.generate_cf(xs_test[0], model.pred_fn)
cf assert cf.shape == xs_test[0].shape
= ft.partial(vcf.generate_cf, pred_fn=model.pred_fn)
partial_gen = jax.vmap(partial_gen)(xs_test)
cfs
print("Validity: ", keras.metrics.binary_accuracy(
1 - model.pred_fn(xs_test)).round(),
(
model.pred_fn(cfs) ).mean())
Validity: 0.99600005
def apply_constraint_fn(x, cf, hard=False):
return jax.lax.cond(
hard,lambda: jnp.clip(cf, 0, 1),
lambda: cf,
)
vcf.set_apply_constraints_fn(apply_constraint_fn)= jax.vmap(partial_gen)(xs_test)
cfs
print("Validity: ", keras.metrics.binary_accuracy(
1 - model.pred_fn(xs_test)).round(),
(
model.pred_fn(cfs)
).mean())assert (cfs >= 0).all() and (cfs <= 1).all()
Validity: 0.98800004
'tmp/vanillacf/')
vcf.save(= VanillaCF.load_from_path('tmp/vanillacf/')
vcf_1
vcf_1.set_apply_constraints_fn(apply_constraint_fn)= ft.partial(vcf_1.generate_cf, pred_fn=model.pred_fn)
partial_gen_1 = jax.vmap(partial_gen_1)(xs_test)
cfs_1
assert jnp.allclose(cfs, cfs_1)