CCHVAE

relax.methods.cchvae.CHVAE

[source]

class relax.methods.cchvae.CHVAE (layers, dropout_rate=0.0, **kwargs)

A model grouping layers into an object with training/inference features.

There are three ways to instantiate a Model:

With the “Functional API”

You start from Input, you chain layer calls to specify the model’s forward pass, and finally you create your model from inputs and outputs:

inputs = keras.Input(shape=(37,))
x = keras.layers.Dense(32, activation="relu")(inputs)
outputs = keras.layers.Dense(5, activation="softmax")(x)
model = keras.Model(inputs=inputs, outputs=outputs)

Note: Only dicts, lists, and tuples of input tensors are supported. Nested inputs are not supported (e.g. lists of list or dicts of dict).

A new Functional API model can also be created by using the intermediate tensors. This enables you to quickly extract sub-components of the model.

Example:

inputs = keras.Input(shape=(None, None, 3))
processed = keras.layers.RandomCrop(width=128, height=128)(inputs)
conv = keras.layers.Conv2D(filters=32, kernel_size=3)(processed)
pooling = keras.layers.GlobalAveragePooling2D()(conv)
feature = keras.layers.Dense(10)(pooling)

full_model = keras.Model(inputs, feature)
backbone = keras.Model(processed, conv)
activations = keras.Model(conv, feature)

Note that the backbone and activations models are not created with keras.Input objects, but with the tensors that originate from keras.Input objects. Under the hood, the layers and weights will be shared across these models, so that user can train the full_model, and use backbone or activations to do feature extraction. The inputs and outputs of the model can be nested structures of tensors as well, and the created models are standard Functional API models that support all the existing APIs.

By subclassing the Model class

In that case, you should define your layers in __init__() and you should implement the model’s forward pass in call().

class MyModel(keras.Model):
    def __init__(self):
        super().__init__()
        self.dense1 = keras.layers.Dense(32, activation="relu")
        self.dense2 = keras.layers.Dense(5, activation="softmax")

    def call(self, inputs):
        x = self.dense1(inputs)
        return self.dense2(x)

model = MyModel()

If you subclass Model, you can optionally have a training argument (boolean) in call(), which you can use to specify a different behavior in training and inference:

class MyModel(keras.Model):
    def __init__(self):
        super().__init__()
        self.dense1 = keras.layers.Dense(32, activation="relu")
        self.dense2 = keras.layers.Dense(5, activation="softmax")
        self.dropout = keras.layers.Dropout(0.5)

    def call(self, inputs, training=False):
        x = self.dense1(inputs)
        x = self.dropout(x, training=training)
        return self.dense2(x)

model = MyModel()

Once the model is created, you can config the model with losses and metrics with model.compile(), train the model with model.fit(), or use the model to do prediction with model.predict().

With the Sequential class

In addition, keras.Sequential is a special case of model where the model is purely a stack of single-input, single-output layers.

model = keras.Sequential([
    keras.Input(shape=(None, None, 3)),
    keras.layers.Conv2D(filters=32, kernel_size=3),
])

relax.methods.cchvae.CCHVAEConfig

[source]

class relax.methods.cchvae.CCHVAEConfig (vae_layers=[20, 16, 14, 12], opt_name=‘adam’, vae_lr=0.001, max_steps=100, n_search_samples=100, step_size=0.1)

Base class for all config classes.

Parameters:

  • vae_layers (List[int], default=[20, 16, 14, 12]) – List of hidden layer sizes for VAE.
  • opt_name (str, default=adam) – Optimizer name of VAE.
  • vae_lr (float, default=0.001) – Learning rate of VAE.
  • max_steps (int, default=100) – Max steps
  • n_search_samples (int, default=100) – Number of generated candidate counterfactuals.
  • step_size (float, default=0.1) – Step size

relax.methods.cchvae.CCHVAE

[source]

class relax.methods.cchvae.CCHVAE (config=None, chvae=None, name=‘cchvae’)

Base class for parametric counterfactual modules.

Methods

[source]

set_apply_constraints_fn (apply_constraints_fn)

[source]

set_compute_reg_loss_fn (compute_reg_loss_fn)

[source]

apply_constraints (*args, **kwargs)

[source]

compute_reg_loss (*args, **kwargs)

[source]

save (path)

[source]

load_from_path (path)

[source]

before_generate_cf (*args, **kwargs)

generate_cf (*args, **kwargs)

data = load_data('adult')
pred_fn = load_ml_module('adult').pred_fn
xs_train, ys_train = data['train']
xs_test, ys_test = data['test']
/home/birk/code/jax-relax/relax/data_module.py:234: UserWarning: Passing `config` will have no effect.
  warnings.warn("Passing `config` will have no effect.")
cchvae = CCHVAE()
cchvae.train(data, epochs=5)
cchvae.set_apply_constraints_fn(data.apply_constraints)
Epoch 1/5
191/191 ━━━━━━━━━━━━━━━━━━━━ 3s 9ms/step - loss: 103.6776     
Epoch 2/5
191/191 ━━━━━━━━━━━━━━━━━━━━ 0s 808us/step - loss: 3.1196     
Epoch 3/5
191/191 ━━━━━━━━━━━━━━━━━━━━ 0s 788us/step - loss: 1.3849    
Epoch 4/5
191/191 ━━━━━━━━━━━━━━━━━━━━ 0s 811us/step - loss: 0.8786    
Epoch 5/5
191/191 ━━━━━━━━━━━━━━━━━━━━ 0s 785us/step - loss: 0.6225    
cf = cchvae.generate_cf(xs_train[0], pred_fn, rng_key=jrand.PRNGKey(0))
n_tests = 100
partial_gen = partial(cchvae.generate_cf, pred_fn=pred_fn)
cfs = jax.vmap(partial_gen)(xs_test[:n_tests], rng_key=jrand.split(jrand.PRNGKey(0), n_tests))

assert cfs.shape == xs_test[:100].shape

print("Validity: ", keras.metrics.binary_accuracy(
    (1 - pred_fn(xs_test[:100])).round(),
    pred_fn(cfs[:, :])
).mean())
Validity:  1.0