CounterNet

A prediction-aware recourse model

CounterNet Model


source

COUNTERNETMODEL

CLASS relax.methods.counternet.CounterNetModel (m_config, name=None)

CounterNet Model

Parameters:
  • m_config (Dict | CounterNetModelConfigs) – Model configs which contain configs in CounterNetModelConfigs.
  • name (str, default=None) – Name of the module.

CounterNet Training Module

Define the CounterNetTrainingModule for training CounterNetModel.


source

PARTITION_TRAINABLE_PARAMS

relax.methods.counternet.partition_trainable_params (params, trainable_name)

<string>:1: DeprecationWarning: jax.numpy.DeviceArray is deprecated. Use jax.Array.

source

PROJECT_IMMUTABLE_FEATURES

relax.methods.counternet.project_immutable_features (x, cf, imutable_idx_list)


source

COUNTERNETTRAININGMODULECONFIGS

CLASS relax.methods.counternet.CounterNetTrainingModuleConfigs (lr=0.003, lambda_1=1.0, lambda_2=0.2, lambda_3=0.1)

Create a new model by parsing and validating input data from keyword arguments.

Raises ValidationError if the input data cannot be parsed to form a valid model.


source

COUNTERNETTRAININGMODULE

CLASS relax.methods.counternet.CounterNetTrainingModule (m_configs)

Helper class that provides a standard way to create an ABC using inheritance.

CounterNet Explanation Module

CounterNet architecture

CounterNet consists of three objectives:

  1. predictive accuracy: the predictor network should output accurate predictions \hat{y}_x;
  2. counterfactual validity: CF examples x' produced by the CF generator network should be valid (e.g. \hat{y}_{x} + \hat{y}_{x'}=1);
  3. minimizing cost of change: minimal modifications should be required to change input instance x to CF example x'.

The objective function of CounterNet:

\operatorname*{argmin}_{\mathbf{\theta}} \frac{1}{N}\sum\nolimits_{i=1}^{N} \bigg[ \lambda_1 \cdot \! \underbrace{\left(y_i- \hat{y}_{x_i}\right)^2}_{\text{Prediction Loss}\ (\mathcal{L}_1)} + \;\lambda_2 \cdot \;\; \underbrace{\left(\hat{y}_{x_i}- \left(1 - \hat{y}_{x_i'}\right)\right)^2}_{\text{Validity Loss}\ (\mathcal{L}_2)} \,+ \;\lambda_3 \cdot \!\! \underbrace{\left(x_i- x'_i\right)^2}_{\text{Cost of change Loss}\ (\mathcal{L}_3)} \bigg]

CounterNet applies two-stage gradient updates to CounterNetModel for each training_step (see CounterNetTrainingModule).

  1. The first gradient update optimizes for predictive accuracy: \theta^{(1)} = \theta^{(0)} - \nabla_{\theta^{(0)}} (\lambda_1 \cdot \mathcal{L}_1).
  2. The second gradient update optimizes for generating CF explanation: \theta^{(2)}_g = \theta^{(1)}_g - \nabla_{\theta^{(1)}_g} (\mathcal \lambda_2 \cdot \mathcal{L}_2 + \lambda_3 \cdot \mathcal{L}_3)

The design choice of this optimizing procedure is made due to improved convergence of the model, and improved adversarial robustness of the predictor network. The CounterNet paper elaborates the design choices.


source

COUNTERNETCONFIGS

CLASS relax.methods.counternet.CounterNetConfigs (enc_sizes=[50, 10], dec_sizes=[10], exp_sizes=[50, 50], dropout_rate=0.3, lr=0.003, lambda_1=1.0, lambda_2=0.2, lambda_3=0.1)

Configurator of CounterNet.

Parameters:
  • enc_sizes (List[int], default=[50, 10]) – Sequence of layer sizes for encoder network.
  • dec_sizes (List[int], default=[10]) – Sequence of layer sizes for predictor.
  • exp_sizes (List[int], default=[50, 50]) – Sequence of layer sizes for CF generator.
  • dropout_rate (float, default=0.3) – Dropout rate.
  • lr (float, default=0.003) – Learning rate for training CounterNet.
  • lambda_1 (float, default=1.0) – \lambda_1 for balancing the prediction loss \mathcal{L}_1.
  • lambda_2 (float, default=0.2) – \lambda_2 for balancing the prediction loss \mathcal{L}_2.
  • lambda_3 (float, default=0.1) – \lambda_3 for balancing the prediction loss \mathcal{L}_3.

source

COUNTERNET

CLASS relax.methods.counternet.CounterNet (m_configs=None)

API for CounterNet Explanation Module.

Parameters:
  • m_configs (dict | CounterNetConfigs, default=None) – configurator of hyperparamters; see CounterNetConfigs

Basic usage of CounterNet

Prepare data:

from relax.data import load_data
dm = load_data("adult", data_configs=dict(sample_frac=0.1))
/Users/chuck/opt/anaconda3/envs/relax/lib/python3.8/site-packages/sklearn/preprocessing/_encoders.py:828: FutureWarning: `sparse` was renamed to `sparse_output` in version 1.2 and will be removed in 1.4. `sparse_output` is ignored unless you leave `sparse` to its default value.
  warnings.warn(

Define CounterNet:

counternet = CounterNet()
assert isinstance(counternet, BaseParametricCFModule)
assert isinstance(counternet, BaseCFModule)
assert isinstance(counternet, BasePredFnCFModule)
assert hasattr(counternet, 'pred_fn')

Train the model:

t_configs = dict(n_epochs=1, batch_size=128)
counternet.train(dm, t_configs=t_configs)

Predict labels

X, y = dm.test_dataset[:]
y_pred = counternet.pred_fn(X)
assert y_pred.shape == (len(y), 1)

Generate a CF explanation for a given x.

x, _ = dm.test_dataset[0]
cf = counternet.generate_cf(x)
assert x.shape == cf.shape
assert cf.shape == (29,)

Generate CF explanations for given x.

X, _ = dm.test_dataset[:]
cfs = counternet.generate_cfs(X)
assert X.shape == cfs.shape