from relax.data import load_dataCounterNet
CounterNet Model
COUNTERNETMODEL
CLASS relax.methods.counternet.CounterNetModel (m_config, name=None)
CounterNet Model
CounterNet Training Module
Define the CounterNetTrainingModule for training CounterNetModel.
PARTITION_TRAINABLE_PARAMS
relax.methods.counternet.partition_trainable_params (params, trainable_name)
<string>:1: DeprecationWarning: jax.numpy.DeviceArray is deprecated. Use jax.Array.
PROJECT_IMMUTABLE_FEATURES
relax.methods.counternet.project_immutable_features (x, cf, imutable_idx_list)
COUNTERNETTRAININGMODULECONFIGS
CLASS relax.methods.counternet.CounterNetTrainingModuleConfigs (lr=0.003, lambda_1=1.0, lambda_2=0.2, lambda_3=0.1)
Create a new model by parsing and validating input data from keyword arguments.
Raises ValidationError if the input data cannot be parsed to form a valid model.
COUNTERNETTRAININGMODULE
CLASS relax.methods.counternet.CounterNetTrainingModule (m_configs)
Helper class that provides a standard way to create an ABC using inheritance.
CounterNet Explanation Module
CounterNet consists of three objectives:
- predictive accuracy: the predictor network should output accurate predictions \hat{y}_x;
- counterfactual validity: CF examples x' produced by the CF generator network should be valid (e.g. \hat{y}_{x} + \hat{y}_{x'}=1);
- minimizing cost of change: minimal modifications should be required to change input instance x to CF example x'.
The objective function of CounterNet:
\operatorname*{argmin}_{\mathbf{\theta}} \frac{1}{N}\sum\nolimits_{i=1}^{N} \bigg[ \lambda_1 \cdot \! \underbrace{\left(y_i- \hat{y}_{x_i}\right)^2}_{\text{Prediction Loss}\ (\mathcal{L}_1)} + \;\lambda_2 \cdot \;\; \underbrace{\left(\hat{y}_{x_i}- \left(1 - \hat{y}_{x_i'}\right)\right)^2}_{\text{Validity Loss}\ (\mathcal{L}_2)} \,+ \;\lambda_3 \cdot \!\! \underbrace{\left(x_i- x'_i\right)^2}_{\text{Cost of change Loss}\ (\mathcal{L}_3)} \bigg]
CounterNet applies two-stage gradient updates to CounterNetModel for each training_step (see CounterNetTrainingModule).
- The first gradient update optimizes for predictive accuracy: \theta^{(1)} = \theta^{(0)} - \nabla_{\theta^{(0)}} (\lambda_1 \cdot \mathcal{L}_1).
- The second gradient update optimizes for generating CF explanation: \theta^{(2)}_g = \theta^{(1)}_g - \nabla_{\theta^{(1)}_g} (\mathcal \lambda_2 \cdot \mathcal{L}_2 + \lambda_3 \cdot \mathcal{L}_3)
The design choice of this optimizing procedure is made due to improved convergence of the model, and improved adversarial robustness of the predictor network. The CounterNet paper elaborates the design choices.
COUNTERNETCONFIGS
CLASS relax.methods.counternet.CounterNetConfigs (enc_sizes=[50, 10], dec_sizes=[10], exp_sizes=[50, 50], dropout_rate=0.3, lr=0.003, lambda_1=1.0, lambda_2=0.2, lambda_3=0.1)
Configurator of CounterNet.
COUNTERNET
CLASS relax.methods.counternet.CounterNet (m_configs=None)
API for CounterNet Explanation Module.
Basic usage of CounterNet
Prepare data:
dm = load_data("adult", data_configs=dict(sample_frac=0.1))/Users/chuck/opt/anaconda3/envs/relax/lib/python3.8/site-packages/sklearn/preprocessing/_encoders.py:828: FutureWarning: `sparse` was renamed to `sparse_output` in version 1.2 and will be removed in 1.4. `sparse_output` is ignored unless you leave `sparse` to its default value.
warnings.warn(
Define CounterNet:
counternet = CounterNet()assert isinstance(counternet, BaseParametricCFModule)
assert isinstance(counternet, BaseCFModule)
assert isinstance(counternet, BasePredFnCFModule)
assert hasattr(counternet, 'pred_fn')Train the model:
t_configs = dict(n_epochs=1, batch_size=128)
counternet.train(dm, t_configs=t_configs)Predict labels
X, y = dm.test_dataset[:]
y_pred = counternet.pred_fn(X)
assert y_pred.shape == (len(y), 1)Generate a CF explanation for a given x.
x, _ = dm.test_dataset[0]
cf = counternet.generate_cf(x)
assert x.shape == cf.shape
assert cf.shape == (29,)Generate CF explanations for given x.
X, _ = dm.test_dataset[:]
cfs = counternet.generate_cfs(X)
assert X.shape == cfs.shape