import jax
Getting started
ReLax
.
This tutorial aims at introducing basics about ReLax
, and how to use ReLax
to generate counterfactual (or recourse) explanations for jax
-based implementations of ML models.
In particular, we will cover the following things in this tutorial:
- Loading datasets with
DataModule
; - Training machine learning classifiers;
- Generating counterfactual (or recourse) explanations;
- Benchmarking different recourse methods.
Preparation
We assume that you have already installed ReLax
. If not, follow the steps in this installation tutorial, or just enter pip install jax-relax
.
We also want to import some libraries for this tutorial.
Load Dataset with DataModule
DataModule
is a python class which modularizes tabular dataset loading. DataModule
loads a .csv
file from the directory by specifying the following attributes:
data_name
is the name of your dataset.data_dir
should contain the relative path of the directory where your dataset is located.continous_cols
specifies a list of feature names representing all the continuous/numeric features in our dataset.discret_cols
specifies a list of feature names representing all discrete features in our dataset. By default, all discrete features are converted via one-hot encoding for training purposes.imutable_cols
specifies a list of feature names that represent immutable features that we do not wish to change in the generated recourse.
from relax.data_module import DataModuleConfig, DataModule, load_data
For example, to load the adult dataset, we can specify the DataModuleConfig
as
= DataModuleConfig(
data_config # The name of this dataset is "adult"
="adult",
data_name# The data file is located in `../assets/data/s_adult.csv`.
="../assets/adult/data/data.csv",
data_dir# Contains 2 features with continuous variables
=["age","hours_per_week"],
continous_cols# Contains 6 features with categorical (discrete) variables
=["workclass","education","marital_status","occupation","race","gender"],
discret_cols# Contains 2 features that we do not wish to change
=["race", "gender"]
imutable_cols )
We can then pass data_configs
to the DataModule
.
= DataModule.from_config(data_config) datamodule
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Alternatively, we can also specify this config via a dictionary.
# This approach is equivalent to using `DataModuleConfig`
= {
data_config_dict "data_name": "adult",
"data_dir": "../assets/adult/data/data.csv",
"continous_cols": ["age","hours_per_week"],
"discret_cols": ["workclass","education","marital_status","occupation","race","gender"],
"imutable_cols": ["race","gender"]
}= DataModule.from_config(data_config) datamodule
For datasets supported by ReLax
, we can simply call load_data
:
# This is equivalent to specifying configs for `DataModule`
= load_data('adult') datamodule
For more usage of loading datasets in ReLax
, check out the data module documentation.
Train the Classifier
For the purpose of exposing full functionality of the framework, we will train the model using the built-in functions in ReLax
, which uses haiku for building neural network blocks. However, the recourse algorithms in ReLax
can generate explanations for all jax-based framework (e.g., flax, haiku, vanilla jax).
The recourse algorithms in ReLax
currently only supports binary classification. The output of the classifier must be a probability score (bounded by [0, 1]). Future support for multi-class classification is planned.
Training a classifier using the built-in functions in ReLax
is very simple. We will first specify the classifier. The classifier is called PredictiveTrainingModule
, which specifies the model structure, and the optimization procedure (e.g., it specifies the loss function for optimizing the model). Next, we use train_model
to train the model on TabularDataModule
.
Define the Model
from relax.ml_model import MLModuleConfig, MLModule
Defining MLModule
is similar to defining MLModuleConfig
. We first specify the configurator as MLModuleConfig
, and pass this configurator to MLModuleConfig
.
= MLModuleConfig(
model_config =0.01, # Learning rate
lr=[50, 10, 50], # The sizes of the hidden layers
sizes=0.3 # Dropout rate
dropout_rate
)
# specify the predictive model
= MLModule(model_config) module
Train the Model
To train MLModule
for the entire dataset (specified in DataModule
), we can simply call MLModule.train
:
=128, epochs=5) module.train(datamodule, batch_size
Epoch 1/5
191/191 ━━━━━━━━━━━━━━━━━━━━ 3s 11ms/step - accuracy: 0.7807 - loss: 0.4597
Epoch 2/5
191/191 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - accuracy: 0.8135 - loss: 0.3945
Epoch 3/5
191/191 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - accuracy: 0.8202 - loss: 0.3769
Epoch 4/5
191/191 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - accuracy: 0.8205 - loss: 0.3784
Epoch 5/5
191/191 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - accuracy: 0.8167 - loss: 0.3817
<relax.ml_model.MLModule>
Make Predictions
We can directly use module.pred_fn
for making the predictions.
= module.pred_fn pred_fn
Generate Counterfactual Explanations
Now, it is time to use ReLax
to generate counterfactual explanations (or recourse).
from relax.methods import VanillaCF, VanillaCFConfig
We use VanillaCF
(a very popular recourse generation algorithm) as an example for this tutorial. Defining VanillaCF
is similar to defining TabularDataModule
and PredictiveTrainingModule
.
= VanillaCFConfig(
cf_config =1000, # Number of steps
n_steps=0.001 # Learning rate
lr
)= VanillaCF(cf_config) cf_exp
Generate counterfactual examples.
from relax.explain import generate_cf_explanations
= generate_cf_explanations(
cf_results
cf_exp, datamodule, pred_fn, )
Benchmark the Counterfactual Method
After we obtain the counterfactual results, we can use benchmark_cfs
to evaluate the accuracy, validity, and proximity of the counterfactual example.
from relax.evaluate import benchmark_cfs
benchmark_cfs([cf_results])
acc | validity | proximity | ||
---|---|---|---|---|
adult | VanillaCF | 0.828261 | 0.814963 | 4.79361 |