from relax.methods.base import CFModule, BaseConfig
from relax.utils import auto_reshaping, validate_configs
from relax.import_essentials import *
import relax
ReLax
as a Recourse Library
ReLax
contains implementations of various recourse methods, which are decoupled from the rest of ReLax
library. We give users flexibility on how to use ReLax
:
- You can use the recourse pipeline in
ReLax
(“one-liner” for easy benchmarking recourse methods; see this tutorial). - You can use all of the recourse methods in
ReLax
without relying on the entire pipeline ofReLax
.
In this tutorial, we uncover the possibility of the second option by using recourse methods under relax.methods
for debugging, diagnosing, interpreting your JAX models.
Types of Recourse Methods
Non-parametric methods: These methods do not rely on any learned parameters. They generate counterfactuals solely based on the model’s predictions and gradients. Examples in ReLax include
VanillaCF
,DiverseCF
andGrowingSphere
. These methods inherit fromCFModule
.Semi-parametric methods: These methods learn some parameters to aid in counterfactual generation, but do not learn a full counterfactual generation model. Examples in ReLax include
ProtoCF
,CCHVAE
andCLUE
. These methods inherit fromParametricCFModule
.Parametric methods: These methods learn a full parametric model for counterfactual generation. The model is trained to generate counterfactuals that fool the model. Examples in ReLax include
CounterNet
andVAECF
. These methods inherit fromParametricCFModule
.
Method Type | Learned Parameters | Training Required | Example Methods |
---|---|---|---|
Non-parametric | None | No | VanillaCF , DiverseCF , GrowingSphere |
Semi-parametric | Some (θ) | Modest amount | ProtoCF , CCHVAE , CLUE |
Parametric | Full generator model (φ) | Substantial amount | CounterNet , VAECF |
Basic Usages
At a high level, you can use the implemented methods in ReLax
to generate one recourse explanation via three lines of code:
from relax.methods import VanillaCF
= VanillaCF()
vcf # x is one data point. Shape: `(K)` or `(1, K)`
= vcf.generate_cf(x, pred_fn=pred_fn) cf
Or generate a batch of recourse explanation via the jax.vmap
primitive:
...import functools as ft
= ft.partial(vcf.generate_cf, pred_fn=pred_fn)
vcf_gen_fn # xs is a batched data. Shape: `(N, K)`
= jax.vmap(vcf_gen_fn)(xs) cfs
To use parametric and semi-parametric methods, you can first train the model by calling ParametricCF.train
, and then generate recourse explanations. Here is an example of using ReLax
for CCHVAE
.
from relax.methods import CCHVAE
= CCHVAE()
cchvae # Train CVAE before generation
cchvae.train(train_data) = cchvae.generate_cf(x, pred_fn=pred_fn) cf
Or generate a batch of recourse explanation via the jax.vmap
primitive:
...import functools as ft
= ft.partial(cchvae.generate_cf, pred_fn=pred_fn)
cchvae_gen_fn = jax.vmap(cchvae_gen_fn)(xs) # Generate counterfactuals cfs
Config Recourse Methods
Each recourse method in ReLax
has an associated Config class that defines the set of supported configuration parameters. To configure a method, import and instantiate its Config class and pass it as the config parameter.
For example, to configure VanillaCF
:
from relax.methods import VanillaCF
from relax.methods.vanilla import VanillaCFConfig
= VanillaCFConfig(
config =100,
n_steps=0.1,
lr=0.1
lambda_
)
= VanillaCF(config) vcf
Each Config class inherits from a BaseConfig
that defines common options like n_steps. Method-specific parameters are defined on the individual Config classes.
See the documentation for each recourse method for details on its supported configuration parameters. The Config class for a method can be imported from relax.methods.[method_name]
.
Alternatively, we can also specify this config via a dictionary.
from relax.methods import VanillaCF
= {
config "n_steps": 10,
"lambda_": 0.1,
"lr": 0.1
}
= VanillaCF(config) vcf
This config dictionary is passed to VanillaCF’s init method, which will set the specified parameters. Now our VanillaCF
instance is configured to:
- Number 10 optimization steps (n_steps=100)
- Use 0.1 validity regularization for counterfactuals (lambda_=0.1)
- Use a learning rate of 0.1 for optimization (lr=0.1)
Implement your Own Recourse Methods
You can easily implement your own recourse methods and leverage jax_relax
to scale the recourse generation. In this section, we implement a mock “Recourse Method”, which add random perturbations to the input x
.
First, we define a configuration class for the random counterfactual module. This class inherits from the BaseConfig
class.
class RandomCFConfig(BaseConfig):
float = 0.2 # Max perturbation allowed for RandomCF max_perturb:
Next, we define the random counterfactual module. This class inhertis from CFModule
class. Importantly, you should override the CFModule.generate_cf
and implement your CF generation procedure for each input (i.e., shape=(k,)
, where k
is the number of features).
class RandomCF(CFModule):
def __init__(
self,
dict | RandomCFConfig = None,
config: str = None,
name:
):if config is None:
= RandomCFConfig()
config = validate_configs(config, RandomCFConfig)
config = "RandomCF" if name is None else name
name super().__init__(config, name=name)
@auto_reshaping('x')
def generate_cf(
self,
# Input data point
x: Array, = None, # Prediction function
pred_fn: Callable = None, # Target label
y_target: Array = None, # Random key
rng_key: jrand.PRNGKey **kwargs,
-> Array:
) # Generate random perturbations in the range of [-max_perturb, max_perturb].
= x + jrand.uniform(rng_key, x.shape,
x_cf =-self.config.max_perturb,
minval=self.config.max_perturb)
maxvalreturn x_cf
Finally, you can easily use jax-relax
to generate recourse explanations at scale.
= RandomCF()
rand_cf = relax.generate_cf_explanations(
exps 'dummy'), relax.load_ml_module('dummy').pred_fn,
rand_cf, relax.load_data(
) relax.benchmark_cfs([exps])
| | acc | validity | proximity |
|:----------------------|------:|-----------:|------------:|
| ('dummy', 'RandomCF') | 0.983 | 0.0599999 | 0.997049 |