import jax
Getting started
ReLax
This tutorial aims at introducing basics about ReLax
, and how to use ReLax
to generate counterfactual (or recourse) explanations for jax
-based implementations of ML models.
In particular, we will cover the following things in this tutorial:
- Loading datasets with
TabularDataModule
; - Training machine learning classifiers;
- Generating counterfactual (or recourse) explanations;
- Benchmarking different recourse methods.
Preparation
We assume that you have already installed ReLax
. If not, follow the steps in this installation tutorial, or just enter pip install jax-relax
.
We also want to import some libraries for this tutorial.
Load Dataset with TabularDataModule
TabularDataModule
is a python class which modularizes tabular dataset loading. TabularDataModule
loads a .csv
file from the directory by specifying the following attributes:
data_name
is the name of your dataset.data_dir
should contain the relative path of the directory where your dataset is located.continous_cols
specifies a list of feature names representing all the continuous/numeric features in our dataset.discret_cols
specifies a list of feature names representing all discrete features in our dataset. By default, all discrete features are converted via one-hot encoding for training purposes.imutable_cols
specifies a list of feature names that represent immutable features that we do not wish to change in the generated recourse.
from relax.data import TabularDataModuleConfigs, TabularDataModule, load_data
For example, to load the adult dataset, we can specify the TabularDataModuleConfigs
as
= TabularDataModuleConfigs(
data_config # The name of this dataset is "adult"
="adult",
data_name# The data file is located in `../assets/data/s_adult.csv`.
="../assets/data/s_adult.csv",
data_dir# Contains 2 features with continuous variables
=["age","hours_per_week"],
continous_cols# Contains 6 features with categorical (discrete) variables
=["workclass","education","marital_status","occupation","race","gender"],
discret_cols# Contains 2 features that we do not wish to change
=["race", "gender"]
imutable_cols )
We can then pass data_configs
to the TabularDataModule
.
= TabularDataModule(data_config) datamodule
Alternatively, we can also specify this config via a dictionary.
# This approach is equivalent to using `TabularDataModuleConfigs`
= {
data_config_dict "data_name": "adult",
"data_dir": "../assets/data/s_adult.csv",
"continous_cols": ["age","hours_per_week"],
"discret_cols": ["workclass","education","marital_status","occupation","race","gender"],
"imutable_cols": ["race","gender"]
}= TabularDataModule(data_config_dict) datamodule
For datasets supported by ReLax
, we can simply call load_data
:
# This is equivalent to specifying configs for `TabularDataModule`
= load_data('adult') datamodule
For more usage of loading datasets in ReLax
, check out the data module documentation.
Train the Classifier
For the purpose of exposing full functionality of the framework, we will train the model using the built-in functions in ReLax
, which uses haiku for building neural network blocks. However, the recourse algorithms in ReLax
can generate explanations for all jax-based framework (e.g., flax, haiku, vanilla jax).
The recourse algorithms in ReLax
currently only supports binary classification. The output of the classifier must be a probability score (bounded by [0, 1]). Future support for multi-class classification is planned.
Training a classifier using the built-in functions in ReLax
is very simple. We will first specify the classifier. The classifier is called PredictiveTrainingModule
, which specifies the model structure, and the optimization procedure (e.g., it specifies the loss function for optimizing the model). Next, we use train_model
to train the model on TabularDataModule
.
Define the Model
from relax.module import PredictiveTrainingModuleConfigs, PredictiveTrainingModule
Defining PredictiveTrainingModule
is similar to defining TabularDataModule
. We first specify the configurator as PredictiveTrainingModuleConfigs
, and pass this configurator to PredictiveTrainingModule
.
= PredictiveTrainingModuleConfigs(
model_config =0.01, # Learning rate
lr=[50, 10, 50], # The sizes of the hidden layers
sizes=0.3 # Dropout rate
dropout_rate
)
# specify the predictive model
= PredictiveTrainingModule(model_config) module
The training step for each batch is specified in PredictiveTrainingModule
. Essentially, it will compute the binary cross-entropy loss for each batch, and apply backpropagation (via adam) to update parameters of the model.
Train the Model
from relax.trainer import TrainingConfigs, train_model
To train PredictiveTrainingModule
for the entire dataset (specified in TabularDataModule
), we can simply call train_model
:
= TrainingConfigs(
trainer_config =10, # Number of epochs
n_epochs=256, # Batch size
batch_size='val/val_loss', # The metric to monitor
monitor_metrics='pred' # The name of the logger
logger_name
)
# train the model
= train_model(
params, opt_state
module, datamodule, trainer_config )
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
2023-04-24 10:46:46.661007: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda/lib64:
2023-04-24 10:46:46.661163: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda/lib64:
2023-04-24 10:46:46.661171: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
Epoch 9: 100%|██████████| 96/96 [00:00<00:00, 218.13batch/s, train/train_loss_1=0.0634]
Make Predictions
We can directly use module.pred_fn
for making the predictions.
= module.pred_fn pred_fn
Generate Counterfactual Explanations
Now, it is time to use ReLax
to generate counterfactual explanations (or recourse).
from relax.methods import VanillaCF, VanillaCFConfig
We use VanillaCF
(a very popular recourse generation algorithm) as an example for this tutorial. Defining VanillaCF
is similar to defining TabularDataModule
and PredictiveTrainingModule
.
= VanillaCFConfig(
cf_config =1000, # Number of steps
n_steps=0.001 # Learning rate
lr
)= VanillaCF(cf_config) cf_exp
Generate counterfactual examples.
from relax.evaluate import generate_cf_explanations
= generate_cf_explanations(
cf_results
cf_exp, datamodule, pred_fn, ={
pred_fn_args'params': params, 'rng_key': jax.random.PRNGKey(0)
} )
100%|██████████| 1000/1000 [00:10<00:00, 96.49it/s]
Benchmark the Counterfactual Method
After we obtain the counterfactual results, we can use benchmark_cfs
to evaluate the accuracy, validity, and proximity of the counterfactual example.
from relax.evaluate import benchmark_cfs
benchmark_cfs([cf_results])
acc | validity | proximity | ||
---|---|---|---|---|
adult | VanillaCF | 0.829751 | 0.999754 | 7.896814 |