'XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'
os.environ[
= jrand.normal(jrand.PRNGKey(0), (100, 100))
w = jrand.normal(jrand.PRNGKey(0), (1000, 100))
xs
@jit
def pred_fn(x): return jnp.dot(x, w.T)
def f(x, pred_fn=None, y_target=None, rng_key=None, **kwargs):
return pred_fn(x) + jrand.normal(rng_key, (1,))
= jrand.split(jrand.PRNGKey(0), 1000)
rng_keys = jnp.ones((1000, 100))
y_targets
= IterativeStrategy()
iter_gen = VmapStrategy()
vmap_gen = PmapStrategy()
pmap_gen = BatchedVmapStrategy(128)
bvmap_gen = BatchedPmapStrategy(128) bpmap_gen
Parallelism Strategy
relax.strategy.BaseStrategy
class relax.strategy.BaseStrategy ()
Base class for mapping strategy.
Methods
call (fn, xs, pred_fn, y_targets, rng_keys, **kwargs)
Call self as a function.
Parameters:
- fn (
typing.Callable
) – Function to generate cf for a single input - xs (
<class 'jax.Array'>
) – Input instances to be explained - pred_fn (
typing.Callable[[jax.Array], jax.Array]
) - y_targets (
<class 'jax.Array'>
) - rng_keys (
typing.Iterable[PRNGKey]
) - kwargs
Returns:
(<class 'jax.Array'>
) – Generated counterfactual explanations
relax.strategy.IterativeStrategy
class relax.strategy.IterativeStrategy ()
Iterativly generate counterfactuals.
Methods
call (fn, xs, pred_fn, y_targets, rng_keys, **kwargs)
Call self as a function.
Parameters:
- fn (
typing.Callable
) – Function to generate cf for a single input - xs (
<class 'jax.Array'>
) – Input instances to be explained - pred_fn (
typing.Callable[[jax.Array], jax.Array]
) - y_targets (
<class 'jax.Array'>
) - rng_keys (
typing.Iterable[PRNGKey]
) - kwargs
Returns:
(<class 'jax.Array'>
) – Generated counterfactual explanations
relax.strategy.VmapStrategy
class relax.strategy.VmapStrategy ()
Generate counterfactuals via jax.vmap
.
Methods
call (fn, xs, pred_fn, y_targets, rng_keys, **kwargs)
Call self as a function.
Parameters:
- fn (
typing.Callable
) – Function to generate cf for a single input - xs (
<class 'jax.Array'>
) – Input instances to be explained - pred_fn (
typing.Callable[[jax.Array], jax.Array]
) - y_targets (
<class 'jax.Array'>
) - rng_keys (
typing.Iterable[PRNGKey]
) - kwargs
Returns:
(<class 'jax.Array'>
) – Generated counterfactual explanations
relax.strategy.PmapStrategy
class relax.strategy.PmapStrategy (n_devices=None, strategy=‘auto’, **kwargs)
Base class for mapping strategy.
Parameters:
- n_devices (
<class 'int'>
, default=None) – Number of devices. If None, use all available devices - strategy (
<class 'str'>
, default=auto) – Strategy to generate counterfactuals - kwargs
Methods
call (fn, xs, pred_fn, y_targets, rng_keys, **kwargs)
Call self as a function.
Parameters:
- fn (
typing.Callable
) – Function to generate cf for a single input - xs (
<class 'jax.Array'>
) – Input instances to be explained - pred_fn (
typing.Callable[[jax.Array], jax.Array]
) - y_targets (
<class 'jax.Array'>
) - rng_keys (
typing.Iterable[PRNGKey]
) - kwargs
Returns:
(<class 'jax.Array'>
) – Generated counterfactual explanations
relax.strategy.BatchedVmapStrategy
class relax.strategy.BatchedVmapStrategy (batch_size)
Auto-batching for generate counterfactuals via jax.vmap
.
Methods
call (fn, xs, pred_fn, y_targets, rng_keys, **kwargs)
Call self as a function.
Parameters:
- fn (
typing.Callable
) – Function to generate cf for a single input - xs (
<class 'jax.Array'>
) – Input instances to be explained - pred_fn (
typing.Callable[[jax.Array], jax.Array]
) - y_targets (
<class 'jax.Array'>
) - rng_keys (
typing.Iterable[PRNGKey]
) - kwargs
Returns:
(<class 'jax.Array'>
) – Generated counterfactual explanations
relax.strategy.BatchedPmapStrategy
class relax.strategy.BatchedPmapStrategy (batch_size, n_devices=None)
Auto-batching for generate counterfactuals via jax.vmap
.
Methods
call (fn, xs, pred_fn, y_targets, rng_keys, **kwargs)
Call self as a function.
Parameters:
- fn (
typing.Callable
) – Function to generate cf for a single input - xs (
<class 'jax.Array'>
) – Input instances to be explained - pred_fn (
typing.Callable[[jax.Array], jax.Array]
) - y_targets (
<class 'jax.Array'>
) - rng_keys (
typing.Iterable[PRNGKey]
) - kwargs
Returns:
(<class 'jax.Array'>
) – Generated counterfactual explanations
= iter_gen(f, xs, pred_fn=pred_fn, y_targets=y_targets, rng_keys=rng_keys) cf_iter
= vmap_gen(f, xs, pred_fn=pred_fn, y_targets=y_targets, rng_keys=rng_keys) cf_vmap
= pmap_gen(f, xs, pred_fn=pred_fn, y_targets=y_targets, rng_keys=rng_keys) cf_pmap
= bvmap_gen(f, xs, pred_fn=pred_fn, y_targets=y_targets, rng_keys=rng_keys) cf_bvmap
def f_mul(x, pred_fn=None, **kwargs):
= pred_fn(x)
cf return einops.repeat(cf, 'k -> c k', c=5)
= iter_gen(f_mul, xs, pred_fn=pred_fn, y_targets=y_targets, rng_keys=rng_keys)
cf_iter = vmap_gen(f_mul, xs, pred_fn=pred_fn, y_targets=y_targets, rng_keys=rng_keys)
cf_vmap = pmap_gen(f_mul, xs, pred_fn=pred_fn, y_targets=y_targets, rng_keys=rng_keys)
cf_pmap = bvmap_gen(f_mul, xs, pred_fn=pred_fn, y_targets=y_targets, rng_keys=rng_keys)
cf_bvmap = bpmap_gen(f_mul, xs, pred_fn=pred_fn, y_targets=y_targets, rng_keys=rng_keys)
cf_bpmap
assert jnp.allclose(cf_iter, cf_vmap, atol=1e-4)
assert jnp.allclose(cf_iter, cf_bvmap, atol=1e-4)
assert jnp.allclose(cf_iter, cf_pmap, atol=1e-4)
assert jnp.allclose(cf_iter, cf_bpmap, atol=1e-4)
assert cf_bvmap.shape == (xs.shape[0], 5, xs.shape[1])
relax.strategy.StrategyFactory
class relax.strategy.StrategyFactory ()
Factory class for Parallelism Strategy.
Methods
get_default_strategy ()
Get default strategy.
get_strategy (strategy)
Get strategy.
= StrategyFactory.get_strategy('iter')
it = StrategyFactory.get_strategy('vmap')
vm = StrategyFactory.get_strategy('pmap')
pm = StrategyFactory.get_default_strategy()
default = StrategyFactory.get_strategy(VmapStrategy())
cus
assert isinstance(it, IterativeStrategy)
assert isinstance(vm, VmapStrategy)
assert isinstance(pm, PmapStrategy)
assert isinstance(default, VmapStrategy)
assert isinstance(cus, VmapStrategy)