relax.methods.proto.ProtoCFConfig 
 
[source] 
class relax.methods.proto.ProtoCFConfig  (n_steps=100, lr=0.01, c=1, beta=0.1, gamma=0.1, theta=0.1, n_samples=128, validity_fn=‘KLDivergence’, enc_sizes=[64, 32, 16], dec_sizes=[16, 32, 64], opt_name=‘adam’, ae_lr=0.001, ae_loss=‘mse’) 
 
Configurator of ProtoCF .
Parameters: 
n_steps  (int, default=100 ) 
lr  (float, default=0.01 ) 
c  (float, default=1 ) – The weight for validity loss. 
beta  (float, default=0.1 ) – The weight for l1_norm in the cost function, where cost = beta * l1_norm + l2_norm. 
gamma  (float, default=0.1 ) – The weight for Autoencoder loss. 
theta  (float, default=0.1 ) – The weight for prototype loss. 
n_samples  (int, default=128 ) – Number of samples for prototype. 
validity_fn  (str, default=KLDivergence ) 
enc_sizes  (List[int], default=[64, 32, 16] ) – List of hidden layers of Encoder. 
dec_sizes  (List[int], default=[16, 32, 64] ) – List of hidden layers of Decoder. 
opt_name  (str, default=adam ) – Optimizer name of AutoEncoder. 
ae_lr  (float, default=0.001 ) – Learning rate of AutoEncoder. 
ae_loss  (str, default=mse ) – Loss function name of AutoEncoder. 
 
 
relax.methods.proto.ProtoCF 
 
[source] 
class relax.methods.proto.ProtoCF  (config=None, ae=None, name=None) 
 
Base class for parametric counterfactual modules.
Methods 
[source] 
set_apply_constraints_fn  (apply_constraints_fn) 
 
[source] 
set_compute_reg_loss_fn  (compute_reg_loss_fn) 
 
[source] 
apply_constraints  (*args, **kwargs) 
 
[source] 
compute_reg_loss  (*args, **kwargs) 
 
[source] 
[source] 
[source] 
before_generate_cf  (*args, **kwargs) 
 
generate_cf  (*args, **kwargs) 
 
 
 dm =  load_data('oulad' ) 
 model =  load_ml_module('oulad' ) 
 xs_train, ys_train =  dm['train' ] 
 xs_test, ys_test =  dm['test' ] 
/home/birk/code/jax-relax/relax/data_module.py:234: UserWarning: Passing `config` will have no effect.
  warnings.warn("Passing `config` will have no effect.") 
 
 
 pcf =  ProtoCF() 
 pcf.set_apply_constraints_fn(dm.apply_constraints) 
 pcf.train(dm, epochs= 5 ) 
Epoch 1/5
191/191 ━━━━━━━━━━━━━━━━━━━━ 4s 10ms/step - loss: 0.1207   
Epoch 2/5
191/191 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.0418      
Epoch 3/5
191/191 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.0373      
Epoch 4/5
191/191 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.0341      
Epoch 5/5
191/191 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.0324     
 
 
 partial_gen =  partial(pcf.generate_cf, pred_fn= model.pred_fn) 
 cfs =  vmap(partial_gen)(xs_test) 
 
print ("Validity: " , keras.metrics.binary_accuracy( 
     (1  -  model.pred_fn(xs_test)).round (), 
     model.pred_fn(cfs) 
 ).mean())