Evaluate

Evaluating and benchmarking the quality of CF explanations.

source

EXPLANATION

CLASS relax.evaluate.Explanation (cf_name, data_module, cfs, total_time, pred_fn, dataset_name=’’, X=None, y=None)

Generated CF Explanations class.

Arguments to Explanation:

  • cf_name: cf method’s name
  • dataset_name: dataset name
  • X: input
  • y: label
  • cfs: generated cf explanation of X
  • total_time: total runtime
  • pred_fn: predict function with only one input argument, and output a label (i.e., its format is y=pred_fn(x)).
  • data_module: data module

Parallelism Strategy


source

BASEGENERATIONSTRATEGY

CLASS relax.evaluate.BaseGenerationStrategy ()

Base class for mapping strategy.


source

ITERATIVEGENERATIONSTRATEGY

CLASS relax.evaluate.IterativeGenerationStrategy ()

Iterativly generate counterfactuals.


source

VMAPGENERATIONSTRATEGY

CLASS relax.evaluate.VmapGenerationStrategy ()

Generate counterfactuals via jax.vmap.


source

PMAPGENERATIONSTRATEGY

CLASS relax.evaluate.PmapGenerationStrategy (n_devices=None, strategy=‘auto’, **kwargs)

Base class for mapping strategy.

Parameters:
  • n_devices (int, default=None) – Number of devices. If None, use all available devices
  • strategy (str, default=auto) – Strategy to generate counterfactuals
  • kwargs

source

BATCHEDVMAPGENERATIONSTRATEGY

CLASS relax.evaluate.BatchedVmapGenerationStrategy (batch_size)

Auto-batching for generate counterfactuals via jax.vmap.


source

BATCHEDPMAPGENERATIONSTRATEGY

CLASS relax.evaluate.BatchedPmapGenerationStrategy (batch_size, n_devices=None)

Auto-batching for generate counterfactuals via jax.vmap.

os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'

w = jrand.normal(jrand.PRNGKey(0), (100, 100))
X = jrand.normal(jrand.PRNGKey(0), (1000, 100))

@jit
def pred_fn(x): return jnp.dot(x, w.T)

def f(x, pred_fn=None, **kwargs):
    return pred_fn(x)

iter_gen = IterativeGenerationStrategy()
vmap_gen = VmapGenerationStrategy()
pmap_gen = PmapGenerationStrategy()
bvmap_gen = BatchedVmapGenerationStrategy(128)
bpmap_gen = BatchedPmapGenerationStrategy(128)
cf_iter = iter_gen(f, X, pred_fn=pred_fn).block_until_ready()
cf_vmap = vmap_gen(f, X, pred_fn=pred_fn).block_until_ready()
cf_pmap = pmap_gen(f, X, pred_fn=pred_fn).block_until_ready()
cf_bvmap = bvmap_gen(f, X, pred_fn=pred_fn).block_until_ready()
# check when batch_size > X.shape[0]
_bvmap_gen = BatchedVmapGenerationStrategy(1280)
_cf_bvmap = _bvmap_gen(f, X, pred_fn=pred_fn).block_until_ready()
assert jnp.allclose(cf_bvmap, _cf_bvmap, atol=1e-4)
cf_bpmap = bpmap_gen(f, X, pred_fn=pred_fn).block_until_ready()
assert jnp.allclose(cf_iter, cf_vmap, atol=1e-4)
assert jnp.allclose(cf_iter, cf_bvmap, atol=1e-4)
assert jnp.allclose(cf_iter, cf_pmap, atol=1e-4)
assert jnp.allclose(cf_iter, cf_bpmap, atol=1e-4)

source

STRATEGYFACTORY

CLASS relax.evaluate.StrategyFactory ()

Factory class for Parallelism Strategy.

Generating CF Explanation Results


source

GENERATE_CF_EXPLANATIONS

relax.evaluate.generate_cf_explanations (cf_module, datamodule, pred_fn=None, strategy=‘vmap’, t_configs=None, pred_fn_args=None)

Generate CF explanations.

