= jnp.array([[0, 1], [1, 2], [2, 3], [3, 5]])
start_end = jrand.normal(jrand.PRNGKey(0), (4, 5),)
xs = jrand.normal(jrand.PRNGKey(1), (4, 5),)
cfs = jrand.uniform(jrand.PRNGKey(2), (4, 4),)
prob
# split xs into 4 parts according to start_end
= jnp.split(xs, start_end[:-1, 1], axis=1)
xs_split = jnp.split(cfs, start_end[:-1, 1], axis=1)
cfs_split = jnp.split(prob, start_end.shape[0], axis=1)
prob_split
def perturb(x, cf, prob):
return x * (1 - prob) + cf * prob
= jax.tree_util.tree_map(
perturbed
perturb, xs_split, cfs_split, prob_split )
L2C
L2C Model
relax.utils.gumbel_softmax
relax.utils.gumbel_softmax (key, logits, tau)
The Gumbel softmax function.
Parameters:
- key (
<function PRNGKey at 0x7fe228f3ab00>
) – Random key - logits (
<class 'jax.Array'>
) – Logits for each class. Shape (batch_size, num_classes) - tau (
<class 'float'>
) – Temperature for the Gumbel softmax
relax.methods.sphere.sample_categorical
relax.methods.sphere.sample_categorical (key, logits, tau, training=True)
Sample from a categorical distribution.
Parameters:
- key (
<function PRNGKey at 0x7fe228f3ab00>
) – Random key - logits (
<class 'jax.Array'>
) – Logits for each class. Shape (batch_size, num_classes) - tau (
<class 'float'>
) – Temperature for the Gumbel softmax - training (
<class 'bool'>
, default=True) – Apply gumbel softmax if training
= jnp.array([[2.0, 1.0, 0.1], [1.0, 2.0, 3.0]])
logits = jrand.PRNGKey(0)
key = sample_categorical(key, logits, tau=0.5, training=True)
output assert output.shape == logits.shape
assert jnp.allclose(output.sum(axis=-1), 1.0)
# low temperature -> one-hot
= sample_categorical(key, logits, tau=0.01, training=True)
output assert jnp.array_equal(
=-1), logits.argmax(axis=-1)
output.argmax(axis
)# high temperature -> uniform
= sample_categorical(key, logits, tau=100, training=True)
output assert jnp.max(output) - jnp.min(output) < 0.5
= sample_categorical(key, logits, tau=0.5, training=False)
output assert output.shape == logits.shape
assert jnp.array_equal(
=-1), logits.argmax(axis=-1)
output.argmax(axis )
relax.methods.l2c.sample_bernouli
relax.methods.l2c.sample_bernouli (key, prob, tau, training=True)
“Sample from a bernouli distribution.
Parameters:
- key (
<function PRNGKey at 0x7fe228f3ab00>
) – Random key - prob (
<class 'jax.Array'>
) – Logits for each class. Shape (batch_size, 1) - tau (
<class 'float'>
) – Temperature for the Gumbel softmax - training (
<class 'bool'>
, default=True) – Apply gumbel softmax if training
Returns:
(<class 'jax.Array'>
)
relax.methods.l2c.split_fn
relax.methods.l2c.split_fn (feature_indices)
= [(0, 1), (1, 2), (2, 3), (3, 5)]
start_end = split_fn(start_end)
split_xs, split_prob assert len(split_xs(xs)) == len(start_end)
assert len(split_prob(prob)) == len(start_end)
relax.methods.l2c.L2CModel
class relax.methods.l2c.L2CModel (generator_layers, selector_layers, feature_indices=None, immutable_mask=None, pred_fn=None, alpha=0.0001, tau=0.7, seed=None, **kwargs)
A model grouping layers into an object with training/inference features.
There are three ways to instantiate a Model
:
With the “Functional API”
You start from Input
, you chain layer calls to specify the model’s forward pass, and finally you create your model from inputs and outputs:
= keras.Input(shape=(37,))
inputs = keras.layers.Dense(32, activation="relu")(inputs)
x = keras.layers.Dense(5, activation="softmax")(x)
outputs = keras.Model(inputs=inputs, outputs=outputs) model
Note: Only dicts, lists, and tuples of input tensors are supported. Nested inputs are not supported (e.g. lists of list or dicts of dict).
A new Functional API model can also be created by using the intermediate tensors. This enables you to quickly extract sub-components of the model.
Example:
= keras.Input(shape=(None, None, 3))
inputs = keras.layers.RandomCrop(width=128, height=128)(inputs)
processed = keras.layers.Conv2D(filters=32, kernel_size=3)(processed)
conv = keras.layers.GlobalAveragePooling2D()(conv)
pooling = keras.layers.Dense(10)(pooling)
feature
= keras.Model(inputs, feature)
full_model = keras.Model(processed, conv)
backbone = keras.Model(conv, feature) activations
Note that the backbone
and activations
models are not created with keras.Input
objects, but with the tensors that originate from keras.Input
objects. Under the hood, the layers and weights will be shared across these models, so that user can train the full_model
, and use backbone
or activations
to do feature extraction. The inputs and outputs of the model can be nested structures of tensors as well, and the created models are standard Functional API models that support all the existing APIs.
By subclassing the Model
class
In that case, you should define your layers in __init__()
and you should implement the model’s forward pass in call()
.
class MyModel(keras.Model):
def __init__(self):
super().__init__()
self.dense1 = keras.layers.Dense(32, activation="relu")
self.dense2 = keras.layers.Dense(5, activation="softmax")
def call(self, inputs):
= self.dense1(inputs)
x return self.dense2(x)
= MyModel() model
If you subclass Model
, you can optionally have a training
argument (boolean) in call()
, which you can use to specify a different behavior in training and inference:
class MyModel(keras.Model):
def __init__(self):
super().__init__()
self.dense1 = keras.layers.Dense(32, activation="relu")
self.dense2 = keras.layers.Dense(5, activation="softmax")
self.dropout = keras.layers.Dropout(0.5)
def call(self, inputs, training=False):
= self.dense1(inputs)
x = self.dropout(x, training=training)
x return self.dense2(x)
= MyModel() model
Once the model is created, you can config the model with losses and metrics with model.compile()
, train the model with model.fit()
, or use the model to do prediction with model.predict()
.
With the Sequential
class
In addition, keras.Sequential
is a special case of model where the model is purely a stack of single-input, single-output layers.
= keras.Sequential([
model =(None, None, 3)),
keras.Input(shape=32, kernel_size=3),
keras.layers.Conv2D(filters ])
Parameters:
- generator_layers (
list[int]
) - selector_layers (
list[int]
) - feature_indices (
list[tuple[int, int]]
, default=None) - immutable_mask (
<class 'jax.Array'>
, default=None) - pred_fn (
typing.Callable
, default=None) - alpha (
<class 'float'>
, default=0.0001) – Sparsity regularization - tau (
<class 'float'>
, default=0.7) - seed (
<class 'int'>
, default=None) - kwargs
Discretizer
relax.methods.l2c.qcut
relax.methods.l2c.qcut (x, q, axis=0)
Quantile binning.
Parameters:
- x (
<class 'jax.Array'>
) – Input array - q (
<class 'int'>
) – Number of quantiles - axis (
<class 'int'>
, default=0) – Axis to quantile
Returns:
(tuple[jax.Array, jax.Array]
) – (digitized array, quantiles)
= qcut(jnp.arange(10), 4)
digitized, quantiles assert digitized.shape == (10,)
assert quantiles.shape == (3,)
assert jnp.allclose(
0,0,0,1,1,2,2,3,3,3])
digitized, jnp.array([
)
= jnp.array([0, 2.25, 4.5, 6.75, 9])
quantiles_true assert jnp.allclose(
1:-1]
quantiles, quantiles_true[
)= jnp.array([])
x_empty = 2
q = qcut(x_empty, q)
digitized_empty, quantiles_empty assert digitized_empty.size == 0 and quantiles_empty.size == 0
# Test with single element array
= jnp.array([1])
x_single = qcut(x_single, q)
digitized_single, quantiles_single assert digitized_single.size == 1 and quantiles_single.size == 0
# Test with large q value
= jnp.array([1, 2, 3, 4, 5, 6])
xs = 10
q_large = qcut(xs, q_large)
_, quantiles_large assert len(quantiles_large) == q_large - 1
relax.methods.l2c.qcut_inverse
relax.methods.l2c.qcut_inverse (digitized, quantiles)
Inverse of qcut.
Parameters:
- digitized (
<class 'jax.Array'>
) – Digitized One-Hot Encoding Array - quantiles (
<class 'jax.Array'>
) – Quantiles
Returns:
(<class 'jax.Array'>
)
= qcut(jnp.arange(10), 4)
digitized, quantiles = jax.nn.one_hot(digitized, 4)
ohe_digitized # continuous feats
= qcut_inverse(ohe_digitized, jnp.arange(4))
quantiles_inv assert quantiles_inv.shape == (10, 1)
# discrete feats
= qcut_inverse(ohe_digitized, jnp.identity(4))
quantiles_inv assert jnp.array_equal(quantiles_inv, ohe_digitized)
relax.methods.l2c.cut_quantiles
relax.methods.l2c.cut_quantiles (quantiles, xs)
Parameters:
- quantiles (
<class 'jax.Array'>
) – Quantiles - xs (
<class 'jax.Array'>
) – Input array
relax.methods.l2c.discretize_xs
relax.methods.l2c.discretize_xs (xs, is_categorical_and_indices, q=4)
Discretize continuous features.
Parameters:
- xs (
<class 'jax.Array'>
) – Input array - is_categorical_and_indices (
list[tuple[bool, tuple[int, int]]]
) – Features list - q (
<class 'int'>
, default=4) – Number of quantiles
Returns:
(tuple[list[jax.Array], list[jax.Array], list[jax.Array], list[list[int, int]]]
) – (discretized array, indices_and_quantiles_and_mid)
= relax.load_data("dummy")
dm = dm['train']
xs, ys = [
is_categorical_and_indices for feat, indices in zip(dm.features, dm.features.feature_indices)
(feat.is_categorical, indices)
]= discretize_xs(xs, is_categorical_and_indices)
discretized_xs, quantiles_feats, mid_quantiles, feature_indices assert len(discretized_xs) == len(is_categorical_and_indices)
assert all(discretized_xs[i].shape[1] == 4 for i in range(len(discretized_xs)))
assert len(quantiles_feats) == len(is_categorical_and_indices)
assert all(len(quantiles_feats[i]) == 3 for i in range(len(quantiles_feats)))
assert len(mid_quantiles) == len(is_categorical_and_indices)
assert all(len(mid_quantiles[i]) == 4 for i in range(len(mid_quantiles)))
relax.methods.l2c.Discretizer
class relax.methods.l2c.Discretizer (is_cat_and_indices, q=4)
Discretize continuous features.
Parameters:
- is_cat_and_indices (
list[tuple[bool, tuple[int, int]]]
) – Features list - q (
<class 'int'>
, default=4) – Number of quantiles
= relax.load_data("adult")
dm = dm['train']
xs, ys = [
is_categorical_and_indices for feat, indices in zip(dm.features, dm.features.feature_indices)
(feat.is_categorical, indices)
]
= Discretizer(is_categorical_and_indices)
dis
dis.fit(xs)= dis.transform(xs)
digitized_xs_1 assert digitized_xs_1.shape == (xs.shape[0], 35)
# assert jnp.array_equal(jnp.concatenate(discretized_xs, axis=-1), digitized_xs_1)
= dis.inverse_transform(digitized_xs_1)
inversed_xs assert xs.shape == inversed_xs.shape
# assert jnp.unique(inversed_xs).size == xs.shape[1] * 4
= relax.load_ml_module("adult")
ml_module = dis.get_pred_fn(ml_module.pred_fn)
pred_fn # digitized_xs_1 = split_xs(xs)
= pred_fn(digitized_xs_1)
y assert y.shape == (xs.shape[0], 2)
def f(x, y):
= pred_fn(x)
y_pred return jnp.mean((y_pred - y) ** 2)
= jax.grad(f)(digitized_xs_1, ys)
grad assert grad.shape == digitized_xs_1.shape
L2C Module
relax.methods.l2c.L2CConfig
class relax.methods.l2c.L2CConfig (generator_layers=[64, 64, 64], selector_layers=[64], lr=0.001, opt_name=‘adam’, alpha=0.0001, tau=0.7, q=4)
Base class for all config classes.
Parameters:
- generator_layers (
list[int]
, default=[64, 64, 64]) – Generator MLP layers. - selector_layers (
list[int]
, default=[64]) – Selector MLP layers. - lr (
float
, default=0.001) – Model learning rate. - opt_name (
str
, default=adam) – Optimizer name of training L2C. - alpha (
float
, default=0.0001) – Sparsity regularization. - tau (
float
, default=0.7) – Temperature for the Gumbel softmax. - q (
int
, default=4) – Number of quantiles.
relax.methods.l2c.L2C
class relax.methods.l2c.L2C (config=None, l2c_model=None, name=‘l2c’)
Base class for parametric counterfactual modules.
Methods
set_apply_constraints_fn (apply_constraints_fn)
set_compute_reg_loss_fn (compute_reg_loss_fn)
apply_constraints (*args, **kwargs)
compute_reg_loss (*args, **kwargs)
save (path)
load_from_path (path)
before_generate_cf (*args, **kwargs)
generate_cf (*args, **kwargs)
= relax.load_data('adult')
dm = relax.load_ml_module('adult') ml_module
= L2C()
l2c = relax.generate_cf_explanations(
exp
l2c, dm, ml_module.pred_fn, )
Epoch 1/10
191/191 ━━━━━━━━━━━━━━━━━━━━ 5s 14ms/step - loss: 0.8767
Epoch 2/10
191/191 ━━━━━━━━━━━━━━━━━━━━ 0s 777us/step - loss: 0.1725
Epoch 3/10
191/191 ━━━━━━━━━━━━━━━━━━━━ 0s 769us/step - loss: 0.1539
Epoch 4/10
191/191 ━━━━━━━━━━━━━━━━━━━━ 0s 790us/step - loss: 0.1462
Epoch 5/10
191/191 ━━━━━━━━━━━━━━━━━━━━ 0s 799us/step - loss: 0.1434
Epoch 6/10
191/191 ━━━━━━━━━━━━━━━━━━━━ 0s 780us/step - loss: 0.1389
Epoch 7/10
191/191 ━━━━━━━━━━━━━━━━━━━━ 0s 772us/step - loss: 0.1383
Epoch 8/10
191/191 ━━━━━━━━━━━━━━━━━━━━ 0s 790us/step - loss: 0.1372
Epoch 9/10
191/191 ━━━━━━━━━━━━━━━━━━━━ 0s 776us/step - loss: 0.1360
Epoch 10/10
191/191 ━━━━━━━━━━━━━━━━━━━━ 0s 787us/step - loss: 0.1345
relax.benchmark_cfs([exp])
acc | validity | proximity | ||
---|---|---|---|---|
adult | l2c | 0.827124 | 0.98099 | 6.412683 |
= ft.partial(l2c.generate_cf, pred_fn=ml_module.pred_fn)
partial_gen = jax.vmap(partial_gen)(dm.xs, rng_key=jrand.split(jrand.PRNGKey(0), dm.xs.shape[0])) cfs