LIME


source

pairwise_distances

 pairwise_distances (x:jax.Array, y:jax.Array, metric:str='euclidean')
Type Default Details
x Array [n, k]
y Array [m, k]
metric str euclidean Supports “euclidean” and “cosine”
Returns Array [n, m]

This function is similar to sklearn.metrics.pairwise_distances.

from sklearn.metrics import pairwise_distances as sk_pairwise_distances

pairwise_distances is faster than sklearn’s implementation.

X = np.random.normal(size=(1000, 28 * 28))
Y = np.random.normal(size=(1000, 28 * 28))

def benchmark_pairwise_distances(metric):
    print(f"[{metric}] Sklearn pairwise_distances:")

    print(f"[{metric}] JAX pairwise_distances:")

    assert jnp.allclose(
        sk_pairwise_distances(X, Y, metric=metric),
        pairwise_distances(X, Y, metric=metric)
    )

benchmark_pairwise_distances("euclidean")
benchmark_pairwise_distances("cosine")
[euclidean] Sklearn pairwise_distances:
28.6 ms ± 6.82 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
[euclidean] JAX pairwise_distances:
6.27 ms ± 3.02 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
[cosine] Sklearn pairwise_distances:
29.2 ms ± 6.07 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
[cosine] JAX pairwise_distances:
6.4 ms ± 2.82 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

source

gaussian_perturb_func

 gaussian_perturb_func (x:jax.Array, prng_key:<function PRNGKey>,
                        **kwargs)

Gaussian perturbation function for LIME


source

bernoulli_perturb_func

 bernoulli_perturb_func (x:jax.Array, prng_key:<function PRNGKey>,
                         **kwargs)

Bernoulli perturbation function for LIME

X = np.random.normal(size=(1, 28 * 28))
b_perturbed = _perturb_data(X, 100, bernoulli_perturb_func, jrand.PRNGKey(42))
g_perturbed = _perturb_data(X, 100, gaussian_perturb_func, jrand.PRNGKey(42))
assert b_perturbed.shape == (101, 28 * 28)
assert g_perturbed.shape == (101, 28 * 28)

source

exp_kernel_func

 exp_kernel_func (dists:jax.Array, kernel_width:float)

Exponential kernel function for LIME

distances = pairwise_distances(g_perturbed, X)
The slowest run took 389.12 times longer than the fastest. This could mean that an intermediate result is being cached.
273 µs ± 657 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

source

LimeBase

 LimeBase (func:Callable, additional_func_args:Dict=None,
           model_regressor=None, kernal_func:Callable=None,
           kernel_width:float=None, perturb_func:Callable=None,
           input_paramter_name:str='x',
           pairwise_distances_metric:str='euclidean')

Initialize self. See help(type(self)) for accurate signature.

Type Default Details
func typing.Callable A black-box function to be explained
additional_func_args typing.Dict None Additional arguments for the black-box function
model_regressor NoneType None Linear regressor to use in explanation
kernal_func typing.Callable None Kernel function for computing similarity
kernel_width float None Kernel width for computing similarity. Defaults to (n_features * 0.75)
perturb_func typing.Callable None Perturbation function for generating perturbed instances
input_paramter_name str x Name of the input parameter for the black-box function
pairwise_distances_metric str euclidean Metric for computing pairwise distances

Tests

from sklearn.datasets import make_regression
import haiku as hk
xs, ys = make_regression(n_samples=500, n_features=20)
linear_model = LinearModel()
linear_model.fit(xs, ys)
lime = LimeBase(linear_model.predict, input_paramter_name="X")
# fit a simple haiku model
def model(x):
    mlp = hk.Sequential([
        hk.Linear(10),
        jax.nn.relu,
        hk.Linear(10),
        jax.nn.relu,
        hk.Linear(1),
    ])
    return mlp(x)


