import jaxGetting started
ReLax
This tutorial aims at introducing basics about ReLax, and how to use ReLax to generate counterfactual (or recourse) explanations for jax-based implementations of ML models.
In particular, we will cover the following things in this tutorial:
- Loading datasets with
TabularDataModule; - Training machine learning classifiers;
- Generating counterfactual (or recourse) explanations;
- Benchmarking different recourse methods.
Preparation
We assume that you have already installed ReLax. If not, follow the steps in this installation tutorial, or just enter pip install jax-relax.
We also want to import some libraries for this tutorial.
Load Dataset with TabularDataModule
TabularDataModule is a python class which modularizes tabular dataset loading. TabularDataModule loads a .csv file from the directory by specifying the following attributes:
data_nameis the name of your dataset.data_dirshould contain the relative path of the directory where your dataset is located.continous_colsspecifies a list of feature names representing all the continuous/numeric features in our dataset.discret_colsspecifies a list of feature names representing all discrete features in our dataset. By default, all discrete features are converted via one-hot encoding for training purposes.imutable_colsspecifies a list of feature names that represent immutable features that we do not wish to change in the generated recourse.
from relax.data import TabularDataModuleConfigs, TabularDataModule, load_dataFor example, to load the adult dataset, we can specify the TabularDataModuleConfigs as
data_config = TabularDataModuleConfigs(
# The name of this dataset is "adult"
data_name="adult",
# The data file is located in `../assets/data/s_adult.csv`.
data_dir="../assets/data/s_adult.csv",
# Contains 2 features with continuous variables
continous_cols=["age","hours_per_week"],
# Contains 6 features with categorical (discrete) variables
discret_cols=["workclass","education","marital_status","occupation","race","gender"],
# Contains 2 features that we do not wish to change
imutable_cols=["race", "gender"]
)We can then pass data_configs to the TabularDataModule.
datamodule = TabularDataModule(data_config)Alternatively, we can also specify this config via a dictionary.
# This approach is equivalent to using `TabularDataModuleConfigs`
data_config_dict = {
"data_name": "adult",
"data_dir": "../assets/data/s_adult.csv",
"continous_cols": ["age","hours_per_week"],
"discret_cols": ["workclass","education","marital_status","occupation","race","gender"],
"imutable_cols": ["race","gender"]
}
datamodule = TabularDataModule(data_config_dict)For datasets supported by ReLax, we can simply call load_data:
# This is equivalent to specifying configs for `TabularDataModule`
datamodule = load_data('adult')For more usage of loading datasets in ReLax, check out the data module documentation.
Train the Classifier
For the purpose of exposing full functionality of the framework, we will train the model using the built-in functions in ReLax, which uses haiku for building neural network blocks. However, the recourse algorithms in ReLax can generate explanations for all jax-based framework (e.g., flax, haiku, vanilla jax).
The recourse algorithms in ReLax currently only supports binary classification. The output of the classifier must be a probability score (bounded by [0, 1]). Future support for multi-class classification is planned.
Training a classifier using the built-in functions in ReLax is very simple. We will first specify the classifier. The classifier is called PredictiveTrainingModule, which specifies the model structure, and the optimization procedure (e.g., it specifies the loss function for optimizing the model). Next, we use train_model to train the model on TabularDataModule.
Define the Model
from relax.module import PredictiveTrainingModuleConfigs, PredictiveTrainingModuleDefining PredictiveTrainingModule is similar to defining TabularDataModule. We first specify the configurator as PredictiveTrainingModuleConfigs, and pass this configurator to PredictiveTrainingModule.
model_config = PredictiveTrainingModuleConfigs(
lr=0.01, # Learning rate
sizes=[50, 10, 50], # The sizes of the hidden layers
dropout_rate=0.3 # Dropout rate
)
# specify the predictive model
module = PredictiveTrainingModule(model_config)The training step for each batch is specified in PredictiveTrainingModule. Essentially, it will compute the binary cross-entropy loss for each batch, and apply backpropagation (via adam) to update parameters of the model.
Train the Model
from relax.trainer import TrainingConfigs, train_modelTo train PredictiveTrainingModule for the entire dataset (specified in TabularDataModule), we can simply call train_model:
trainer_config = TrainingConfigs(
n_epochs=10, # Number of epochs
batch_size=256, # Batch size
monitor_metrics='val/val_loss', # The metric to monitor
logger_name='pred' # The name of the logger
)
# train the model
params, opt_state = train_model(
module, datamodule, trainer_config
)WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
2023-04-24 10:46:46.661007: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda/lib64:
2023-04-24 10:46:46.661163: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda/lib64:
2023-04-24 10:46:46.661171: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
Epoch 9: 100%|██████████| 96/96 [00:00<00:00, 218.13batch/s, train/train_loss_1=0.0634]
Make Predictions
We can directly use module.pred_fn for making the predictions.
pred_fn = module.pred_fnGenerate Counterfactual Explanations
Now, it is time to use ReLax to generate counterfactual explanations (or recourse).
from relax.methods import VanillaCF, VanillaCFConfigWe use VanillaCF (a very popular recourse generation algorithm) as an example for this tutorial. Defining VanillaCF is similar to defining TabularDataModule and PredictiveTrainingModule.
cf_config = VanillaCFConfig(
n_steps=1000, # Number of steps
lr=0.001 # Learning rate
)
cf_exp = VanillaCF(cf_config)Generate counterfactual examples.
from relax.evaluate import generate_cf_explanationscf_results = generate_cf_explanations(
cf_exp, datamodule, pred_fn,
pred_fn_args={
'params': params, 'rng_key': jax.random.PRNGKey(0)
}
)100%|██████████| 1000/1000 [00:10<00:00, 96.49it/s]
Benchmark the Counterfactual Method
After we obtain the counterfactual results, we can use benchmark_cfs to evaluate the accuracy, validity, and proximity of the counterfactual example.
from relax.evaluate import benchmark_cfsbenchmark_cfs([cf_results])| acc | validity | proximity | ||
|---|---|---|---|---|
| adult | VanillaCF | 0.829751 | 0.999754 | 7.896814 |