L2C

start_end = jnp.array([[0, 1], [1, 2], [2, 3], [3, 5]])
xs = jrand.normal(jrand.PRNGKey(0), (4, 5),)
cfs = jrand.normal(jrand.PRNGKey(1), (4, 5),)
prob = jrand.uniform(jrand.PRNGKey(2), (4, 4),)

# split xs into 4 parts according to start_end
xs_split = jnp.split(xs, start_end[:-1, 1], axis=1)
cfs_split = jnp.split(cfs, start_end[:-1, 1], axis=1)
prob_split = jnp.split(prob, start_end.shape[0], axis=1)

def perturb(x, cf, prob):
    return x * (1 - prob) + cf * prob

perturbed = jax.tree_util.tree_map(
    perturb, xs_split, cfs_split, prob_split
)

L2C Model

relax.utils.gumbel_softmax

[source]

relax.utils.gumbel_softmax (key, logits, tau)

The Gumbel softmax function.

Parameters:

  • key (<function PRNGKey at 0x7fe228f3ab00>) – Random key
  • logits (<class 'jax.Array'>) – Logits for each class. Shape (batch_size, num_classes)
  • tau (<class 'float'>) – Temperature for the Gumbel softmax

relax.methods.sphere.sample_categorical

[source]

relax.methods.sphere.sample_categorical (key, logits, tau, training=True)

Sample from a categorical distribution.

Parameters:

  • key (<function PRNGKey at 0x7fe228f3ab00>) – Random key
  • logits (<class 'jax.Array'>) – Logits for each class. Shape (batch_size, num_classes)
  • tau (<class 'float'>) – Temperature for the Gumbel softmax
  • training (<class 'bool'>, default=True) – Apply gumbel softmax if training
logits = jnp.array([[2.0, 1.0, 0.1], [1.0, 2.0, 3.0]])
key = jrand.PRNGKey(0)
output = sample_categorical(key, logits, tau=0.5, training=True)
assert output.shape == logits.shape
assert jnp.allclose(output.sum(axis=-1), 1.0)
# low temperature -> one-hot
output = sample_categorical(key, logits, tau=0.01, training=True)
assert jnp.array_equal(
    output.argmax(axis=-1), logits.argmax(axis=-1)
)
# high temperature -> uniform
output = sample_categorical(key, logits, tau=100, training=True)
assert jnp.max(output) - jnp.min(output) < 0.5

output = sample_categorical(key, logits, tau=0.5, training=False)
assert output.shape == logits.shape
assert jnp.array_equal(
    output.argmax(axis=-1), logits.argmax(axis=-1)
)

relax.methods.l2c.sample_bernouli

[source]

relax.methods.l2c.sample_bernouli (key, prob, tau, training=True)

“Sample from a bernouli distribution.

Parameters:

  • key (<function PRNGKey at 0x7fe228f3ab00>) – Random key
  • prob (<class 'jax.Array'>) – Logits for each class. Shape (batch_size, 1)
  • tau (<class 'float'>) – Temperature for the Gumbel softmax
  • training (<class 'bool'>, default=True) – Apply gumbel softmax if training

Returns:

    (<class 'jax.Array'>)

relax.methods.l2c.split_fn

[source]

relax.methods.l2c.split_fn (feature_indices)

start_end = [(0, 1), (1, 2), (2, 3), (3, 5)]
split_xs, split_prob = split_fn(start_end)
assert len(split_xs(xs)) == len(start_end)
assert len(split_prob(prob)) == len(start_end)

relax.methods.l2c.L2CModel

[source]

class relax.methods.l2c.L2CModel (generator_layers, selector_layers, feature_indices=None, immutable_mask=None, pred_fn=None, alpha=0.0001, tau=0.7, seed=None, **kwargs)

A model grouping layers into an object with training/inference features.

There are three ways to instantiate a Model:

With the “Functional API”

You start from Input, you chain layer calls to specify the model’s forward pass, and finally you create your model from inputs and outputs:

