ReLax as a Recourse Library

ReLax contains implementations of various recourse methods, which are decoupled from the rest of ReLax library. We give users flexibility on how to use ReLax:

In this tutorial, we uncover the possibility of the second option by using recourse methods under relax.methods for debugging, diagnosing, interpreting your JAX models.

Types of Recourse Methods

  1. Non-parametric methods: These methods do not rely on any learned parameters. They generate counterfactuals solely based on the model’s predictions and gradients. Examples in ReLax include VanillaCF, DiverseCF and GrowingSphere . These methods inherit from CFModule.

  2. Semi-parametric methods: These methods learn some parameters to aid in counterfactual generation, but do not learn a full counterfactual generation model. Examples in ReLax include ProtoCF, CCHVAE and CLUE. These methods inherit from ParametricCFModule.

  3. Parametric methods: These methods learn a full parametric model for counterfactual generation. The model is trained to generate counterfactuals that fool the model. Examples in ReLax include CounterNet and VAECF. These methods inherit from ParametricCFModule.

Method Type Learned Parameters Training Required Example Methods
Non-parametric None No VanillaCF, DiverseCF, GrowingSphere
Semi-parametric Some (θ) Modest amount ProtoCF, CCHVAE, CLUE
Parametric Full generator model (φ) Substantial amount CounterNet, VAECF

Basic Usages

At a high level, you can use the implemented methods in ReLax to generate one recourse explanation via three lines of code:

from relax.methods import VanillaCF

vcf = VanillaCF()
# x is one data point. Shape: `(K)` or `(1, K)`
cf = vcf.generate_cf(x, pred_fn=pred_fn)

Or generate a batch of recourse explanation via the jax.vmap primitive:

...
import functools as ft

vcf_gen_fn = ft.partial(vcf.generate_cf, pred_fn=pred_fn)
# xs is a batched data. Shape: `(N, K)`
cfs = jax.vmap(vcf_gen_fn)(xs)

To use parametric and semi-parametric methods, you can first train the model by calling ParametricCF.train, and then generate recourse explanations. Here is an example of using ReLax for CCHVAE.

from relax.methods import CCHVAE

cchvae = CCHVAE()
cchvae.train(train_data) # Train CVAE before generation
cf = cchvae.generate_cf(x, pred_fn=pred_fn) 

Or generate a batch of recourse explanation via the jax.vmap primitive:

...
import functools as ft

cchvae_gen_fn = ft.partial(cchvae.generate_cf, pred_fn=pred_fn)
cfs = jax.vmap(cchvae_gen_fn)(xs) # Generate counterfactuals

Config Recourse Methods

Each recourse method in ReLax has an associated Config class that defines the set of supported configuration parameters. To configure a method, import and instantiate its Config class and pass it as the config parameter.

For example, to configure VanillaCF:

from relax.methods import VanillaCF 
from relax.methods.vanilla import VanillaCFConfig

config = VanillaCFConfig(
  n_steps=100,
  lr=0.1,
  lambda_=0.1
)

vcf = VanillaCF(config)

Each Config class inherits from a BaseConfig that defines common options like n_steps. Method-specific parameters are defined on the individual Config classes.

See the documentation for each recourse method for details on its supported configuration parameters. The Config class for a method can be imported from relax.methods.[method_name].

Alternatively, we can also specify this config via a dictionary.

from relax.methods import VanillaCF

config = {
  "n_steps": 10,  
  "lambda_": 0.1,
  "lr": 0.1   
}

vcf = VanillaCF(config)

This config dictionary is passed to VanillaCF’s init method, which will set the specified parameters. Now our VanillaCF instance is configured to:

  • Number 10 optimization steps (n_steps=100)
  • Use 0.1 validity regularization for counterfactuals (lambda_=0.1)
  • Use a learning rate of 0.1 for optimization (lr=0.1)

Implement your Own Recourse Methods

You can easily implement your own recourse methods and leverage jax_relax to scale the recourse generation. In this section, we implement a mock “Recourse Method”, which add random perturbations to the input x.

from relax.methods.base import CFModule, BaseConfig
from relax.utils import auto_reshaping, validate_configs
from relax.import_essentials import *
import relax

First, we define a configuration class for the random counterfactual module. This class inherits from the BaseConfig class.

class RandomCFConfig(BaseConfig):
    max_perturb: float = 0.2 # Max perturbation allowed for RandomCF

Next, we define the random counterfactual module. This class inhertis from CFModule class. Importantly, you should override the CFModule.generate_cf and implement your CF generation procedure for each input (i.e., shape=(k,), where k is the number of features).

class RandomCF(CFModule):

    def __init__(
        self,
        config: dict | RandomCFConfig = None,
        name: str = None,
    ):
        if config is None:
            config = RandomCFConfig()
        config = validate_configs(config, RandomCFConfig)
        name = "RandomCF" if name is None else name
        super().__init__(config, name=name)

    @auto_reshaping('x')
    def generate_cf(
        self,
        x: Array, # Input data point
        pred_fn: Callable = None, # Prediction function
        y_target: Array = None,   # Target label
        rng_key: jrand.PRNGKey = None, # Random key
        **kwargs,
    ) -> Array:
        # Generate random perturbations in the range of [-max_perturb, max_perturb].
        x_cf = x + jrand.uniform(rng_key, x.shape, 
                                 minval=-self.config.max_perturb, 
                                 maxval=self.config.max_perturb)
        return x_cf

Finally, you can easily use jax-relax to generate recourse explanations at scale.

rand_cf = RandomCF()
exps = relax.generate_cf_explanations(
    rand_cf, relax.load_data('dummy'), relax.load_ml_module('dummy').pred_fn, 
)
relax.benchmark_cfs([exps])
|                       |   acc |   validity |   proximity |
|:----------------------|------:|-----------:|------------:|
| ('dummy', 'RandomCF') | 0.983 |  0.0599999 |    0.997049 |