from relax.methods import VanillaCF
from relax import DataModule, MLModule, generate_cf_explanations, benchmark_cfs
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
import functools as ft
import jax
ReLax
Overview | Installation | Tutorials | Documentation | Citing ReLax
Overview
ReLax
(Recourse Explanation Library in Jax) is an efficient and scalable benchmarking library for recourse and counterfactual explanations, built on top of jax. By leveraging language primitives such as vectorization, parallelization, and just-in-time compilation in jax, ReLax
offers massive speed improvements in generating individual (or local) explanations for predictions made by Machine Learning algorithms.
Some of the key features are as follows:
π Fast and scalable recourse generation.
π Accelerated over
cpu
,gpu
,tpu
.πͺ Comprehensive set of recourse methods implemented for benchmarking.
π Customizable API to enable the building of entire modeling and interpretation pipelines for new recourse algorithms.
Installation
pip install jax-relax
# Or install the latest version of `jax-relax`
pip install git+https://github.com/BirkhoffG/jax-relax.git
To futher unleash the power of accelerators (i.e., GPU/TPU), we suggest to first install this library via pip install jax-relax
. Then, follow steps in the official install guidelines to install the right version for GPU or TPU.
Dive into ReLax
ReLax
is a recourse explanation library for explaining (any) JAX-based ML models. We believe that it is important to give users flexibility to choose how to use ReLax
. You can
- only use methods implemeted in
ReLax
(as a recourse methods library); - build a pipeline using
ReLax
to define data module, training ML models, and generating CF explanation (for constructing recourse benchmarking pipeline).
ReLax
as a Recourse Explanation Library
We introduce basic use cases of using methods in ReLax
to generate recourse explanations. For more advanced usages of methods in ReLax
, See this tutorials.
Letβs first generate synthetic data:
= make_classification(n_samples=1000, n_features=10, random_state=42)
xs, ys = train_test_split(xs, ys, random_state=42) train_xs, test_xs, train_ys, test_ys
Next, we fit an MLP model for this data. Note that this model can be any model implmented in JAX. We will use the MLModule
in ReLax
as an example.
= MLModule()
model =10, batch_size=64) model.train((train_xs, train_ys), epochs
Generating recourse explanations are straightforward. We can simply call generate_cf
of an implemented recourse method to generate one recourse explanation:
= VanillaCF(config={'n_steps': 1000, 'lr': 0.05})
vcf = vcf.generate_cf(test_xs[0], model.pred_fn)
cf assert cf.shape == test_xs[0].shape
Or generate a bunch of recourse explanations with jax.vmap
:
= ft.partial(vcf.generate_cf, pred_fn=model.pred_fn)
generate_fn = jax.vmap(generate_fn)(test_xs)
cfs assert cfs.shape == test_xs.shape
ReLax
for Building Recourse Explanation Pipelines
The above example illustrates the usage of the decoupled relax.methods
to generate recourse explanations. However, users are required to write boilerplate code for tasks such as data preprocessing, model training, and generating recourse explanations with feature constraints.
ReLax
additionally offers a one-liner framework, streamlining the process and helping users in building a standardized pipeline for generating recourse explanations. You can write three lines of code to benchmark recourse explanations:
= DataModule.from_numpy(xs, ys)
data_module = generate_cf_explanations(vcf, data_module, model.pred_fn)
exps benchmark_cfs([exps])
See Getting Started with ReLax for an end-to-end example of using ReLax
.
Supported Recourse Methods
ReLax
currently provides implementations of 9 recourse explanation methods.
Method | Type | Paper Title | Ref |
---|---|---|---|
VanillaCF |
Non-Parametric | Counterfactual Explanations without Opening the Black Box: Automated Decisions and the GDPR. | [1] |
DiverseCF |
Non-Parametric | Explaining Machine Learning Classifiers through Diverse Counterfactual Explanations. | [2] |
ProtoCF |
Semi-Parametric | Interpretable Counterfactual Explanations Guided by Prototypes. | [3] |
CounterNet |
Parametric | CounterNet: End-to-End Training of Prediction Aware Counterfactual Explanations. | [4] |
GrowingSphere |
Non-Parametric | Inverse Classification for Comparison-based Interpretability in Machine Learning. | [5] |
CCHVAE |
Semi-Parametric | Learning Model-Agnostic Counterfactual Explanations for Tabular Data. | [6] |
VAECF |
Parametric | Preserving Causal Constraints in Counterfactual Explanations for Machine Learning Classifiers. | [7] |
CLUE |
Semi-Parametric | Getting a CLUE: A Method for Explaining Uncertainty Estimates. | [8] |
L2C |
Parametric | Feature-based Learning for Diverse and Privacy-Preserving Counterfactual Explanations | [9] |
Citing ReLax
To cite this repository:
@software{relax2023github,
author = {Hangzhi Guo and Xinchang Xiong and Amulya Yadav},
title = {{R}e{L}ax: Recourse Explanation Library in Jax},
url = {http://github.com/birkhoffg/jax-relax},
version = {0.2.0},
year = {2023}, }