inputs = keras.Input(shape=(37,))
x = keras.layers.Dense(32, activation="relu")(inputs)
outputs = keras.layers.Dense(5, activation="softmax")(x)
model = keras.Model(inputs=inputs, outputs=outputs)

Note: Only dicts, lists, and tuples of input tensors are supported. Nested inputs are not supported (e.g. lists of list or dicts of dict).

A new Functional API model can also be created by using the intermediate tensors. This enables you to quickly extract sub-components of the model.

Example:

inputs = keras.Input(shape=(None, None, 3))
processed = keras.layers.RandomCrop(width=128, height=128)(inputs)
conv = keras.layers.Conv2D(filters=32, kernel_size=3)(processed)
pooling = keras.layers.GlobalAveragePooling2D()(conv)
feature = keras.layers.Dense(10)(pooling)

full_model = keras.Model(inputs, feature)
backbone = keras.Model(processed, conv)
activations = keras.Model(conv, feature)

Note that the backbone and activations models are not created with keras.Input objects, but with the tensors that originate from keras.Input objects. Under the hood, the layers and weights will be shared across these models, so that user can train the full_model, and use backbone or activations to do feature extraction. The inputs and outputs of the model can be nested structures of tensors as well, and the created models are standard Functional API models that support all the existing APIs.

By subclassing the Model class

In that case, you should define your layers in __init__() and you should implement the model’s forward pass in call().

class MyModel(keras.Model):
    def __init__(self):
        super().__init__()
        self.dense1 = keras.layers.Dense(32, activation="relu")
        self.dense2 = keras.layers.Dense(5, activation="softmax")

    def call(self, inputs):
        x = self.dense1(inputs)
        return self.dense2(x)

model = MyModel()

If you subclass Model, you can optionally have a training argument (boolean) in call(), which you can use to specify a different behavior in training and inference:

class MyModel(keras.Model):
    def __init__(self):
        super().__init__()
        self.dense1 = keras.layers.Dense(32, activation="relu")
        self.dense2 = keras.layers.Dense(5, activation="softmax")
        self.dropout = keras.layers.Dropout(0.5)

    def call(self, inputs, training=False):
        x = self.dense1(inputs)
        x = self.dropout(x, training=training)
        return self.dense2(x)

model = MyModel()

Once the model is created, you can config the model with losses and metrics with model.compile(), train the model with model.fit(), or use the model to do prediction with model.predict().

With the Sequential class

In addition, keras.Sequential is a special case of model where the model is purely a stack of single-input, single-output layers.

model = keras.Sequential([
    keras.Input(shape=(None, None, 3)),
    keras.layers.Conv2D(filters=32, kernel_size=3),
])

Parameters:

  • generator_layers (list[int])
  • selector_layers (list[int])
  • feature_indices (list[tuple[int, int]], default=None)
  • immutable_mask (<class 'jax.Array'>, default=None)
  • pred_fn (typing.Callable, default=None)
  • alpha (<class 'float'>, default=0.0001) – Sparsity regularization
  • tau (<class 'float'>, default=0.7)
  • seed (<class 'int'>, default=None)
  • kwargs

Discretizer

relax.methods.l2c.qcut

[source]

relax.methods.l2c.qcut (x, q, axis=0)

Quantile binning.

Parameters:

  • x (<class 'jax.Array'>) – Input array
  • q (<class 'int'>) – Number of quantiles
  • axis (<class 'int'>, default=0) – Axis to quantile

Returns:

    (tuple[jax.Array, jax.Array]) – (digitized array, quantiles)

digitized, quantiles = qcut(jnp.arange(10), 4)
assert digitized.shape == (10,)
assert quantiles.shape == (3,)
assert jnp.allclose(
    digitized, jnp.array([0,0,0,1,1,2,2,3,3,3])
)

