'XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'
os.environ[
= jrand.normal(jrand.PRNGKey(0), (100, 100))
w = jrand.normal(jrand.PRNGKey(0), (1000, 100))
X
@jit
def pred_fn(x): return jnp.dot(x, w.T)
def f(x, pred_fn=None, **kwargs):
return pred_fn(x)
= IterativeGenerationStrategy()
iter_gen = VmapGenerationStrategy()
vmap_gen = PmapGenerationStrategy()
pmap_gen = BatchedVmapGenerationStrategy(128)
bvmap_gen = BatchedPmapGenerationStrategy(128) bpmap_gen
Evaluate
EXPLANATION
CLASS relax.evaluate.Explanation (cf_name, data_module, cfs, total_time, pred_fn, dataset_name=’’, X=None, y=None)
Generated CF Explanations class.
Arguments to Explanation
:
cf_name
: cf method’s namedataset_name
: dataset nameX
: inputy
: labelcfs
: generated cf explanation ofX
total_time
: total runtimepred_fn
: predict function with only one input argument, and output a label (i.e., its format isy=pred_fn(x)
).data_module
: data module
Parallelism Strategy
BASEGENERATIONSTRATEGY
CLASS relax.evaluate.BaseGenerationStrategy ()
Base class for mapping strategy.
ITERATIVEGENERATIONSTRATEGY
CLASS relax.evaluate.IterativeGenerationStrategy ()
Iterativly generate counterfactuals.
VMAPGENERATIONSTRATEGY
CLASS relax.evaluate.VmapGenerationStrategy ()
Generate counterfactuals via jax.vmap
.
PMAPGENERATIONSTRATEGY
CLASS relax.evaluate.PmapGenerationStrategy (n_devices=None, strategy=‘auto’, **kwargs)
Base class for mapping strategy.
BATCHEDVMAPGENERATIONSTRATEGY
CLASS relax.evaluate.BatchedVmapGenerationStrategy (batch_size)
Auto-batching for generate counterfactuals via jax.vmap
.
BATCHEDPMAPGENERATIONSTRATEGY
CLASS relax.evaluate.BatchedPmapGenerationStrategy (batch_size, n_devices=None)
Auto-batching for generate counterfactuals via jax.vmap
.
= iter_gen(f, X, pred_fn=pred_fn).block_until_ready() cf_iter
= vmap_gen(f, X, pred_fn=pred_fn).block_until_ready() cf_vmap
= pmap_gen(f, X, pred_fn=pred_fn).block_until_ready() cf_pmap
= bvmap_gen(f, X, pred_fn=pred_fn).block_until_ready() cf_bvmap
# check when batch_size > X.shape[0]
= BatchedVmapGenerationStrategy(1280)
_bvmap_gen = _bvmap_gen(f, X, pred_fn=pred_fn).block_until_ready()
_cf_bvmap assert jnp.allclose(cf_bvmap, _cf_bvmap, atol=1e-4)
= bpmap_gen(f, X, pred_fn=pred_fn).block_until_ready() cf_bpmap
assert jnp.allclose(cf_iter, cf_vmap, atol=1e-4)
assert jnp.allclose(cf_iter, cf_bvmap, atol=1e-4)
assert jnp.allclose(cf_iter, cf_pmap, atol=1e-4)
assert jnp.allclose(cf_iter, cf_bpmap, atol=1e-4)
STRATEGYFACTORY
CLASS relax.evaluate.StrategyFactory ()
Factory class for Parallelism Strategy.
Generating CF Explanation Results
GENERATE_CF_EXPLANATIONS
relax.evaluate.generate_cf_explanations (cf_module, datamodule, pred_fn=None, strategy=‘vmap’, t_configs=None, pred_fn_args=None)
Generate CF explanations.
The pred_fn
in generate_cf_explanations
is a model’s prediction function. The general format is y = pred_fn(x, **pred_fn_args)
. If pred_fn
is not parameterized by other variables (except input x
), then pred_fn_args
is set to None
, which is the default setting. Otherwise, you should pass these argument as a dict
.
For example, we have a simple linear function
def linear_pred_fn(x: jnp.DeviceArray, params: jnp.DeviceArray):
return x @ params
To pass linear_pred_fn
to generate_cf_explanations
, we can either create an auxiliary function of linear_pred_fn
, or pass params
into pred_fn_args
.
Assuming we now have the input x
and params
:
= jax.random.normal(random.PRNGKey(0), shape=(5, 10)) # input
x = jnp.ones((10, 1)) # params params
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
- Create an auxillary function (Not recommended)
= lambda x: linear_pred_fn(x, params)
aux_linear_pred_fn = generate_cf_explanations(
explanations
cf_module, datamodule, aux_linear_pred_fn )
This approach could work, but if params
is changed, explanations.pred_fn
might not work as expected.
- Pass
params
intopred_fn_args
= generate_cf_explanations(
explanations
cf_module, datamodule, linear_pred_fn, =dict(params=params)
pred_fn_args )
This is a recommended approach as we will deepcopy params
inside generate_cf_explanations
.
The pred_fn
in explanations
only takes x: jnp.DeviceArray
as an input. For example, to make predictions, we use
= explanations.pred_fn(x) y
Evaluating Metrics
BASEEVALMETRICS
CLASS relax.evaluate.BaseEvalMetrics (name=None)
Base evaluation metrics class.
PREDICTIVEACCURACY
CLASS relax.evaluate.PredictiveAccuracy (name=‘accuracy’)
Compute the accuracy of the predict function.
VALIDITY
CLASS relax.evaluate.Validity (name=‘validity’)
Compute fraction of input instances on which CF explanation methods output valid CF examples.
= jnp.array([
inputs 0, 1], [1, 0], [1, 1]])
[= jnp.array([
cfs_1 0, 0], [0, 0], [0, 1]])
[= jnp.array([
cfs_2 1, 0], [1, -2], [0, 0]])
[assert _compute_proximity(inputs, cfs_1) == 1.0
assert _compute_proximity(inputs, cfs_2) == 2.0
PROXIMITY
CLASS relax.evaluate.Proximity (name=‘proximity’)
Compute L1 norm distance between input datasets and CF examples divided by the number of features.
SPARSITY
CLASS relax.evaluate.Sparsity (name=‘sparsity’)
Compute the number of feature changes between input datasets and CF examples.
MANIFOLDDIST
CLASS relax.evaluate.ManifoldDist (n_neighbors=1, p=2, name=‘manifold_dist’)
Compute the L1 distance to the n-nearest neighbor for all CF examples.
RUNTIME
CLASS relax.evaluate.Runtime (name=‘runtime’)
Get the running time to generate CF examples.
COMPUTE_SO_SPARSITY
relax.evaluate.compute_so_sparsity (cf_results, threshold=2.0)
COMPUTE_SO_PROXIMITY
relax.evaluate.compute_so_proximity (cf_results, threshold=2.0)
COMPUTE_SO_VALIDITY
relax.evaluate.compute_so_validity (cf_results, threshold=2.0)
Benchmarking
EVALUATE_CFS
relax.evaluate.evaluate_cfs (cf_exp, metrics=None, return_dict=True, return_df=False)
BENCHMARK_CFS
relax.evaluate.benchmark_cfs (cf_results_list, metrics=None)
How to evaluate a CF Explanation Module
from relax.module import PredictiveTrainingModule
from relax.trainer import train_model
from relax.utils import load_json
= load_json('assets/configs/data_configs/adult.json')
configs = configs['mlp_configs']
m_configs = configs['data_configs']
data_configs 'sample_frac'] = 0.1
data_configs[
= {
t_configs 'n_epochs': 10,
'monitor_metrics': 'val/val_loss',
'seed': 42,
"batch_size": 256
}
We first train a model
= PredictiveTrainingModule(m_configs)
training_module = TabularDataModule(data_configs)
dm
= train_model(
params, opt_state
training_module,
dm,
t_configs
)= lambda x, params, prng_key: \
pred_fn =False) training_module.forward(params, prng_key, x, is_training
Epoch 9: 100%|██████████| 96/96 [00:01<00:00, 53.81batch/s, train/train_loss_1=0.0791]
Now, we can start to benchmark different methods
from relax.methods import VanillaCF, CounterNet
Generate CF explanations for VanillaCF
= VanillaCF(dict(n_steps=1000, lr=0.001))
vanillacf = generate_cf_explanations(
vanillacf_exp
vanillacf, dm, pred_fn,=dict(params=params, prng_key=random.PRNGKey(0))
pred_fn_args )
100%|██████████| 1000/1000 [00:10<00:00, 92.93it/s]
assert vanillacf_exp.cf_name == vanillacf.name
assert vanillacf_exp.dataset_name == dm.data_name
assert vanillacf_exp.X.shape == vanillacf_exp.cfs.shape
assert vanillacf_exp.pred_fn(vanillacf_exp.X).shape == vanillacf_exp.y.shape
Generate CF explanations for CounterNet
= CounterNet()
counternet = generate_cf_explanations(counternet, dm, pred_fn=None) counternet_exp
CounterNet contains parametric models. Starts training before generating explanations...
Epoch 99: 100%|██████████| 191/191 [00:03<00:00, 58.07batch/s, train/train_loss_1=0.0657, train/train_loss_2=0.000985, train/train_loss_3=0.0963]
Note that CounterNet
contains a predictive module, so we set pred_fn=None
assert counternet_exp.cf_name == counternet.name
assert counternet_exp.dataset_name == dm.data_name
assert counternet_exp.X.shape == counternet_exp.cfs.shape
assert counternet_exp.pred_fn(counternet_exp.X).shape == counternet_exp.y.shape
assert counternet_exp.pred_fn == counternet.pred_fn
If cf_module
is a subclass of BasePredFnCFModule
(e.g., CounterNet
), the pred_fn
in Explanation
will be set to cf_module.pred_fn
, and the pred_fn
argument passed generate_cf_explanations
will be ignored.
= generate_cf_explanations(counternet, dm, pred_fn=pred_fn)
counternet_exp_1 assert counternet_exp_1.pred_fn != pred_fn
assert counternet_exp_1.pred_fn == counternet.pred_fn
CounterNet contains parametric models. Starts training before generating explanations...
Epoch 99: 100%|██████████| 191/191 [00:02<00:00, 64.02batch/s, train/train_loss_1=0.0713, train/train_loss_2=0.000427, train/train_loss_3=0.0944]
Now, we can compute metrics for benchmarking different CF explanation methods.
=True)[1] evaluate_cfs(vanillacf_exp, return_df
acc | validity | proximity | ||
---|---|---|---|---|
adult | VanillaCF | 0.822012 | 0.93674 | 7.62256 |
=True)[1] evaluate_cfs(counternet_exp, return_df
acc | validity | proximity | ||
---|---|---|---|---|
adult | CounterNet | 0.831347 | 0.958605 | 5.9374576 |
benchmark_cfs([vanillacf_exp, counternet_exp])
acc | validity | proximity | ||
---|---|---|---|---|
adult | VanillaCF | 0.822012 | 0.936740 | 7.62256 |
CounterNet | 0.831347 | 0.958605 | 5.9374576 |