Parameters:
  • cf_module (BaseCFModule) – CF Explanation Module
  • datamodule (TabularDataModule) – Data Module
  • pred_fn (callable, default=None) – Predictive function
  • strategy (str | BaseGenerationStrategy, default=vmap) – Parallelism Strategy for generating CFs
  • t_configs (TrainingConfigs, default=None) – training configs for BaseParametricCFModule
  • pred_fn_args (dict, default=None) – auxiliary arguments for pred_fn
Returns:

    (Explanation)

The pred_fn in generate_cf_explanations is a model’s prediction function. The general format is y = pred_fn(x, **pred_fn_args). If pred_fn is not parameterized by other variables (except input x), then pred_fn_args is set to None, which is the default setting. Otherwise, you should pass these argument as a dict.

For example, we have a simple linear function

def linear_pred_fn(x: jnp.DeviceArray, params: jnp.DeviceArray):
    return x @ params

To pass linear_pred_fn to generate_cf_explanations, we can either create an auxiliary function of linear_pred_fn, or pass params into pred_fn_args.

Assuming we now have the input x and params:

x = jax.random.normal(random.PRNGKey(0), shape=(5, 10)) # input
params = jnp.ones((10, 1)) # params
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
  1. Create an auxillary function (Not recommended)
aux_linear_pred_fn = lambda x: linear_pred_fn(x, params)
explanations = generate_cf_explanations(
    cf_module, datamodule, aux_linear_pred_fn
)

This approach could work, but if params is changed, explanations.pred_fn might not work as expected.

  1. Pass params into pred_fn_args
explanations = generate_cf_explanations(
    cf_module, datamodule, linear_pred_fn, 
    pred_fn_args=dict(params=params)
)

This is a recommended approach as we will deepcopy params inside generate_cf_explanations.

The pred_fn in explanations only takes x: jnp.DeviceArray as an input. For example, to make predictions, we use

y = explanations.pred_fn(x)

Evaluating Metrics


source

BASEEVALMETRICS

CLASS relax.evaluate.BaseEvalMetrics (name=None)

Base evaluation metrics class.


source

PREDICTIVEACCURACY

CLASS relax.evaluate.PredictiveAccuracy (name=‘accuracy’)

Compute the accuracy of the predict function.


source

VALIDITY

CLASS relax.evaluate.Validity (name=‘validity’)

Compute fraction of input instances on which CF explanation methods output valid CF examples.

inputs = jnp.array([
    [0, 1], [1, 0], [1, 1]])
cfs_1 = jnp.array([
    [0, 0], [0, 0], [0, 1]])
cfs_2 = jnp.array([
    [1, 0], [1, -2], [0, 0]])
assert _compute_proximity(inputs, cfs_1) == 1.0
assert _compute_proximity(inputs, cfs_2) == 2.0

source

PROXIMITY

CLASS relax.evaluate.Proximity (name=‘proximity’)

Compute L1 norm distance between input datasets and CF examples divided by the number of features.


source

SPARSITY

CLASS relax.evaluate.Sparsity (name=‘sparsity’)

Compute the number of feature changes between input datasets and CF examples.


source

MANIFOLDDIST

CLASS relax.evaluate.ManifoldDist (n_neighbors=1, p=2, name=‘manifold_dist’)

Compute the L1 distance to the n-nearest neighbor for all CF examples.


source

RUNTIME

CLASS relax.evaluate.Runtime (name=‘runtime’)

Get the running time to generate CF examples.


source

COMPUTE_SO_SPARSITY

relax.evaluate.compute_so_sparsity (cf_results, threshold=2.0)


source

COMPUTE_SO_PROXIMITY

relax.evaluate.compute_so_proximity (cf_results, threshold=2.0)


source

COMPUTE_SO_VALIDITY

relax.evaluate.compute_so_validity (cf_results, threshold=2.0)

Benchmarking


source

EVALUATE_CFS

relax.evaluate.evaluate_cfs (cf_exp, metrics=None, return_dict=True, return_df=False)