quantiles_true = jnp.array([0, 2.25, 4.5, 6.75, 9])
assert jnp.allclose(
    quantiles, quantiles_true[1:-1]
)
x_empty = jnp.array([])
q = 2
digitized_empty, quantiles_empty = qcut(x_empty, q)
assert digitized_empty.size == 0 and quantiles_empty.size == 0
# Test with single element array
x_single = jnp.array([1])
digitized_single, quantiles_single = qcut(x_single, q)
assert digitized_single.size == 1 and quantiles_single.size == 0

# Test with large q value
xs = jnp.array([1, 2, 3, 4, 5, 6])
q_large = 10
_, quantiles_large = qcut(xs, q_large)
assert len(quantiles_large) == q_large - 1

relax.methods.l2c.qcut_inverse

[source]

relax.methods.l2c.qcut_inverse (digitized, quantiles)

Inverse of qcut.

Parameters:

  • digitized (<class 'jax.Array'>) – Digitized One-Hot Encoding Array
  • quantiles (<class 'jax.Array'>) – Quantiles

Returns:

    (<class 'jax.Array'>)

digitized, quantiles = qcut(jnp.arange(10), 4)
ohe_digitized = jax.nn.one_hot(digitized, 4)
# continuous feats
quantiles_inv = qcut_inverse(ohe_digitized, jnp.arange(4))
assert quantiles_inv.shape == (10, 1)
# discrete feats
quantiles_inv = qcut_inverse(ohe_digitized, jnp.identity(4))
assert jnp.array_equal(quantiles_inv, ohe_digitized)

relax.methods.l2c.cut_quantiles

[source]

relax.methods.l2c.cut_quantiles (quantiles, xs)

Parameters:

  • quantiles (<class 'jax.Array'>) – Quantiles
  • xs (<class 'jax.Array'>) – Input array

relax.methods.l2c.discretize_xs

[source]

relax.methods.l2c.discretize_xs (xs, is_categorical_and_indices, q=4)

Discretize continuous features.

Parameters:

  • xs (<class 'jax.Array'>) – Input array
  • is_categorical_and_indices (list[tuple[bool, tuple[int, int]]]) – Features list
  • q (<class 'int'>, default=4) – Number of quantiles

Returns:

    (tuple[list[jax.Array], list[jax.Array], list[jax.Array], list[list[int, int]]]) – (discretized array, indices_and_quantiles_and_mid)

dm = relax.load_data("dummy")
xs, ys = dm['train']
is_categorical_and_indices = [
    (feat.is_categorical, indices) for feat, indices in zip(dm.features, dm.features.feature_indices)
]
discretized_xs, quantiles_feats, mid_quantiles, feature_indices = discretize_xs(xs, is_categorical_and_indices)
assert len(discretized_xs) == len(is_categorical_and_indices)
assert all(discretized_xs[i].shape[1] == 4 for i in range(len(discretized_xs)))

assert len(quantiles_feats) == len(is_categorical_and_indices)
assert all(len(quantiles_feats[i]) == 3 for i in range(len(quantiles_feats)))
assert len(mid_quantiles) == len(is_categorical_and_indices)
assert all(len(mid_quantiles[i]) == 4 for i in range(len(mid_quantiles)))

relax.methods.l2c.Discretizer

[source]

class relax.methods.l2c.Discretizer (is_cat_and_indices, q=4)

Discretize continuous features.

Parameters:

  • is_cat_and_indices (list[tuple[bool, tuple[int, int]]]) – Features list
  • q (<class 'int'>, default=4) – Number of quantiles
dm = relax.load_data("adult")
xs, ys = dm['train']
is_categorical_and_indices = [
    (feat.is_categorical, indices) for feat, indices in zip(dm.features, dm.features.feature_indices)
]

dis = Discretizer(is_categorical_and_indices)
dis.fit(xs)
digitized_xs_1 = dis.transform(xs)
assert digitized_xs_1.shape == (xs.shape[0], 35)
# assert jnp.array_equal(jnp.concatenate(discretized_xs, axis=-1), digitized_xs_1)
inversed_xs = dis.inverse_transform(digitized_xs_1)
assert xs.shape == inversed_xs.shape
# assert jnp.unique(inversed_xs).size == xs.shape[1] * 4

