os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'
w = jrand.normal(jrand.PRNGKey(0), (100, 100))
X = jrand.normal(jrand.PRNGKey(0), (1000, 100))
@jit
def pred_fn(x): return jnp.dot(x, w.T)
def f(x, pred_fn=None, **kwargs):
return pred_fn(x)
iter_gen = IterativeGenerationStrategy()
vmap_gen = VmapGenerationStrategy()
pmap_gen = PmapGenerationStrategy()
bvmap_gen = BatchedVmapGenerationStrategy(128)
bpmap_gen = BatchedPmapGenerationStrategy(128)Evaluate
EXPLANATION
CLASS relax.evaluate.Explanation (cf_name, data_module, cfs, total_time, pred_fn, dataset_name=’’, X=None, y=None)
Generated CF Explanations class.
Arguments to Explanation:
cf_name: cf method’s namedataset_name: dataset nameX: inputy: labelcfs: generated cf explanation ofXtotal_time: total runtimepred_fn: predict function with only one input argument, and output a label (i.e., its format isy=pred_fn(x)).data_module: data module
Parallelism Strategy
BASEGENERATIONSTRATEGY
CLASS relax.evaluate.BaseGenerationStrategy ()
Base class for mapping strategy.
ITERATIVEGENERATIONSTRATEGY
CLASS relax.evaluate.IterativeGenerationStrategy ()
Iterativly generate counterfactuals.
VMAPGENERATIONSTRATEGY
CLASS relax.evaluate.VmapGenerationStrategy ()
Generate counterfactuals via jax.vmap.
PMAPGENERATIONSTRATEGY
CLASS relax.evaluate.PmapGenerationStrategy (n_devices=None, strategy=‘auto’, **kwargs)
Base class for mapping strategy.
BATCHEDVMAPGENERATIONSTRATEGY
CLASS relax.evaluate.BatchedVmapGenerationStrategy (batch_size)
Auto-batching for generate counterfactuals via jax.vmap.
BATCHEDPMAPGENERATIONSTRATEGY
CLASS relax.evaluate.BatchedPmapGenerationStrategy (batch_size, n_devices=None)
Auto-batching for generate counterfactuals via jax.vmap.
cf_iter = iter_gen(f, X, pred_fn=pred_fn).block_until_ready()cf_vmap = vmap_gen(f, X, pred_fn=pred_fn).block_until_ready()cf_pmap = pmap_gen(f, X, pred_fn=pred_fn).block_until_ready()cf_bvmap = bvmap_gen(f, X, pred_fn=pred_fn).block_until_ready()# check when batch_size > X.shape[0]
_bvmap_gen = BatchedVmapGenerationStrategy(1280)
_cf_bvmap = _bvmap_gen(f, X, pred_fn=pred_fn).block_until_ready()
assert jnp.allclose(cf_bvmap, _cf_bvmap, atol=1e-4)cf_bpmap = bpmap_gen(f, X, pred_fn=pred_fn).block_until_ready()assert jnp.allclose(cf_iter, cf_vmap, atol=1e-4)
assert jnp.allclose(cf_iter, cf_bvmap, atol=1e-4)
assert jnp.allclose(cf_iter, cf_pmap, atol=1e-4)
assert jnp.allclose(cf_iter, cf_bpmap, atol=1e-4)STRATEGYFACTORY
CLASS relax.evaluate.StrategyFactory ()
Factory class for Parallelism Strategy.
Generating CF Explanation Results
GENERATE_CF_EXPLANATIONS
relax.evaluate.generate_cf_explanations (cf_module, datamodule, pred_fn=None, strategy=‘vmap’, t_configs=None, pred_fn_args=None)
Generate CF explanations.
The pred_fn in generate_cf_explanations is a model’s prediction function. The general format is y = pred_fn(x, **pred_fn_args). If pred_fn is not parameterized by other variables (except input x), then pred_fn_args is set to None, which is the default setting. Otherwise, you should pass these argument as a dict.
For example, we have a simple linear function
def linear_pred_fn(x: jnp.DeviceArray, params: jnp.DeviceArray):
return x @ paramsTo pass linear_pred_fn to generate_cf_explanations, we can either create an auxiliary function of linear_pred_fn, or pass params into pred_fn_args.
Assuming we now have the input x and params:
x = jax.random.normal(random.PRNGKey(0), shape=(5, 10)) # input
params = jnp.ones((10, 1)) # paramsWARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
- Create an auxillary function (Not recommended)
aux_linear_pred_fn = lambda x: linear_pred_fn(x, params)
explanations = generate_cf_explanations(
cf_module, datamodule, aux_linear_pred_fn
)This approach could work, but if params is changed, explanations.pred_fn might not work as expected.
- Pass
paramsintopred_fn_args
explanations = generate_cf_explanations(
cf_module, datamodule, linear_pred_fn,
pred_fn_args=dict(params=params)
)This is a recommended approach as we will deepcopy params inside generate_cf_explanations.
The pred_fn in explanations only takes x: jnp.DeviceArray as an input. For example, to make predictions, we use
y = explanations.pred_fn(x)Evaluating Metrics
BASEEVALMETRICS
CLASS relax.evaluate.BaseEvalMetrics (name=None)
Base evaluation metrics class.
PREDICTIVEACCURACY
CLASS relax.evaluate.PredictiveAccuracy (name=‘accuracy’)
Compute the accuracy of the predict function.
VALIDITY
CLASS relax.evaluate.Validity (name=‘validity’)
Compute fraction of input instances on which CF explanation methods output valid CF examples.
inputs = jnp.array([
[0, 1], [1, 0], [1, 1]])
cfs_1 = jnp.array([
[0, 0], [0, 0], [0, 1]])
cfs_2 = jnp.array([
[1, 0], [1, -2], [0, 0]])
assert _compute_proximity(inputs, cfs_1) == 1.0
assert _compute_proximity(inputs, cfs_2) == 2.0PROXIMITY
CLASS relax.evaluate.Proximity (name=‘proximity’)
Compute L1 norm distance between input datasets and CF examples divided by the number of features.
SPARSITY
CLASS relax.evaluate.Sparsity (name=‘sparsity’)
Compute the number of feature changes between input datasets and CF examples.
MANIFOLDDIST
CLASS relax.evaluate.ManifoldDist (n_neighbors=1, p=2, name=‘manifold_dist’)
Compute the L1 distance to the n-nearest neighbor for all CF examples.
RUNTIME
CLASS relax.evaluate.Runtime (name=‘runtime’)
Get the running time to generate CF examples.
COMPUTE_SO_SPARSITY
relax.evaluate.compute_so_sparsity (cf_results, threshold=2.0)
COMPUTE_SO_PROXIMITY
relax.evaluate.compute_so_proximity (cf_results, threshold=2.0)
COMPUTE_SO_VALIDITY
relax.evaluate.compute_so_validity (cf_results, threshold=2.0)
Benchmarking
EVALUATE_CFS
relax.evaluate.evaluate_cfs (cf_exp, metrics=None, return_dict=True, return_df=False)
BENCHMARK_CFS
relax.evaluate.benchmark_cfs (cf_results_list, metrics=None)
How to evaluate a CF Explanation Module
from relax.module import PredictiveTrainingModule
from relax.trainer import train_model
from relax.utils import load_jsonconfigs = load_json('assets/configs/data_configs/adult.json')
m_configs = configs['mlp_configs']
data_configs = configs['data_configs']
data_configs['sample_frac'] = 0.1
t_configs = {
'n_epochs': 10,
'monitor_metrics': 'val/val_loss',
'seed': 42,
"batch_size": 256
}We first train a model
training_module = PredictiveTrainingModule(m_configs)
dm = TabularDataModule(data_configs)
params, opt_state = train_model(
training_module,
dm,
t_configs
)
pred_fn = lambda x, params, prng_key: \
training_module.forward(params, prng_key, x, is_training=False)Epoch 9: 100%|██████████| 96/96 [00:01<00:00, 53.81batch/s, train/train_loss_1=0.0791]
Now, we can start to benchmark different methods
from relax.methods import VanillaCF, CounterNetGenerate CF explanations for VanillaCF
vanillacf = VanillaCF(dict(n_steps=1000, lr=0.001))
vanillacf_exp = generate_cf_explanations(
vanillacf, dm, pred_fn,
pred_fn_args=dict(params=params, prng_key=random.PRNGKey(0))
)100%|██████████| 1000/1000 [00:10<00:00, 92.93it/s]
assert vanillacf_exp.cf_name == vanillacf.name
assert vanillacf_exp.dataset_name == dm.data_name
assert vanillacf_exp.X.shape == vanillacf_exp.cfs.shape
assert vanillacf_exp.pred_fn(vanillacf_exp.X).shape == vanillacf_exp.y.shapeGenerate CF explanations for CounterNet
counternet = CounterNet()
counternet_exp = generate_cf_explanations(counternet, dm, pred_fn=None)CounterNet contains parametric models. Starts training before generating explanations...
Epoch 99: 100%|██████████| 191/191 [00:03<00:00, 58.07batch/s, train/train_loss_1=0.0657, train/train_loss_2=0.000985, train/train_loss_3=0.0963]
Note that CounterNet contains a predictive module, so we set pred_fn=None
assert counternet_exp.cf_name == counternet.name
assert counternet_exp.dataset_name == dm.data_name
assert counternet_exp.X.shape == counternet_exp.cfs.shape
assert counternet_exp.pred_fn(counternet_exp.X).shape == counternet_exp.y.shape
assert counternet_exp.pred_fn == counternet.pred_fnIf cf_module is a subclass of BasePredFnCFModule (e.g., CounterNet), the pred_fn in Explanation will be set to cf_module.pred_fn, and the pred_fn argument passed generate_cf_explanations will be ignored.
counternet_exp_1 = generate_cf_explanations(counternet, dm, pred_fn=pred_fn)
assert counternet_exp_1.pred_fn != pred_fn
assert counternet_exp_1.pred_fn == counternet.pred_fnCounterNet contains parametric models. Starts training before generating explanations...
Epoch 99: 100%|██████████| 191/191 [00:02<00:00, 64.02batch/s, train/train_loss_1=0.0713, train/train_loss_2=0.000427, train/train_loss_3=0.0944]
Now, we can compute metrics for benchmarking different CF explanation methods.
evaluate_cfs(vanillacf_exp, return_df=True)[1]| acc | validity | proximity | ||
|---|---|---|---|---|
| adult | VanillaCF | 0.822012 | 0.93674 | 7.62256 |
evaluate_cfs(counternet_exp, return_df=True)[1]| acc | validity | proximity | ||
|---|---|---|---|---|
| adult | CounterNet | 0.831347 | 0.958605 | 5.9374576 |
benchmark_cfs([vanillacf_exp, counternet_exp])| acc | validity | proximity | ||
|---|---|---|---|---|
| adult | VanillaCF | 0.822012 | 0.936740 | 7.62256 |
| CounterNet | 0.831347 | 0.958605 | 5.9374576 |