from relax.data import load_data
CounterNet
CounterNet Model
COUNTERNETMODEL
CLASS relax.methods.counternet.CounterNetModel (m_config, name=None)
CounterNet Model
CounterNet Training Module
Define the CounterNetTrainingModule
for training CounterNetModel
.
PARTITION_TRAINABLE_PARAMS
relax.methods.counternet.partition_trainable_params (params, trainable_name)
<string>:1: DeprecationWarning: jax.numpy.DeviceArray is deprecated. Use jax.Array.
PROJECT_IMMUTABLE_FEATURES
relax.methods.counternet.project_immutable_features (x, cf, imutable_idx_list)
COUNTERNETTRAININGMODULECONFIGS
CLASS relax.methods.counternet.CounterNetTrainingModuleConfigs (lr=0.003, lambda_1=1.0, lambda_2=0.2, lambda_3=0.1)
Create a new model by parsing and validating input data from keyword arguments.
Raises ValidationError if the input data cannot be parsed to form a valid model.
COUNTERNETTRAININGMODULE
CLASS relax.methods.counternet.CounterNetTrainingModule (m_configs)
Helper class that provides a standard way to create an ABC using inheritance.
CounterNet Explanation Module
CounterNet
consists of three objectives:
- predictive accuracy: the predictor network should output accurate predictions \hat{y}_x;
- counterfactual validity: CF examples x' produced by the CF generator network should be valid (e.g. \hat{y}_{x} + \hat{y}_{x'}=1);
- minimizing cost of change: minimal modifications should be required to change input instance x to CF example x'.
The objective function of CounterNet
:
\operatorname*{argmin}_{\mathbf{\theta}} \frac{1}{N}\sum\nolimits_{i=1}^{N} \bigg[ \lambda_1 \cdot \! \underbrace{\left(y_i- \hat{y}_{x_i}\right)^2}_{\text{Prediction Loss}\ (\mathcal{L}_1)} + \;\lambda_2 \cdot \;\; \underbrace{\left(\hat{y}_{x_i}- \left(1 - \hat{y}_{x_i'}\right)\right)^2}_{\text{Validity Loss}\ (\mathcal{L}_2)} \,+ \;\lambda_3 \cdot \!\! \underbrace{\left(x_i- x'_i\right)^2}_{\text{Cost of change Loss}\ (\mathcal{L}_3)} \bigg]
CounterNet
applies two-stage gradient updates to CounterNetModel
for each training_step
(see CounterNetTrainingModule
).
- The first gradient update optimizes for predictive accuracy: \theta^{(1)} = \theta^{(0)} - \nabla_{\theta^{(0)}} (\lambda_1 \cdot \mathcal{L}_1).
- The second gradient update optimizes for generating CF explanation: \theta^{(2)}_g = \theta^{(1)}_g - \nabla_{\theta^{(1)}_g} (\mathcal \lambda_2 \cdot \mathcal{L}_2 + \lambda_3 \cdot \mathcal{L}_3)
The design choice of this optimizing procedure is made due to improved convergence of the model, and improved adversarial robustness of the predictor network. The CounterNet paper elaborates the design choices.
COUNTERNETCONFIGS
CLASS relax.methods.counternet.CounterNetConfigs (enc_sizes=[50, 10], dec_sizes=[10], exp_sizes=[50, 50], dropout_rate=0.3, lr=0.003, lambda_1=1.0, lambda_2=0.2, lambda_3=0.1)
Configurator of CounterNet
.
COUNTERNET
CLASS relax.methods.counternet.CounterNet (m_configs=None)
API for CounterNet Explanation Module.
Basic usage of CounterNet
Prepare data:
= load_data("adult", data_configs=dict(sample_frac=0.1)) dm
/Users/chuck/opt/anaconda3/envs/relax/lib/python3.8/site-packages/sklearn/preprocessing/_encoders.py:828: FutureWarning: `sparse` was renamed to `sparse_output` in version 1.2 and will be removed in 1.4. `sparse_output` is ignored unless you leave `sparse` to its default value.
warnings.warn(
Define CounterNet
:
= CounterNet() counternet
assert isinstance(counternet, BaseParametricCFModule)
assert isinstance(counternet, BaseCFModule)
assert isinstance(counternet, BasePredFnCFModule)
assert hasattr(counternet, 'pred_fn')
Train the model:
= dict(n_epochs=1, batch_size=128)
t_configs =t_configs) counternet.train(dm, t_configs
Predict labels
= dm.test_dataset[:]
X, y = counternet.pred_fn(X)
y_pred assert y_pred.shape == (len(y), 1)
Generate a CF explanation for a given x
.
= dm.test_dataset[0]
x, _ = counternet.generate_cf(x)
cf assert x.shape == cf.shape
assert cf.shape == (29,)
Generate CF explanations for given x
.
= dm.test_dataset[:]
X, _ = counternet.generate_cfs(X)
cfs assert X.shape == cfs.shape