def init(x):
    net = hk.without_apply_rng(hk.transform(model))
    opt = optax.sgd(1e-1)
    params = net.init(jrand.PRNGKey(42), x)
    opt_state = opt.init(params)
    return net, opt, params, opt_state

def loss(params, net, x, y):
    pred = net.apply(params, x)
    return jnp.mean((pred - y) ** 2)

@partial(jax.jit, static_argnums=(2,3))
def update(
    params: hk.Params,
    opt_state: optax.OptState,
    net: hk.Transformed,
    opt: optax.GradientTransformation,
    x: jnp.ndarray,
    y: jnp.ndarray
):
    grads = jax.grad(loss)(params, net, x, y)
    updates, opt_state = opt.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state

def train(
    net: hk.Transformed,
    opt: optax.GradientTransformation,
    params: hk.Params,
    opt_state: optax.OptState,
    x: jnp.ndarray,
    y: jnp.ndarray,
    n_epochs: int = 100,
    batch_size: int = 32,
):
    n_samples = x.shape[0]
    for _ in range(n_epochs):
        for i in range(0, n_samples, batch_size):
            x_batch = x[i:i+batch_size]
            y_batch = y[i:i+batch_size]
            params, opt_state = update(params, opt_state, net, opt, x_batch, y_batch)
    return params

def fit_a_model(
    X: Array,
    y: Array,
):
    net, opt, params, opt_state = init(X)
    params = train(net, opt, params, opt_state, X, y)
    return net, params
net, params = fit_a_model(xs, ys)
/home/birk/mambaforge-pypy3/envs/nbdev2/lib/python3.7/site-packages/haiku/_src/base.py:515: UserWarning: Explicitly requested dtype float64 requested in zeros is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  param = init(shape, dtype)
net.apply(params, xs[:10])
DeviceArray([[0.24004635],
             [0.24004635],
             [0.24004635],
             [0.24004635],
             [0.24004635],
             [0.24004635],
             [0.24004635],
             [0.24004635],
             [0.24004635],
             [0.24004635]], dtype=float32)
kernel_func = partial(exp_kernel_func, kernel_width=2 * 0.75)

_lime_attribute_single_instance(
    X[:1],
    1000,
    jrand.PRNGKey(42),
    net.apply,
    additional_func_args={"params": params},
    input_paramter_name="x",
    perturb_func=gaussian_perturb_func,
    kernel_func=kernel_func,
    model_regressor=Ridge(alpha=1),
    pairwise_distances_metric="euclidean",
)
(DeviceArray([-5.97378996e-04,  4.81305433e-05,  1.55211031e-03,
              -3.89112538e-04, -1.21736361e-04, -1.05082065e-04,
               1.35615395e-04, -4.70238097e-04,  5.06596552e-05,
               5.96614438e-04,  8.60058644e-04,  1.57593074e-03,
              -1.35515491e-03,  7.79002556e-04,  8.51082150e-04,
               1.22845857e-04,  6.43223408e-04,  6.08801842e-04,
              -3.02861667e-06,  1.18552020e-03], dtype=float32),
 DeviceArray([0.24033982], dtype=float32))
