Diverse CF

Util Functions

relax.methods.dice.dpp_style_vmap

[source]

relax.methods.dice.dpp_style_vmap (cfs)

# From the original dice implementation
# https://github.com/interpretml/DiCE/blob/a772c8d4fcd88d1cab7f2e02b0bcc045dc0e2eab/dice_ml/explainer_interfaces/dice_pytorch.py#L222-L227
def dpp_style_torch(cfs: torch.Tensor):
    compute_dist = lambda x, y: torch.abs(x-y).sum()

    total_CFs = len(cfs)
    det_entries = torch.ones((total_CFs, total_CFs))
    for i in range(total_CFs):
        for j in range(total_CFs):
            det_entries[(i,j)] = 1.0/(1.0 + compute_dist(cfs[i], cfs[j]))
            if i == j:
                det_entries[(i,j)] += 1e-8
    return torch.det(det_entries)
def jax2torch(x: Array):
    return torch.from_numpy(x.__array__())
cfs = jrand.normal(jrand.PRNGKey(0), (100, 100))
cfs_tensor = jax2torch(cfs)
assert np.allclose(
    dpp_style_torch(cfs_tensor).numpy(),
    dpp_style_vmap(cfs)
)
/tmp/ipykernel_11637/3412149913.py:2: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:206.)
  return torch.from_numpy(x.__array__())

Our jax-based implementation is ~500X faster than DiCE’s pytorch implementation.

torch_res = dpp_style_torch(cfs_tensor)
318 ms ± 4.24 ms per loop (mean ± std. dev. of 5 runs, 1 loop each)
jax_res = dpp_style_vmap(cfs)
571 µs ± 44.4 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)

Config

relax.methods.dice.DiverseCFConfig

[source]

class relax.methods.dice.DiverseCFConfig (n_cfs=5, n_steps=1000, lr=0.001, lambda_1=1.0, lambda_2=1.0, lambda_3=1.0, lambda_4=0.1, validity_fn=‘KLDivergence’, cost_fn=‘MeanSquaredError’, seed=42)

Base class for all config classes.

relax.methods.dice.DiverseCF

[source]

class relax.methods.dice.DiverseCF (config=None, name=None)

Base class for all counterfactual modules.

Methods

[source]

set_apply_constraints_fn (apply_constraints_fn)

[source]

set_compute_reg_loss_fn (compute_reg_loss_fn)

[source]

apply_constraints (*args, **kwargs)

[source]

compute_reg_loss (*args, **kwargs)

[source]

save (path)

[source]

load_from_path (path)

[source]

before_generate_cf (*args, **kwargs)

generate_cf (*args, **kwargs)

dm = load_data('dummy')
model = load_ml_module('dummy')
xs_train, ys_train = dm['train']
xs_test, ys_test = dm['test']
x_shape = xs_test.shape
/home/birk/code/jax-relax/relax/data_module.py:234: UserWarning: Passing `config` will have no effect.
  warnings.warn("Passing `config` will have no effect.")
dcf = DiverseCF({'lambda_2': 4.0})
dcf.set_apply_constraints_fn(dm.apply_constraints)
dcf.set_compute_reg_loss_fn(dm.compute_reg_loss)
cf = dcf.generate_cf(xs_test[0], model.pred_fn, rng_key=jrand.PRNGKey(0))
assert cf.shape == (5, x_shape[1])

partial_gen = partial(dcf.generate_cf, pred_fn=model.pred_fn)
cfs = jax.vmap(partial_gen)(xs_test, rng_key=jrand.split(jrand.PRNGKey(0), xs_test.shape[0]))

assert cfs.shape == (x_shape[0], 5, x_shape[1])

print("Validity: ", keras.metrics.binary_accuracy(
    (1 - model.pred_fn(xs_test)).round(),
    model.pred_fn(cfs[:, 0, :])
).mean())
Validity:  1.0
dcf.save('tmp/dice/')
dcf_1 = DiverseCF.load_from_path('tmp/dice/')
dcf_1.set_apply_constraints_fn(dm.apply_constraints)
partial_gen_1 = ft.partial(dcf_1.generate_cf, pred_fn=model.pred_fn)
cfs_1 = jax.vmap(partial_gen_1)(xs_test, rng_key=jrand.split(jrand.PRNGKey(0), xs_test.shape[0]))

assert jnp.allclose(cfs, cfs_1)
exp = relax.generate_cf_explanations(
    dcf, dm, model.pred_fn
)
relax.benchmark_cfs([exp])
acc validity proximity
dummy DiverseCF 0.983 1.0 1.264459