Proto CF

relax.methods.proto.ProtoCFConfig

[source]

class relax.methods.proto.ProtoCFConfig (n_steps=100, lr=0.01, c=1, beta=0.1, gamma=0.1, theta=0.1, n_samples=128, validity_fn=‘KLDivergence’, enc_sizes=[64, 32, 16], dec_sizes=[16, 32, 64], opt_name=‘adam’, ae_lr=0.001, ae_loss=‘mse’)

Configurator of ProtoCF.

Parameters:

  • n_steps (int, default=100)
  • lr (float, default=0.01)
  • c (float, default=1) – The weight for validity loss.
  • beta (float, default=0.1) – The weight for l1_norm in the cost function, where cost = beta * l1_norm + l2_norm.
  • gamma (float, default=0.1) – The weight for Autoencoder loss.
  • theta (float, default=0.1) – The weight for prototype loss.
  • n_samples (int, default=128) – Number of samples for prototype.
  • validity_fn (str, default=KLDivergence)
  • enc_sizes (List[int], default=[64, 32, 16]) – List of hidden layers of Encoder.
  • dec_sizes (List[int], default=[16, 32, 64]) – List of hidden layers of Decoder.
  • opt_name (str, default=adam) – Optimizer name of AutoEncoder.
  • ae_lr (float, default=0.001) – Learning rate of AutoEncoder.
  • ae_loss (str, default=mse) – Loss function name of AutoEncoder.

relax.methods.proto.ProtoCF

[source]

class relax.methods.proto.ProtoCF (config=None, ae=None, name=None)

Base class for parametric counterfactual modules.

Methods

[source]

set_apply_constraints_fn (apply_constraints_fn)

[source]

set_compute_reg_loss_fn (compute_reg_loss_fn)

[source]

apply_constraints (*args, **kwargs)

[source]

compute_reg_loss (*args, **kwargs)

[source]

save (path)

[source]

load_from_path (path)

[source]

before_generate_cf (*args, **kwargs)

generate_cf (*args, **kwargs)

dm = load_data('oulad')
model = load_ml_module('oulad')
xs_train, ys_train = dm['train']
xs_test, ys_test = dm['test']
/home/birk/code/jax-relax/relax/data_module.py:234: UserWarning: Passing `config` will have no effect.
  warnings.warn("Passing `config` will have no effect.")
pcf = ProtoCF()
pcf.set_apply_constraints_fn(dm.apply_constraints)
pcf.train(dm, epochs=5)
Epoch 1/5
191/191 ━━━━━━━━━━━━━━━━━━━━ 4s 10ms/step - loss: 0.1207   
Epoch 2/5
191/191 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.0418      
Epoch 3/5
191/191 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.0373      
Epoch 4/5
191/191 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.0341      
Epoch 5/5
191/191 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.0324    
<__main__.ProtoCF>
partial_gen = partial(pcf.generate_cf, pred_fn=model.pred_fn)
cfs = vmap(partial_gen)(xs_test)

print("Validity: ", keras.metrics.binary_accuracy(
    (1 - model.pred_fn(xs_test)).round(),
    model.pred_fn(cfs)
).mean())
Validity:  0.95471835