from relax.methods.base import CFModule, BaseConfig
from relax.utils import auto_reshaping, validate_configs
from relax.import_essentials import *
import relaxReLax as a Recourse Library
ReLax contains implementations of various recourse methods, which are decoupled from the rest of ReLax library. We give users flexibility on how to use ReLax:
- You can use the recourse pipeline in
ReLax(“one-liner” for easy benchmarking recourse methods; see this tutorial). - You can use all of the recourse methods in
ReLaxwithout relying on the entire pipeline ofReLax.
In this tutorial, we uncover the possibility of the second option by using recourse methods under relax.methods for debugging, diagnosing, interpreting your JAX models.
Types of Recourse Methods
Non-parametric methods: These methods do not rely on any learned parameters. They generate counterfactuals solely based on the model’s predictions and gradients. Examples in ReLax include
VanillaCF,DiverseCFandGrowingSphere. These methods inherit fromCFModule.Semi-parametric methods: These methods learn some parameters to aid in counterfactual generation, but do not learn a full counterfactual generation model. Examples in ReLax include
ProtoCF,CCHVAEandCLUE. These methods inherit fromParametricCFModule.Parametric methods: These methods learn a full parametric model for counterfactual generation. The model is trained to generate counterfactuals that fool the model. Examples in ReLax include
CounterNetandVAECF. These methods inherit fromParametricCFModule.
| Method Type | Learned Parameters | Training Required | Example Methods |
|---|---|---|---|
| Non-parametric | None | No | VanillaCF, DiverseCF, GrowingSphere |
| Semi-parametric | Some (θ) | Modest amount | ProtoCF, CCHVAE, CLUE |
| Parametric | Full generator model (φ) | Substantial amount | CounterNet, VAECF |
Basic Usages
At a high level, you can use the implemented methods in ReLax to generate one recourse explanation via three lines of code:
from relax.methods import VanillaCF
vcf = VanillaCF()
# x is one data point. Shape: `(K)` or `(1, K)`
cf = vcf.generate_cf(x, pred_fn=pred_fn)Or generate a batch of recourse explanation via the jax.vmap primitive:
...
import functools as ft
vcf_gen_fn = ft.partial(vcf.generate_cf, pred_fn=pred_fn)
# xs is a batched data. Shape: `(N, K)`
cfs = jax.vmap(vcf_gen_fn)(xs)To use parametric and semi-parametric methods, you can first train the model by calling ParametricCF.train, and then generate recourse explanations. Here is an example of using ReLax for CCHVAE.
from relax.methods import CCHVAE
cchvae = CCHVAE()
cchvae.train(train_data) # Train CVAE before generation
cf = cchvae.generate_cf(x, pred_fn=pred_fn) Or generate a batch of recourse explanation via the jax.vmap primitive:
...
import functools as ft
cchvae_gen_fn = ft.partial(cchvae.generate_cf, pred_fn=pred_fn)
cfs = jax.vmap(cchvae_gen_fn)(xs) # Generate counterfactualsConfig Recourse Methods
Each recourse method in ReLax has an associated Config class that defines the set of supported configuration parameters. To configure a method, import and instantiate its Config class and pass it as the config parameter.
For example, to configure VanillaCF:
from relax.methods import VanillaCF
from relax.methods.vanilla import VanillaCFConfig
config = VanillaCFConfig(
n_steps=100,
lr=0.1,
lambda_=0.1
)
vcf = VanillaCF(config)Each Config class inherits from a BaseConfig that defines common options like n_steps. Method-specific parameters are defined on the individual Config classes.
See the documentation for each recourse method for details on its supported configuration parameters. The Config class for a method can be imported from relax.methods.[method_name].
Alternatively, we can also specify this config via a dictionary.
from relax.methods import VanillaCF
config = {
"n_steps": 10,
"lambda_": 0.1,
"lr": 0.1
}
vcf = VanillaCF(config)This config dictionary is passed to VanillaCF’s init method, which will set the specified parameters. Now our VanillaCF instance is configured to:
- Number 10 optimization steps (n_steps=100)
- Use 0.1 validity regularization for counterfactuals (lambda_=0.1)
- Use a learning rate of 0.1 for optimization (lr=0.1)
Implement your Own Recourse Methods
You can easily implement your own recourse methods and leverage jax_relax to scale the recourse generation. In this section, we implement a mock “Recourse Method”, which add random perturbations to the input x.
First, we define a configuration class for the random counterfactual module. This class inherits from the BaseConfig class.
class RandomCFConfig(BaseConfig):
max_perturb: float = 0.2 # Max perturbation allowed for RandomCFNext, we define the random counterfactual module. This class inhertis from CFModule class. Importantly, you should override the CFModule.generate_cf and implement your CF generation procedure for each input (i.e., shape=(k,), where k is the number of features).
class RandomCF(CFModule):
def __init__(
self,
config: dict | RandomCFConfig = None,
name: str = None,
):
if config is None:
config = RandomCFConfig()
config = validate_configs(config, RandomCFConfig)
name = "RandomCF" if name is None else name
super().__init__(config, name=name)
@auto_reshaping('x')
def generate_cf(
self,
x: Array, # Input data point
pred_fn: Callable = None, # Prediction function
y_target: Array = None, # Target label
rng_key: jrand.PRNGKey = None, # Random key
**kwargs,
) -> Array:
# Generate random perturbations in the range of [-max_perturb, max_perturb].
x_cf = x + jrand.uniform(rng_key, x.shape,
minval=-self.config.max_perturb,
maxval=self.config.max_perturb)
return x_cfFinally, you can easily use jax-relax to generate recourse explanations at scale.
rand_cf = RandomCF()
exps = relax.generate_cf_explanations(
rand_cf, relax.load_data('dummy'), relax.load_ml_module('dummy').pred_fn,
)
relax.benchmark_cfs([exps])| | acc | validity | proximity |
|:----------------------|------:|-----------:|------------:|
| ('dummy', 'RandomCF') | 0.983 | 0.0599999 | 0.997049 |