relax.methods.proto.ProtoCFConfig
[source]
class relax.methods.proto.ProtoCFConfig (n_steps=100, lr=0.01, c=1, beta=0.1, gamma=0.1, theta=0.1, n_samples=128, validity_fn=‘KLDivergence’, enc_sizes=[64, 32, 16], dec_sizes=[16, 32, 64], opt_name=‘adam’, ae_lr=0.001, ae_loss=‘mse’)
Configurator of ProtoCF
.
Parameters:
n_steps (int
, default=100 )
lr (float
, default=0.01 )
c (float
, default=1 ) – The weight for validity loss.
beta (float
, default=0.1 ) – The weight for l1_norm in the cost function, where cost = beta * l1_norm + l2_norm.
gamma (float
, default=0.1 ) – The weight for Autoencoder loss.
theta (float
, default=0.1 ) – The weight for prototype loss.
n_samples (int
, default=128 ) – Number of samples for prototype.
validity_fn (str
, default=KLDivergence )
enc_sizes (List[int]
, default=[64, 32, 16] ) – List of hidden layers of Encoder.
dec_sizes (List[int]
, default=[16, 32, 64] ) – List of hidden layers of Decoder.
opt_name (str
, default=adam ) – Optimizer name of AutoEncoder.
ae_lr (float
, default=0.001 ) – Learning rate of AutoEncoder.
ae_loss (str
, default=mse ) – Loss function name of AutoEncoder.
relax.methods.proto.ProtoCF
[source]
class relax.methods.proto.ProtoCF (config=None, ae=None, name=None)
Base class for parametric counterfactual modules.
Methods
[source]
set_apply_constraints_fn (apply_constraints_fn)
[source]
set_compute_reg_loss_fn (compute_reg_loss_fn)
[source]
apply_constraints (*args, **kwargs)
[source]
compute_reg_loss (*args, **kwargs)
[source]
[source]
[source]
before_generate_cf (*args, **kwargs)
generate_cf (*args, **kwargs)
dm = load_data('oulad' )
model = load_ml_module('oulad' )
xs_train, ys_train = dm['train' ]
xs_test, ys_test = dm['test' ]
/home/birk/code/jax-relax/relax/data_module.py:234: UserWarning: Passing `config` will have no effect.
warnings.warn("Passing `config` will have no effect.")
pcf = ProtoCF()
pcf.set_apply_constraints_fn(dm.apply_constraints)
pcf.train(dm, epochs= 5 )
Epoch 1/5
191/191 ━━━━━━━━━━━━━━━━━━━━ 4s 10ms/step - loss: 0.1207
Epoch 2/5
191/191 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.0418
Epoch 3/5
191/191 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.0373
Epoch 4/5
191/191 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.0341
Epoch 5/5
191/191 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.0324
partial_gen = partial(pcf.generate_cf, pred_fn= model.pred_fn)
cfs = vmap(partial_gen)(xs_test)
print ("Validity: " , keras.metrics.binary_accuracy(
(1 - model.pred_fn(xs_test)).round (),
model.pred_fn(cfs)
).mean())