# From the original dice implementation
# https://github.com/interpretml/DiCE/blob/a772c8d4fcd88d1cab7f2e02b0bcc045dc0e2eab/dice_ml/explainer_interfaces/dice_pytorch.py#L222-L227
def dpp_style_torch(cfs: torch.Tensor):
= lambda x, y: torch.abs(x-y).sum()
compute_dist
= len(cfs)
total_CFs = torch.ones((total_CFs, total_CFs))
det_entries for i in range(total_CFs):
for j in range(total_CFs):
= 1.0/(1.0 + compute_dist(cfs[i], cfs[j]))
det_entries[(i,j)] if i == j:
+= 1e-8
det_entries[(i,j)] return torch.det(det_entries)
Diverse CF
Util Functions
relax.methods.dice.dpp_style_vmap
relax.methods.dice.dpp_style_vmap (cfs)
def jax2torch(x: Array):
return torch.from_numpy(x.__array__())
= jrand.normal(jrand.PRNGKey(0), (100, 100))
cfs = jax2torch(cfs)
cfs_tensor assert np.allclose(
dpp_style_torch(cfs_tensor).numpy(),
dpp_style_vmap(cfs) )
/tmp/ipykernel_11637/3412149913.py:2: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:206.)
return torch.from_numpy(x.__array__())
Our jax-based implementation is ~500X faster than DiCE’s pytorch implementation.
= dpp_style_torch(cfs_tensor) torch_res
318 ms ± 4.24 ms per loop (mean ± std. dev. of 5 runs, 1 loop each)
= dpp_style_vmap(cfs) jax_res
571 µs ± 44.4 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
Config
relax.methods.dice.DiverseCFConfig
class relax.methods.dice.DiverseCFConfig (n_cfs=5, n_steps=1000, lr=0.001, lambda_1=1.0, lambda_2=1.0, lambda_3=1.0, lambda_4=0.1, validity_fn=‘KLDivergence’, cost_fn=‘MeanSquaredError’, seed=42)
Base class for all config classes.
relax.methods.dice.DiverseCF
class relax.methods.dice.DiverseCF (config=None, name=None)
Base class for all counterfactual modules.
Methods
set_apply_constraints_fn (apply_constraints_fn)
set_compute_reg_loss_fn (compute_reg_loss_fn)
apply_constraints (*args, **kwargs)
compute_reg_loss (*args, **kwargs)
save (path)
load_from_path (path)
before_generate_cf (*args, **kwargs)
generate_cf (*args, **kwargs)
= load_data('dummy')
dm = load_ml_module('dummy')
model = dm['train']
xs_train, ys_train = dm['test']
xs_test, ys_test = xs_test.shape x_shape
/home/birk/code/jax-relax/relax/data_module.py:234: UserWarning: Passing `config` will have no effect.
warnings.warn("Passing `config` will have no effect.")
= DiverseCF({'lambda_2': 4.0})
dcf
dcf.set_apply_constraints_fn(dm.apply_constraints)
dcf.set_compute_reg_loss_fn(dm.compute_reg_loss)= dcf.generate_cf(xs_test[0], model.pred_fn, rng_key=jrand.PRNGKey(0))
cf assert cf.shape == (5, x_shape[1])
= partial(dcf.generate_cf, pred_fn=model.pred_fn)
partial_gen = jax.vmap(partial_gen)(xs_test, rng_key=jrand.split(jrand.PRNGKey(0), xs_test.shape[0]))
cfs
assert cfs.shape == (x_shape[0], 5, x_shape[1])
print("Validity: ", keras.metrics.binary_accuracy(
1 - model.pred_fn(xs_test)).round(),
(0, :])
model.pred_fn(cfs[:, ).mean())
Validity: 1.0
'tmp/dice/')
dcf.save(= DiverseCF.load_from_path('tmp/dice/')
dcf_1
dcf_1.set_apply_constraints_fn(dm.apply_constraints)= ft.partial(dcf_1.generate_cf, pred_fn=model.pred_fn)
partial_gen_1 = jax.vmap(partial_gen_1)(xs_test, rng_key=jrand.split(jrand.PRNGKey(0), xs_test.shape[0]))
cfs_1
assert jnp.allclose(cfs, cfs_1)
= relax.generate_cf_explanations(
exp
dcf, dm, model.pred_fn
) relax.benchmark_cfs([exp])
acc | validity | proximity | ||
---|---|---|---|---|
dummy | DiverseCF | 0.983 | 1.0 | 1.264459 |