Getting started

An end-to-end tutorial which demonstrates key features in ReLax.

This tutorial aims at introducing basics about ReLax, and how to use ReLax to generate counterfactual (or recourse) explanations for jax-based implementations of ML models.

In particular, we will cover the following things in this tutorial:

  1. Loading datasets with DataModule;
  2. Training machine learning classifiers;
  3. Generating counterfactual (or recourse) explanations;
  4. Benchmarking different recourse methods.

Preparation

We assume that you have already installed ReLax. If not, follow the steps in this installation tutorial, or just enter pip install jax-relax.

We also want to import some libraries for this tutorial.

import jax

Load Dataset with DataModule

DataModule is a python class which modularizes tabular dataset loading. DataModule loads a .csv file from the directory by specifying the following attributes:

  • data_name is the name of your dataset.
  • data_dir should contain the relative path of the directory where your dataset is located.
  • continous_cols specifies a list of feature names representing all the continuous/numeric features in our dataset.
  • discret_cols specifies a list of feature names representing all discrete features in our dataset. By default, all discrete features are converted via one-hot encoding for training purposes.
  • imutable_cols specifies a list of feature names that represent immutable features that we do not wish to change in the generated recourse.
from relax.data_module import DataModuleConfig, DataModule, load_data

For example, to load the adult dataset, we can specify the DataModuleConfig as

data_config = DataModuleConfig(
    # The name of this dataset is "adult"
    data_name="adult",
    # The data file is located in `../assets/data/s_adult.csv`.
    data_dir="../assets/adult/data/data.csv",
    # Contains 2 features with continuous variables
    continous_cols=["age","hours_per_week"],
    # Contains 6 features with categorical (discrete) variables
    discret_cols=["workclass","education","marital_status","occupation","race","gender"],
    # Contains 2 features that we do not wish to change
    imutable_cols=["race", "gender"]
)

We can then pass data_configs to the DataModule.

datamodule = DataModule.from_config(data_config)
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

Alternatively, we can also specify this config via a dictionary.

# This approach is equivalent to using `DataModuleConfig`
data_config_dict = {
    "data_name": "adult",
    "data_dir": "../assets/adult/data/data.csv",
    "continous_cols": ["age","hours_per_week"],
    "discret_cols": ["workclass","education","marital_status","occupation","race","gender"],
    "imutable_cols": ["race","gender"]
}
datamodule = DataModule.from_config(data_config)

For datasets supported by ReLax, we can simply call load_data:

# This is equivalent to specifying configs for `DataModule`
datamodule = load_data('adult')

For more usage of loading datasets in ReLax, check out the data module documentation.

Train the Classifier

For the purpose of exposing full functionality of the framework, we will train the model using the built-in functions in ReLax, which uses haiku for building neural network blocks. However, the recourse algorithms in ReLax can generate explanations for all jax-based framework (e.g., flax, haiku, vanilla jax).

Warning

The recourse algorithms in ReLax currently only supports binary classification. The output of the classifier must be a probability score (bounded by [0, 1]). Future support for multi-class classification is planned.

Training a classifier using the built-in functions in ReLax is very simple. We will first specify the classifier. The classifier is called PredictiveTrainingModule, which specifies the model structure, and the optimization procedure (e.g., it specifies the loss function for optimizing the model). Next, we use train_model to train the model on TabularDataModule.

Define the Model

from relax.ml_model import MLModuleConfig, MLModule

Defining MLModule is similar to defining MLModuleConfig. We first specify the configurator as MLModuleConfig, and pass this configurator to MLModuleConfig.

model_config = MLModuleConfig(
    lr=0.01, # Learning rate
    sizes=[50, 10, 50], # The sizes of the hidden layers
    dropout_rate=0.3 # Dropout rate
)

# specify the predictive model
module = MLModule(model_config)

Train the Model

To train MLModule for the entire dataset (specified in DataModule), we can simply call MLModule.train:

module.train(datamodule, batch_size=128, epochs=5)
Epoch 1/5
191/191 ━━━━━━━━━━━━━━━━━━━━ 3s 11ms/step - accuracy: 0.7807 - loss: 0.4597
Epoch 2/5
191/191 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - accuracy: 0.8135 - loss: 0.3945
Epoch 3/5
191/191 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - accuracy: 0.8202 - loss: 0.3769
Epoch 4/5
191/191 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - accuracy: 0.8205 - loss: 0.3784
Epoch 5/5
191/191 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - accuracy: 0.8167 - loss: 0.3817
<relax.ml_model.MLModule>

Make Predictions

We can directly use module.pred_fn for making the predictions.

pred_fn = module.pred_fn

Generate Counterfactual Explanations

Now, it is time to use ReLax to generate counterfactual explanations (or recourse).

from relax.methods import VanillaCF, VanillaCFConfig

We use VanillaCF (a very popular recourse generation algorithm) as an example for this tutorial. Defining VanillaCF is similar to defining TabularDataModule and PredictiveTrainingModule.

cf_config = VanillaCFConfig(
    n_steps=1000, # Number of steps
    lr=0.001 # Learning rate
)
cf_exp = VanillaCF(cf_config)

Generate counterfactual examples.

from relax.explain import generate_cf_explanations
cf_results = generate_cf_explanations(
    cf_exp, datamodule, pred_fn, 
)

Benchmark the Counterfactual Method

After we obtain the counterfactual results, we can use benchmark_cfs to evaluate the accuracy, validity, and proximity of the counterfactual example.

from relax.evaluate import benchmark_cfs
benchmark_cfs([cf_results])
acc validity proximity
adult VanillaCF 0.828261 0.814963 4.79361