Getting started

A basic tutorial of key features in ReLax

This tutorial aims at introducing basics about ReLax, and how to use ReLax to generate counterfactual (or recourse) explanations for jax-based implementations of ML models.

In particular, we will cover the following things in this tutorial:

  1. Loading datasets with TabularDataModule;
  2. Training machine learning classifiers;
  3. Generating counterfactual (or recourse) explanations;
  4. Benchmarking different recourse methods.

Preparation

We assume that you have already installed ReLax. If not, follow the steps in this installation tutorial, or just enter pip install jax-relax.

We also want to import some libraries for this tutorial.

import jax

Load Dataset with TabularDataModule

TabularDataModule is a python class which modularizes tabular dataset loading. TabularDataModule loads a .csv file from the directory by specifying the following attributes:

  • data_name is the name of your dataset.
  • data_dir should contain the relative path of the directory where your dataset is located.
  • continous_cols specifies a list of feature names representing all the continuous/numeric features in our dataset.
  • discret_cols specifies a list of feature names representing all discrete features in our dataset. By default, all discrete features are converted via one-hot encoding for training purposes.
  • imutable_cols specifies a list of feature names that represent immutable features that we do not wish to change in the generated recourse.
from relax.data import TabularDataModuleConfigs, TabularDataModule, load_data

For example, to load the adult dataset, we can specify the TabularDataModuleConfigs as

data_config = TabularDataModuleConfigs(
    # The name of this dataset is "adult"
    data_name="adult",
    # The data file is located in `../assets/data/s_adult.csv`.
    data_dir="../assets/data/s_adult.csv",
    # Contains 2 features with continuous variables
    continous_cols=["age","hours_per_week"],
    # Contains 6 features with categorical (discrete) variables
    discret_cols=["workclass","education","marital_status","occupation","race","gender"],
    # Contains 2 features that we do not wish to change
    imutable_cols=["race", "gender"]
)

We can then pass data_configs to the TabularDataModule.

datamodule = TabularDataModule(data_config)

Alternatively, we can also specify this config via a dictionary.

# This approach is equivalent to using `TabularDataModuleConfigs`
data_config_dict = {
    "data_name": "adult",
    "data_dir": "../assets/data/s_adult.csv",
    "continous_cols": ["age","hours_per_week"],
    "discret_cols": ["workclass","education","marital_status","occupation","race","gender"],
    "imutable_cols": ["race","gender"]
}
datamodule = TabularDataModule(data_config_dict)

For datasets supported by ReLax, we can simply call load_data:

# This is equivalent to specifying configs for `TabularDataModule`
datamodule = load_data('adult')

For more usage of loading datasets in ReLax, check out the data module documentation.

Train the Classifier

For the purpose of exposing full functionality of the framework, we will train the model using the built-in functions in ReLax, which uses haiku for building neural network blocks. However, the recourse algorithms in ReLax can generate explanations for all jax-based framework (e.g., flax, haiku, vanilla jax).

Warning

The recourse algorithms in ReLax currently only supports binary classification. The output of the classifier must be a probability score (bounded by [0, 1]). Future support for multi-class classification is planned.

Training a classifier using the built-in functions in ReLax is very simple. We will first specify the classifier. The classifier is called PredictiveTrainingModule, which specifies the model structure, and the optimization procedure (e.g., it specifies the loss function for optimizing the model). Next, we use train_model to train the model on TabularDataModule.

Define the Model

from relax.module import PredictiveTrainingModuleConfigs, PredictiveTrainingModule

Defining PredictiveTrainingModule is similar to defining TabularDataModule. We first specify the configurator as PredictiveTrainingModuleConfigs, and pass this configurator to PredictiveTrainingModule.

model_config = PredictiveTrainingModuleConfigs(
    lr=0.01, # Learning rate
    sizes=[50, 10, 50], # The sizes of the hidden layers
    dropout_rate=0.3 # Dropout rate
)

# specify the predictive model
module = PredictiveTrainingModule(model_config)
Note

The training step for each batch is specified in PredictiveTrainingModule. Essentially, it will compute the binary cross-entropy loss for each batch, and apply backpropagation (via adam) to update parameters of the model.

Train the Model

from relax.trainer import TrainingConfigs, train_model

To train PredictiveTrainingModule for the entire dataset (specified in TabularDataModule), we can simply call train_model:

trainer_config = TrainingConfigs(
    n_epochs=10, # Number of epochs
    batch_size=256, # Batch size
    monitor_metrics='val/val_loss', # The metric to monitor
    logger_name='pred' # The name of the logger
)

# train the model
params, opt_state = train_model(
    module, datamodule, trainer_config
)
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
2023-04-24 10:46:46.661007: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda/lib64:
2023-04-24 10:46:46.661163: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda/lib64:
2023-04-24 10:46:46.661171: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
Epoch 9: 100%|██████████| 96/96 [00:00<00:00, 218.13batch/s, train/train_loss_1=0.0634]

Make Predictions

We can directly use module.pred_fn for making the predictions.

pred_fn = module.pred_fn

Generate Counterfactual Explanations

Now, it is time to use ReLax to generate counterfactual explanations (or recourse).

from relax.methods import VanillaCF, VanillaCFConfig

We use VanillaCF (a very popular recourse generation algorithm) as an example for this tutorial. Defining VanillaCF is similar to defining TabularDataModule and PredictiveTrainingModule.

cf_config = VanillaCFConfig(
    n_steps=1000, # Number of steps
    lr=0.001 # Learning rate
)
cf_exp = VanillaCF(cf_config)

Generate counterfactual examples.

from relax.evaluate import generate_cf_explanations
cf_results = generate_cf_explanations(
    cf_exp, datamodule, pred_fn, 
    pred_fn_args={
        'params': params, 'rng_key': jax.random.PRNGKey(0)
    }
)
100%|██████████| 1000/1000 [00:10<00:00, 96.49it/s]

Benchmark the Counterfactual Method

After we obtain the counterfactual results, we can use benchmark_cfs to evaluate the accuracy, validity, and proximity of the counterfactual example.

from relax.evaluate import benchmark_cfs
benchmark_cfs([cf_results])
acc validity proximity
adult VanillaCF 0.829751 0.999754 7.896814