from sklearn.metrics import pairwise_distances as sk_pairwise_distances
LIME
pairwise_distances
pairwise_distances (x:jax.Array, y:jax.Array, metric:str='euclidean')
Type | Default | Details | |
---|---|---|---|
x | Array | [n, k] | |
y | Array | [m, k] | |
metric | str | euclidean | Supports “euclidean” and “cosine” |
Returns | Array | [n, m] |
This function is similar to sklearn.metrics.pairwise_distances.
pairwise_distances
is faster than sklearn’s implementation.
= np.random.normal(size=(1000, 28 * 28))
X = np.random.normal(size=(1000, 28 * 28))
Y
def benchmark_pairwise_distances(metric):
print(f"[{metric}] Sklearn pairwise_distances:")
print(f"[{metric}] JAX pairwise_distances:")
assert jnp.allclose(
=metric),
sk_pairwise_distances(X, Y, metric=metric)
pairwise_distances(X, Y, metric
)
"euclidean")
benchmark_pairwise_distances("cosine") benchmark_pairwise_distances(
[euclidean] Sklearn pairwise_distances:
28.6 ms ± 6.82 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
[euclidean] JAX pairwise_distances:
6.27 ms ± 3.02 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
[cosine] Sklearn pairwise_distances:
29.2 ms ± 6.07 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
[cosine] JAX pairwise_distances:
6.4 ms ± 2.82 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
gaussian_perturb_func
gaussian_perturb_func (x:jax.Array, prng_key:<function PRNGKey>, **kwargs)
Gaussian perturbation function for LIME
bernoulli_perturb_func
bernoulli_perturb_func (x:jax.Array, prng_key:<function PRNGKey>, **kwargs)
Bernoulli perturbation function for LIME
= np.random.normal(size=(1, 28 * 28))
X = _perturb_data(X, 100, bernoulli_perturb_func, jrand.PRNGKey(42))
b_perturbed = _perturb_data(X, 100, gaussian_perturb_func, jrand.PRNGKey(42))
g_perturbed assert b_perturbed.shape == (101, 28 * 28)
assert g_perturbed.shape == (101, 28 * 28)
exp_kernel_func
exp_kernel_func (dists:jax.Array, kernel_width:float)
Exponential kernel function for LIME
= pairwise_distances(g_perturbed, X) distances
The slowest run took 389.12 times longer than the fastest. This could mean that an intermediate result is being cached.
273 µs ± 657 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
LimeBase
LimeBase (func:Callable, additional_func_args:Dict=None, model_regressor=None, kernal_func:Callable=None, kernel_width:float=None, perturb_func:Callable=None, input_paramter_name:str='x', pairwise_distances_metric:str='euclidean')
Initialize self. See help(type(self)) for accurate signature.
Type | Default | Details | |
---|---|---|---|
func | typing.Callable | A black-box function to be explained | |
additional_func_args | typing.Dict | None | Additional arguments for the black-box function |
model_regressor | NoneType | None | Linear regressor to use in explanation |
kernal_func | typing.Callable | None | Kernel function for computing similarity |
kernel_width | float | None | Kernel width for computing similarity. Defaults to (n_features * 0.75) |
perturb_func | typing.Callable | None | Perturbation function for generating perturbed instances |
input_paramter_name | str | x | Name of the input parameter for the black-box function |
pairwise_distances_metric | str | euclidean | Metric for computing pairwise distances |
Tests
from sklearn.datasets import make_regression
import haiku as hk
= make_regression(n_samples=500, n_features=20) xs, ys
= LinearModel()
linear_model
linear_model.fit(xs, ys)= LimeBase(linear_model.predict, input_paramter_name="X") lime
# fit a simple haiku model
def model(x):
= hk.Sequential([
mlp 10),
hk.Linear(
jax.nn.relu,10),
hk.Linear(
jax.nn.relu,1),
hk.Linear(
])return mlp(x)
def init(x):
= hk.without_apply_rng(hk.transform(model))
net = optax.sgd(1e-1)
opt = net.init(jrand.PRNGKey(42), x)
params = opt.init(params)
opt_state return net, opt, params, opt_state
def loss(params, net, x, y):
= net.apply(params, x)
pred return jnp.mean((pred - y) ** 2)
@partial(jax.jit, static_argnums=(2,3))
def update(
params: hk.Params,
opt_state: optax.OptState,
net: hk.Transformed,
opt: optax.GradientTransformation,
x: jnp.ndarray,
y: jnp.ndarray
):= jax.grad(loss)(params, net, x, y)
grads = opt.update(grads, opt_state)
updates, opt_state = optax.apply_updates(params, updates)
params return params, opt_state
def train(
net: hk.Transformed,
opt: optax.GradientTransformation,
params: hk.Params,
opt_state: optax.OptState,
x: jnp.ndarray,
y: jnp.ndarray,int = 100,
n_epochs: int = 32,
batch_size:
):= x.shape[0]
n_samples for _ in range(n_epochs):
for i in range(0, n_samples, batch_size):
= x[i:i+batch_size]
x_batch = y[i:i+batch_size]
y_batch = update(params, opt_state, net, opt, x_batch, y_batch)
params, opt_state return params
def fit_a_model(
X: Array,
y: Array,
):= init(X)
net, opt, params, opt_state = train(net, opt, params, opt_state, X, y)
params return net, params
= fit_a_model(xs, ys) net, params
/home/birk/mambaforge-pypy3/envs/nbdev2/lib/python3.7/site-packages/haiku/_src/base.py:515: UserWarning: Explicitly requested dtype float64 requested in zeros is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
param = init(shape, dtype)
apply(params, xs[:10]) net.
DeviceArray([[0.24004635],
[0.24004635],
[0.24004635],
[0.24004635],
[0.24004635],
[0.24004635],
[0.24004635],
[0.24004635],
[0.24004635],
[0.24004635]], dtype=float32)
= partial(exp_kernel_func, kernel_width=2 * 0.75)
kernel_func
_lime_attribute_single_instance(1],
X[:1000,
42),
jrand.PRNGKey(apply,
net.={"params": params},
additional_func_args="x",
input_paramter_name=gaussian_perturb_func,
perturb_func=kernel_func,
kernel_func=Ridge(alpha=1),
model_regressor="euclidean",
pairwise_distances_metric )
(DeviceArray([-5.97378996e-04, 4.81305433e-05, 1.55211031e-03,
-3.89112538e-04, -1.21736361e-04, -1.05082065e-04,
1.35615395e-04, -4.70238097e-04, 5.06596552e-05,
5.96614438e-04, 8.60058644e-04, 1.57593074e-03,
-1.35515491e-03, 7.79002556e-04, 8.51082150e-04,
1.22845857e-04, 6.43223408e-04, 6.08801842e-04,
-3.02861667e-06, 1.18552020e-03], dtype=float32),
DeviceArray([0.24033982], dtype=float32))
= LimeBase(
lime =net.apply,
func={"params": params},
additional_func_args
) lime.attribute(X)
(DeviceArray([[ 0.00221086, 0.00581249, 0.00237516, ..., 0.0043735 ,
0.00298327, 0.00597063],
[ 0.00241475, 0.00598198, 0.00254859, ..., 0.00455689,
0.00296534, 0.00607178],
[-0.00225598, -0.00591312, -0.00256158, ..., -0.00448116,
-0.00301292, -0.00606422],
...,
[ 0.00223549, 0.00583388, 0.00241553, ..., 0.00438501,
0.00298516, 0.00600691],
[ 0.00221306, 0.00573398, 0.00225197, ..., 0.00415604,
0.00285654, 0.00581799],
[ 0.00242912, 0.00596596, 0.00257836, ..., 0.00450296,
0.0030849 , 0.00620169]], dtype=float32),
DeviceArray([[0.24730496],
[0.24756414],
[0.23204625],
[0.24731565],
[0.24734785],
[0.24736024],
[0.24726894],
[0.24740964],
[0.24748607],
[0.2472707 ],
[0.24726589],
[0.23272656],
[0.2474309 ],
[0.24768972],
[0.24720442],
[0.23258576],
[0.24747995],
[0.23265731],
[0.23216027],
[0.23245938],
[0.23232539],
[0.23272593],
[0.24734417],
[0.232578 ],
[0.2313688 ],
[0.232699 ],
[0.23166694],
[0.2476369 ],
[0.24740493],
[0.23260298],
[0.2476171 ],
[0.23259275],
[0.2327499 ],
[0.24747556],
[0.23274088],
[0.23278277],
[0.23274417],
[0.23259158],
[0.23276031],
[0.23244202],
[0.23248667],
[0.24736097],
[0.2475133 ],
[0.23159517],
[0.23230171],
[0.24782038],
[0.23145361],
[0.23132221],
[0.23273085],
[0.24778697],
[0.2318807 ],
[0.23234296],
[0.2473141 ],
[0.23186037],
[0.24774271],
[0.24740867],
[0.2321608 ],
[0.24721904],
[0.24765676],
[0.23248613],
[0.2325826 ],
[0.23253548],
[0.24741523],
[0.23102273],
[0.2324856 ],
[0.24752925],
[0.24744649],
[0.24753611],
[0.24741869],
[0.24742423],
[0.23231731],
[0.24737181],
[0.23265643],
[0.24769214],
[0.23244457],
[0.23207039],
[0.23221391],
[0.24714589],
[0.24750586],
[0.24760035],
[0.23261759],
[0.23214899],
[0.23229751],
[0.2317236 ],
[0.24728425],
[0.24737929],
[0.24734789],
[0.24735087],
[0.24734165],
[0.23252094],
[0.24735183],
[0.23268992],
[0.2472564 ],
[0.23250158],
[0.24733323],
[0.23173681],
[0.24738492],
[0.24749885],
[0.23246227],
[0.2326699 ],
[0.24735057],
[0.24739408],
[0.23209122],
[0.24731648],
[0.23272413],
[0.232523 ],
[0.24749972],
[0.24729979],
[0.23263235],
[0.23183963],
[0.23184806],
[0.23263596],
[0.24731863],
[0.23263787],
[0.24730252],
[0.2472181 ],
[0.24739836],
[0.24762246],
[0.24867377],
[0.24735266],
[0.23101482],
[0.24794899],
[0.23271923],
[0.23227747],
[0.23191595],
[0.24739477],
[0.23250002],
[0.24734001],
[0.24730746],
[0.24746673],
[0.24722606],
[0.24733111],
[0.23255974],
[0.24728492],
[0.23227465],
[0.24736558],
[0.24749716],
[0.23218371],
[0.23240303],
[0.24745671],
[0.23237772],
[0.24727368],
[0.23259522],
[0.23260601],
[0.23210865],
[0.23270965],
[0.23201807],
[0.2472552 ],
[0.2324636 ],
[0.24743754],
[0.23217565],
[0.24725908],
[0.23177943],
[0.23214754],
[0.23217827],
[0.24719004],
[0.23194984],
[0.24750523],
[0.23241459],
[0.23250034],
[0.2324958 ],
[0.24732801],
[0.24729869],
[0.23244049],
[0.24737254],
[0.23204154],
[0.23235755],
[0.24782768],
[0.24725914],
[0.24750242],
[0.23217706],
[0.23235546],
[0.24741676],
[0.23215216],
[0.24742627],
[0.24739274],
[0.23202284],
[0.23235677],
[0.23258376],
[0.23233262],
[0.2325433 ],
[0.2473211 ],
[0.23198971],
[0.24748974],
[0.24733643],
[0.23230669],
[0.24724506],
[0.23218584],
[0.24735439],
[0.24746102],
[0.24732526],
[0.2325736 ],
[0.23248109],
[0.24725024],
[0.24732903],
[0.24728669],
[0.24792492],
[0.24741988],
[0.23169258],
[0.23255627],
[0.24753469],
[0.24741082],
[0.23244117],
[0.23260221],
[0.2316141 ],
[0.24735719],
[0.24743626],
[0.24751961],
[0.23250861],
[0.23179816],
[0.24762237],
[0.24788071],
[0.2325231 ],
[0.2324555 ],
[0.2325428 ],
[0.23237003],
[0.2323332 ],
[0.24787727],
[0.23265064],
[0.24732669],
[0.24749607],
[0.24766587],
[0.2480405 ],
[0.23218127],
[0.24730903],
[0.23179471],
[0.2473682 ],
[0.24743044],
[0.247412 ],
[0.23198949],
[0.2474296 ],
[0.23019618],
[0.24736585],
[0.24736209],
[0.24728309],
[0.24746317],
[0.23231001],
[0.23207894],
[0.24735352],
[0.2322949 ],
[0.24727553],
[0.24729888],
[0.23229258],
[0.23158833],
[0.24727146],
[0.23261805],
[0.24736366],
[0.23235823],
[0.24729802],
[0.24752894],
[0.24734056],
[0.23284838],
[0.24715085],
[0.2323496 ],
[0.24730666],
[0.2320316 ],
[0.23239292],
[0.24736963],
[0.24747422],
[0.24727249],
[0.23237279],
[0.24724723],
[0.24729113],
[0.24737135],
[0.24733087],
[0.24724013],
[0.24737415],
[0.2327153 ],
[0.24791467],
[0.24732663],
[0.24748352],
[0.23251979],
[0.24714312],
[0.24755022],
[0.2473973 ],
[0.23273005],
[0.23246035],
[0.24732141],
[0.23213631],
[0.24721566],
[0.23201378],
[0.23221391],
[0.24740039],
[0.23237988],
[0.23115295],
[0.23251712],
[0.23238027],
[0.23267289],
[0.24774548],
[0.23251162],
[0.23228587],
[0.23244764],
[0.24746987],
[0.23209953],
[0.24765822],
[0.23273446],
[0.23225085],
[0.24744767],
[0.23257345],
[0.24717891],
[0.23226202],
[0.23245876],
[0.24770094],
[0.23234405],
[0.24729179],
[0.23253849],
[0.23176861],
[0.23238832],
[0.24772736],
[0.23276523],
[0.24727891],
[0.24722643],
[0.23272316],
[0.24743803],
[0.2473214 ],
[0.2324927 ],
[0.24720925],
[0.2319577 ],
[0.24736662],
[0.24720564],
[0.2473594 ],
[0.24746218],
[0.2473249 ],
[0.23269431],
[0.23267029],
[0.24746528],
[0.24735796],
[0.23275968],
[0.24729164],
[0.24727115],
[0.24726492],
[0.23244949],
[0.23239763],
[0.23260587],
[0.2317266 ],
[0.24721311],
[0.24791732],
[0.23277605],
[0.24746867],
[0.2326223 ],
[0.24716686],
[0.23254141],
[0.24743882],
[0.24725032],
[0.23271337],
[0.23271807],
[0.24727352],
[0.23150992],
[0.24748631],
[0.24734168],
[0.24780883],
[0.24741678],
[0.24733034],
[0.23234318],
[0.24731077],
[0.2472419 ],
[0.2473333 ],
[0.23262657],
[0.24725671],
[0.24735816],
[0.23228347],
[0.2473127 ],
[0.24725196],
[0.23249117],
[0.23216985],
[0.24734968],
[0.23251817],
[0.23218517],
[0.23229903],
[0.24736002],
[0.23227802],
[0.24734715],
[0.24728952],
[0.23249188],
[0.23238972],
[0.24732132],
[0.23275092],
[0.23262392],
[0.23213783],
[0.24741286],
[0.24725187],
[0.23273003],
[0.23257877],
[0.24740161],
[0.24715275],
[0.24721095],
[0.23270902],
[0.24722742],
[0.2319201 ],
[0.23252024],
[0.23254202],
[0.23239242],
[0.23167515],
[0.24744089],
[0.23231809],
[0.23276606],
[0.23184034],
[0.23192035],
[0.23264523],
[0.24742319],
[0.23254554],
[0.23217438],
[0.24735396],
[0.23255692],
[0.23227966],
[0.232184 ],
[0.24756455],
[0.23222198],
[0.24758385],
[0.24754198],
[0.24729684],
[0.23244324],
[0.24749175],
[0.23269522],
[0.23192582],
[0.2326539 ],
[0.23266825],
[0.24736997],
[0.24730414],
[0.2324066 ],
[0.24725583],
[0.24741372],
[0.24747895],
[0.23201811],
[0.24729444],
[0.247381 ],
[0.2326003 ],
[0.23236334],
[0.24736139],
[0.2321806 ],
[0.23245579],
[0.23235963],
[0.23238862],
[0.2472786 ],
[0.23169187],
[0.24740751],
[0.23263946],
[0.23267938],
[0.24756868],
[0.23228672],
[0.24769364],
[0.24736963],
[0.23052931],
[0.23218985],
[0.23187762],
[0.23254694],
[0.23233788],
[0.2312679 ],
[0.23229599],
[0.23235682],
[0.23257495],
[0.23236491],
[0.24776962],
[0.24718513],
[0.24725026],
[0.23261815],
[0.2471972 ],
[0.24723697],
[0.2323229 ],
[0.23140587],
[0.2477279 ],
[0.23209357],
[0.24751735],
[0.24730821],
[0.24750674],
[0.23272803],
[0.24727444],
[0.24727756],
[0.24726333],
[0.24735498],
[0.23130915],
[0.24735644],
[0.2326549 ],
[0.23219867],
[0.23273338],
[0.23268846],
[0.24724345],
[0.2473797 ],
[0.24738654],
[0.2314789 ],
[0.24723507],
[0.23221129],
[0.24730465],
[0.23243599],
[0.232267 ],
[0.23234174],
[0.24730346],
[0.2472916 ],
[0.24730274],
[0.2325149 ],
[0.24734567],
[0.24727803],
[0.24726497],
[0.23281087],
[0.24736117],
[0.23239669],
[0.23156339],
[0.2475236 ],
[0.24708891],
[0.2476187 ]], dtype=float32))