ml_module = relax.load_ml_module("adult")
pred_fn = dis.get_pred_fn(ml_module.pred_fn)
# digitized_xs_1 = split_xs(xs)
y = pred_fn(digitized_xs_1)
assert y.shape == (xs.shape[0], 2)

def f(x, y):
    y_pred = pred_fn(x)
    return jnp.mean((y_pred - y) ** 2)

grad = jax.grad(f)(digitized_xs_1, ys)
assert grad.shape == digitized_xs_1.shape

L2C Module

relax.methods.l2c.L2CConfig

[source]

class relax.methods.l2c.L2CConfig (generator_layers=[64, 64, 64], selector_layers=[64], lr=0.001, opt_name=‘adam’, alpha=0.0001, tau=0.7, q=4)

Base class for all config classes.

Parameters:

  • generator_layers (list[int], default=[64, 64, 64]) – Generator MLP layers.
  • selector_layers (list[int], default=[64]) – Selector MLP layers.
  • lr (float, default=0.001) – Model learning rate.
  • opt_name (str, default=adam) – Optimizer name of training L2C.
  • alpha (float, default=0.0001) – Sparsity regularization.
  • tau (float, default=0.7) – Temperature for the Gumbel softmax.
  • q (int, default=4) – Number of quantiles.

relax.methods.l2c.L2C

[source]

class relax.methods.l2c.L2C (config=None, l2c_model=None, name=‘l2c’)

Base class for parametric counterfactual modules.

Methods

[source]

set_apply_constraints_fn (apply_constraints_fn)

[source]

set_compute_reg_loss_fn (compute_reg_loss_fn)

[source]

apply_constraints (*args, **kwargs)

[source]

compute_reg_loss (*args, **kwargs)

[source]

save (path)

[source]

load_from_path (path)

[source]

before_generate_cf (*args, **kwargs)

generate_cf (*args, **kwargs)

dm = relax.load_data('adult')
ml_module = relax.load_ml_module('adult')
l2c = L2C()
exp = relax.generate_cf_explanations(
    l2c, dm, ml_module.pred_fn,
)
Epoch 1/10
191/191 ━━━━━━━━━━━━━━━━━━━━ 5s 14ms/step - loss: 0.8767   
Epoch 2/10
191/191 ━━━━━━━━━━━━━━━━━━━━ 0s 777us/step - loss: 0.1725     
Epoch 3/10
191/191 ━━━━━━━━━━━━━━━━━━━━ 0s 769us/step - loss: 0.1539    
Epoch 4/10
191/191 ━━━━━━━━━━━━━━━━━━━━ 0s 790us/step - loss: 0.1462    
Epoch 5/10
191/191 ━━━━━━━━━━━━━━━━━━━━ 0s 799us/step - loss: 0.1434    
Epoch 6/10
191/191 ━━━━━━━━━━━━━━━━━━━━ 0s 780us/step - loss: 0.1389    
Epoch 7/10
191/191 ━━━━━━━━━━━━━━━━━━━━ 0s 772us/step - loss: 0.1383    
Epoch 8/10
191/191 ━━━━━━━━━━━━━━━━━━━━ 0s 790us/step - loss: 0.1372    
Epoch 9/10
191/191 ━━━━━━━━━━━━━━━━━━━━ 0s 776us/step - loss: 0.1360    
Epoch 10/10
191/191 ━━━━━━━━━━━━━━━━━━━━ 0s 787us/step - loss: 0.1345    
relax.benchmark_cfs([exp])
acc validity proximity
adult l2c 0.827124 0.98099 6.412683
partial_gen = ft.partial(l2c.generate_cf, pred_fn=ml_module.pred_fn)
cfs = jax.vmap(partial_gen)(dm.xs, rng_key=jrand.split(jrand.PRNGKey(0), dm.xs.shape[0]))