ReLax

Recourse Explanation Library in Jax.

Python CI status Docs pypi GitHub License

Overview | Installation | Tutorials | Documentation | Citing ReLax

Overview

ReLax (Recourse Explanation Library in Jax) is a library built on top of jax to generate counterfactual and recourse explanations for Machine Learning algorithms. By leveraging vectorization though vmap/pmap and just-in-time compilation in jax (a high-performance auto-differentiation library). ReLax offers massive speed improvements in generating individual (or local) explanations for predictions made by Machine Learning algorithms.

Some of the key features are as follows:

  • 🏃 Fast recourse generation via jax.jit, jax.vmap/jax.pmap.

  • 🚀 Accelerated over cpu, gpu, tpu.

  • 🪓 Comprehensive set of recourse methods implemented for benchmarking.

  • 👐 Customizable API to enable the building of entire modeling

  • and interpretation pipelines for new recourse algorithms.

Installation

The latest ReLax release can directly be installed from PyPI:

pip install jax-relax

or installed directly from the repository:

pip install git+https://github.com/BirkhoffG/ReLax.git 

To futher unleash the power of accelerators (i.e., GPU/TPU), we suggest to first install this library via pip install jax-relax. Then, follow steps in the official install guidelines to install the right version for GPU or TPU.

An Example of using ReLax

See Getting Started with ReLax.

Citing ReLax

To cite this repository:

@software{relax2023github,
  author = {Hangzhi Guo and Xinchang Xiong and Amulya Yadav},
  title = {{R}e{L}ax: Recourse Explanation Library in Jax},
  url = {http://github.com/birkhoffg/ReLax},
  version = {0.1.0},
  year = {2023},
}