Parameters:
  • cf_exp (Explanation) – CF Explanations
  • metrics (Iterable[Union[str, BaseEvalMetrics]], default=None) – A list of Metrics. Can be str or a subclass of BaseEvalMetrics
  • return_dict (bool, default=True) – return a dictionary or not (default: True)
  • return_df (bool, default=False) – return a pandas Dataframe or not (default: False)

source

BENCHMARK_CFS

relax.evaluate.benchmark_cfs (cf_results_list, metrics=None)

How to evaluate a CF Explanation Module

from relax.module import PredictiveTrainingModule
from relax.trainer import train_model
from relax.utils import load_json
configs = load_json('assets/configs/data_configs/adult.json')
m_configs = configs['mlp_configs']
data_configs = configs['data_configs']
data_configs['sample_frac'] = 0.1

t_configs = {
    'n_epochs': 10,
    'monitor_metrics': 'val/val_loss',
    'seed': 42,
    "batch_size": 256
}

We first train a model

training_module = PredictiveTrainingModule(m_configs)
dm = TabularDataModule(data_configs)

params, opt_state = train_model(
    training_module, 
    dm, 
    t_configs
)
pred_fn = lambda x, params, prng_key: \
    training_module.forward(params, prng_key, x, is_training=False)
Epoch 9: 100%|██████████| 96/96 [00:01<00:00, 53.81batch/s, train/train_loss_1=0.0791]

Now, we can start to benchmark different methods

from relax.methods import VanillaCF, CounterNet

Generate CF explanations for VanillaCF

vanillacf = VanillaCF(dict(n_steps=1000, lr=0.001))
vanillacf_exp = generate_cf_explanations(
    vanillacf, dm, pred_fn,
    pred_fn_args=dict(params=params, prng_key=random.PRNGKey(0))
)
100%|██████████| 1000/1000 [00:10<00:00, 92.93it/s]
assert vanillacf_exp.cf_name == vanillacf.name
assert vanillacf_exp.dataset_name == dm.data_name
assert vanillacf_exp.X.shape == vanillacf_exp.cfs.shape
assert vanillacf_exp.pred_fn(vanillacf_exp.X).shape == vanillacf_exp.y.shape

Generate CF explanations for CounterNet

counternet = CounterNet()
counternet_exp = generate_cf_explanations(counternet, dm, pred_fn=None)
CounterNet contains parametric models. Starts training before generating explanations...
Epoch 99: 100%|██████████| 191/191 [00:03<00:00, 58.07batch/s, train/train_loss_1=0.0657, train/train_loss_2=0.000985, train/train_loss_3=0.0963]

Note that CounterNet contains a predictive module, so we set pred_fn=None

assert counternet_exp.cf_name == counternet.name
assert counternet_exp.dataset_name == dm.data_name
assert counternet_exp.X.shape == counternet_exp.cfs.shape
assert counternet_exp.pred_fn(counternet_exp.X).shape == counternet_exp.y.shape
assert counternet_exp.pred_fn == counternet.pred_fn

If cf_module is a subclass of BasePredFnCFModule (e.g., CounterNet), the pred_fn in Explanation will be set to cf_module.pred_fn, and the pred_fn argument passed generate_cf_explanations will be ignored.

counternet_exp_1 = generate_cf_explanations(counternet, dm, pred_fn=pred_fn)
assert counternet_exp_1.pred_fn != pred_fn
assert counternet_exp_1.pred_fn == counternet.pred_fn
CounterNet contains parametric models. Starts training before generating explanations...
Epoch 99: 100%|██████████| 191/191 [00:02<00:00, 64.02batch/s, train/train_loss_1=0.0713, train/train_loss_2=0.000427, train/train_loss_3=0.0944]

Now, we can compute metrics for benchmarking different CF explanation methods.

evaluate_cfs(vanillacf_exp, return_df=True)[1]
acc validity proximity
adult VanillaCF 0.822012 0.93674 7.62256
evaluate_cfs(counternet_exp, return_df=True)[1]
acc validity proximity
adult CounterNet 0.831347 0.958605 5.9374576
benchmark_cfs([vanillacf_exp, counternet_exp])
acc validity proximity
adult VanillaCF 0.822012 0.936740 7.62256
CounterNet 0.831347 0.958605 5.9374576