lime = LimeBase(
    func=net.apply,
    additional_func_args={"params": params},
)
lime.attribute(X)
(DeviceArray([[ 0.00221086,  0.00581249,  0.00237516, ...,  0.0043735 ,
                0.00298327,  0.00597063],
              [ 0.00241475,  0.00598198,  0.00254859, ...,  0.00455689,
                0.00296534,  0.00607178],
              [-0.00225598, -0.00591312, -0.00256158, ..., -0.00448116,
               -0.00301292, -0.00606422],
              ...,
              [ 0.00223549,  0.00583388,  0.00241553, ...,  0.00438501,
                0.00298516,  0.00600691],
              [ 0.00221306,  0.00573398,  0.00225197, ...,  0.00415604,
                0.00285654,  0.00581799],
              [ 0.00242912,  0.00596596,  0.00257836, ...,  0.00450296,
                0.0030849 ,  0.00620169]], dtype=float32),
 DeviceArray([[0.24730496],
              [0.24756414],
              [0.23204625],
              [0.24731565],
              [0.24734785],
              [0.24736024],
              [0.24726894],
              [0.24740964],
              [0.24748607],
              [0.2472707 ],
              [0.24726589],
              [0.23272656],
              [0.2474309 ],
              [0.24768972],
              [0.24720442],
              [0.23258576],
              [0.24747995],
              [0.23265731],
              [0.23216027],
              [0.23245938],
              [0.23232539],
              [0.23272593],
              [0.24734417],
              [0.232578  ],
              [0.2313688 ],
              [0.232699  ],
              [0.23166694],
              [0.2476369 ],
              [0.24740493],
              [0.23260298],
              [0.2476171 ],
              [0.23259275],
              [0.2327499 ],
              [0.24747556],
              [0.23274088],
              [0.23278277],
              [0.23274417],
              [0.23259158],
              [0.23276031],
              [0.23244202],
              [0.23248667],
              [0.24736097],
              [0.2475133 ],
              [0.23159517],
              [0.23230171],
              [0.24782038],
              [0.23145361],
              [0.23132221],
              [0.23273085],
              [0.24778697],
              [0.2318807 ],
              [0.23234296],
              [0.2473141 ],
              [0.23186037],
              [0.24774271],
              [0.24740867],
              [0.2321608 ],
              [0.24721904],
              [0.24765676],
              [0.23248613],
              [0.2325826 ],
              [0.23253548],
              [0.24741523],
              [0.23102273],
              [0.2324856 ],
              [0.24752925],
              [0.24744649],
              [0.24753611],
              [0.24741869],
              [0.24742423],
              [0.23231731],
              [0.24737181],
              [0.23265643],
              [0.24769214],
              [0.23244457],
              [0.23207039],
              [0.23221391],
              [0.24714589],
              [0.24750586],
              [0.24760035],
              [0.23261759],
              [0.23214899],
              [0.23229751],
              [0.2317236 ],
              [0.24728425],
              [0.24737929],
              [0.24734789],
              [0.24735087],
              [0.24734165],
              [0.23252094],
              [0.24735183],
              [0.23268992],
              [0.2472564 ],
              [0.23250158],
              [0.24733323],
              [0.23173681],
              [0.24738492],
              [0.24749885],
              [0.23246227],
              [0.2326699 ],
              [0.24735057],
              [0.24739408],
              [0.23209122],
              [0.24731648],
              [0.23272413],
              [0.232523  ],
              [0.24749972],
              [0.24729979],
              [0.23263235],
              [0.23183963],
              [0.23184806],
              [0.23263596],
              [0.24731863],
              [0.23263787],
              [0.24730252],
              [0.2472181 ],
              [0.24739836],
              [0.24762246],
              [0.24867377],
              [0.24735266],
              [0.23101482],
              [0.24794899],
              [0.23271923],
              [0.23227747],
              [0.23191595],
              [0.24739477],
              [0.23250002],
              [0.24734001],
              [0.24730746],
              [0.24746673],
              [0.24722606],
              [0.24733111],
              [0.23255974],
              [0.24728492],
              [0.23227465],
              [0.24736558],
              [0.24749716],
              [0.23218371],
              [0.23240303],
              [0.24745671],
              [0.23237772],
              [0.24727368],
              [0.23259522],
              [0.23260601],
              [0.23210865],
              [0.23270965],
              [0.23201807],
              [0.2472552 ],
              [0.2324636 ],
              [0.24743754],
              [0.23217565],
              [0.24725908],
              [0.23177943],
              [0.23214754],
              [0.23217827],
              [0.24719004],
              [0.23194984],
              [0.24750523],
              [0.23241459],
              [0.23250034],
              [0.2324958 ],
              [0.24732801],
              [0.24729869],
              [0.23244049],
              [0.24737254],
              [0.23204154],
              [0.23235755],
              [0.24782768],
              [0.24725914],
              [0.24750242],
              [0.23217706],
              [0.23235546],
              [0.24741676],
              [0.23215216],
              [0.24742627],
              [0.24739274],
              [0.23202284],
              [0.23235677],
              [0.23258376],
              [0.23233262],
              [0.2325433 ],
              [0.2473211 ],
              [0.23198971],
              [0.24748974],
              [0.24733643],
              [0.23230669],
              [0.24724506],
              [0.23218584],
              [0.24735439],
              [0.24746102],
              [0.24732526],
              [0.2325736 ],
              [0.23248109],
              [0.24725024],
              [0.24732903],
              [0.24728669],
              [0.24792492],
              [0.24741988],
              [0.23169258],
              [0.23255627],
              [0.24753469],
              [0.24741082],
              [0.23244117],
              [0.23260221],
              [0.2316141 ],
              [0.24735719],
              [0.24743626],
              [0.24751961],
              [0.23250861],
              [0.23179816],
              [0.24762237],
              [0.24788071],
              [0.2325231 ],
              [0.2324555 ],
              [0.2325428 ],
              [0.23237003],
              [0.2323332 ],
              [0.24787727],
              [0.23265064],
              [0.24732669],
              [0.24749607],
              [0.24766587],
              [0.2480405 ],
              [0.23218127],
              [0.24730903],
              [0.23179471],
              [0.2473682 ],
              [0.24743044],
              [0.247412  ],
              [0.23198949],
              [0.2474296 ],
              [0.23019618],
              [0.24736585],
              [0.24736209],
              [0.24728309],
              [0.24746317],
              [0.23231001],
              [0.23207894],
              [0.24735352],
              [0.2322949 ],
              [0.24727553],
              [0.24729888],
              [0.23229258],
              [0.23158833],
              [0.24727146],
              [0.23261805],
              [0.24736366],
              [0.23235823],
              [0.24729802],
              [0.24752894],
              [0.24734056],
              [0.23284838],
              [0.24715085],
              [0.2323496 ],
              [0.24730666],
              [0.2320316 ],
              [0.23239292],
              [0.24736963],
              [0.24747422],
              [0.24727249],
              [0.23237279],
              [0.24724723],
              [0.24729113],
              [0.24737135],
              [0.24733087],
              [0.24724013],
              [0.24737415],
              [0.2327153 ],
              [0.24791467],
              [0.24732663],
              [0.24748352],
              [0.23251979],
              [0.24714312],
              [0.24755022],
              [0.2473973 ],
              [0.23273005],
              [0.23246035],
              [0.24732141],
              [0.23213631],
              [0.24721566],
              [0.23201378],
              [0.23221391],
              [0.24740039],
              [0.23237988],
              [0.23115295],
              [0.23251712],
              [0.23238027],
              [0.23267289],
              [0.24774548],
              [0.23251162],
              [0.23228587],
              [0.23244764],
              [0.24746987],
              [0.23209953],
              [0.24765822],
              [0.23273446],
              [0.23225085],
              [0.24744767],
              [0.23257345],
              [0.24717891],
              [0.23226202],
              [0.23245876],
              [0.24770094],
              [0.23234405],
              [0.24729179],
              [0.23253849],
              [0.23176861],
              [0.23238832],
              [0.24772736],
              [0.23276523],
              [0.24727891],
              [0.24722643],
              [0.23272316],
              [0.24743803],
              [0.2473214 ],
              [0.2324927 ],
              [0.24720925],
              [0.2319577 ],
              [0.24736662],
              [0.24720564],
              [0.2473594 ],
              [0.24746218],
              [0.2473249 ],
              [0.23269431],
              [0.23267029],
              [0.24746528],
              [0.24735796],
              [0.23275968],
              [0.24729164],
              [0.24727115],
              [0.24726492],
              [0.23244949],
              [0.23239763],
              [0.23260587],
              [0.2317266 ],
              [0.24721311],
              [0.24791732],
              [0.23277605],
              [0.24746867],
              [0.2326223 ],
              [0.24716686],
              [0.23254141],
              [0.24743882],
              [0.24725032],
              [0.23271337],
              [0.23271807],
              [0.24727352],
              [0.23150992],
              [0.24748631],
              [0.24734168],
              [0.24780883],
              [0.24741678],
              [0.24733034],
              [0.23234318],
              [0.24731077],
              [0.2472419 ],
              [0.2473333 ],
              [0.23262657],
              [0.24725671],
              [0.24735816],
              [0.23228347],
              [0.2473127 ],
              [0.24725196],
              [0.23249117],
              [0.23216985],
              [0.24734968],
              [0.23251817],
              [0.23218517],
              [0.23229903],
              [0.24736002],
              [0.23227802],
              [0.24734715],
              [0.24728952],
              [0.23249188],
              [0.23238972],
              [0.24732132],
              [0.23275092],
              [0.23262392],
              [0.23213783],
              [0.24741286],
              [0.24725187],
              [0.23273003],
              [0.23257877],
              [0.24740161],
              [0.24715275],
              [0.24721095],
              [0.23270902],
              [0.24722742],
              [0.2319201 ],
              [0.23252024],
              [0.23254202],
              [0.23239242],
              [0.23167515],
              [0.24744089],
              [0.23231809],
              [0.23276606],
              [0.23184034],
              [0.23192035],
              [0.23264523],
              [0.24742319],
              [0.23254554],
              [0.23217438],
              [0.24735396],
              [0.23255692],
              [0.23227966],
              [0.232184  ],
              [0.24756455],
              [0.23222198],
              [0.24758385],
              [0.24754198],
              [0.24729684],
              [0.23244324],
              [0.24749175],
              [0.23269522],
              [0.23192582],
              [0.2326539 ],
              [0.23266825],
              [0.24736997],
              [0.24730414],
              [0.2324066 ],
              [0.24725583],
              [0.24741372],
              [0.24747895],
              [0.23201811],
              [0.24729444],
              [0.247381  ],
              [0.2326003 ],
              [0.23236334],
              [0.24736139],
              [0.2321806 ],
              [0.23245579],
              [0.23235963],
              [0.23238862],
              [0.2472786 ],
              [0.23169187],
              [0.24740751],
              [0.23263946],
              [0.23267938],
              [0.24756868],
              [0.23228672],
              [0.24769364],
              [0.24736963],
              [0.23052931],
              [0.23218985],
              [0.23187762],
              [0.23254694],
              [0.23233788],
              [0.2312679 ],
              [0.23229599],
              [0.23235682],
              [0.23257495],
              [0.23236491],
              [0.24776962],
              [0.24718513],
              [0.24725026],
              [0.23261815],
              [0.2471972 ],
              [0.24723697],
              [0.2323229 ],
              [0.23140587],
              [0.2477279 ],
              [0.23209357],
              [0.24751735],
              [0.24730821],
              [0.24750674],
              [0.23272803],
              [0.24727444],
              [0.24727756],
              [0.24726333],
              [0.24735498],
              [0.23130915],
              [0.24735644],
              [0.2326549 ],
              [0.23219867],
              [0.23273338],
              [0.23268846],
              [0.24724345],
              [0.2473797 ],
              [0.24738654],
              [0.2314789 ],
              [0.24723507],
              [0.23221129],
              [0.24730465],
              [0.23243599],
              [0.232267  ],
              [0.23234174],
              [0.24730346],
              [0.2472916 ],
              [0.24730274],
              [0.2325149 ],
              [0.24734567],
              [0.24727803],
              [0.24726497],
              [0.23281087],
              [0.24736117],
              [0.23239669],
              [0.23156339],
              [0.2475236 ],
              [0.24708891],
              [0.2476187 ]], dtype=float32))