os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'
w = jrand.normal(jrand.PRNGKey(0), (100, 100))
xs = jrand.normal(jrand.PRNGKey(0), (1000, 100))
@jit
def pred_fn(x): return jnp.dot(x, w.T)
def f(x, pred_fn=None, y_target=None, rng_key=None, **kwargs):
return pred_fn(x) + jrand.normal(rng_key, (1,))
rng_keys = jrand.split(jrand.PRNGKey(0), 1000)
y_targets = jnp.ones((1000, 100))
iter_gen = IterativeStrategy()
vmap_gen = VmapStrategy()
pmap_gen = PmapStrategy()
bvmap_gen = BatchedVmapStrategy(128)
bpmap_gen = BatchedPmapStrategy(128)Parallelism Strategy
relax.strategy.BaseStrategy
class relax.strategy.BaseStrategy ()
Base class for mapping strategy.
Methods
call (fn, xs, pred_fn, y_targets, rng_keys, **kwargs)
Call self as a function.
Parameters:
- fn (
typing.Callable) – Function to generate cf for a single input - xs (
<class 'jax.Array'>) – Input instances to be explained - pred_fn (
typing.Callable[[jax.Array], jax.Array]) - y_targets (
<class 'jax.Array'>) - rng_keys (
typing.Iterable[PRNGKey]) - kwargs (
VAR_KEYWORD)
Returns:
(<class 'jax.Array'>) – Generated counterfactual explanations
relax.strategy.IterativeStrategy
class relax.strategy.IterativeStrategy ()
Iterativly generate counterfactuals.
Methods
call (fn, xs, pred_fn, y_targets, rng_keys, **kwargs)
Call self as a function.
Parameters:
- fn (
typing.Callable) – Function to generate cf for a single input - xs (
<class 'jax.Array'>) – Input instances to be explained - pred_fn (
typing.Callable[[jax.Array], jax.Array]) - y_targets (
<class 'jax.Array'>) - rng_keys (
typing.Iterable[PRNGKey]) - kwargs (
VAR_KEYWORD)
Returns:
(<class 'jax.Array'>) – Generated counterfactual explanations
relax.strategy.VmapStrategy
class relax.strategy.VmapStrategy ()
Generate counterfactuals via jax.vmap.
Methods
call (fn, xs, pred_fn, y_targets, rng_keys, **kwargs)
Call self as a function.
Parameters:
- fn (
typing.Callable) – Function to generate cf for a single input - xs (
<class 'jax.Array'>) – Input instances to be explained - pred_fn (
typing.Callable[[jax.Array], jax.Array]) - y_targets (
<class 'jax.Array'>) - rng_keys (
typing.Iterable[PRNGKey]) - kwargs (
VAR_KEYWORD)
Returns:
(<class 'jax.Array'>) – Generated counterfactual explanations
relax.strategy.PmapStrategy
class relax.strategy.PmapStrategy (n_devices=None, strategy=‘auto’, **kwargs)
Base class for mapping strategy.
Parameters:
- n_devices (
<class 'int'>, default=None) – Number of devices. If None, use all available devices - strategy (
<class 'str'>, default=auto) – Strategy to generate counterfactuals - kwargs (
VAR_KEYWORD)
Methods
call (fn, xs, pred_fn, y_targets, rng_keys, **kwargs)
Call self as a function.
Parameters:
- fn (
typing.Callable) – Function to generate cf for a single input - xs (
<class 'jax.Array'>) – Input instances to be explained - pred_fn (
typing.Callable[[jax.Array], jax.Array]) - y_targets (
<class 'jax.Array'>) - rng_keys (
typing.Iterable[PRNGKey]) - kwargs (
VAR_KEYWORD)
Returns:
(<class 'jax.Array'>) – Generated counterfactual explanations
relax.strategy.BatchedVmapStrategy
class relax.strategy.BatchedVmapStrategy (batch_size)
Auto-batching for generate counterfactuals via jax.vmap.
Methods
call (fn, xs, pred_fn, y_targets, rng_keys, **kwargs)
Call self as a function.
Parameters:
- fn (
typing.Callable) – Function to generate cf for a single input - xs (
<class 'jax.Array'>) – Input instances to be explained - pred_fn (
typing.Callable[[jax.Array], jax.Array]) - y_targets (
<class 'jax.Array'>) - rng_keys (
typing.Iterable[PRNGKey]) - kwargs (
VAR_KEYWORD)
Returns:
(<class 'jax.Array'>) – Generated counterfactual explanations
relax.strategy.BatchedPmapStrategy
class relax.strategy.BatchedPmapStrategy (batch_size, n_devices=None)
Auto-batching for generate counterfactuals via jax.vmap.
Methods
call (fn, xs, pred_fn, y_targets, rng_keys, **kwargs)
Call self as a function.
Parameters:
- fn (
typing.Callable) – Function to generate cf for a single input - xs (
<class 'jax.Array'>) – Input instances to be explained - pred_fn (
typing.Callable[[jax.Array], jax.Array]) - y_targets (
<class 'jax.Array'>) - rng_keys (
typing.Iterable[PRNGKey]) - kwargs (
VAR_KEYWORD)
Returns:
(<class 'jax.Array'>) – Generated counterfactual explanations
cf_iter = iter_gen(f, xs, pred_fn=pred_fn, y_targets=y_targets, rng_keys=rng_keys)cf_vmap = vmap_gen(f, xs, pred_fn=pred_fn, y_targets=y_targets, rng_keys=rng_keys)cf_pmap = pmap_gen(f, xs, pred_fn=pred_fn, y_targets=y_targets, rng_keys=rng_keys)cf_bvmap = bvmap_gen(f, xs, pred_fn=pred_fn, y_targets=y_targets, rng_keys=rng_keys)def f_mul(x, pred_fn=None, **kwargs):
cf = pred_fn(x)
return einops.repeat(cf, 'k -> c k', c=5)cf_iter = iter_gen(f_mul, xs, pred_fn=pred_fn, y_targets=y_targets, rng_keys=rng_keys)
cf_vmap = vmap_gen(f_mul, xs, pred_fn=pred_fn, y_targets=y_targets, rng_keys=rng_keys)
cf_pmap = pmap_gen(f_mul, xs, pred_fn=pred_fn, y_targets=y_targets, rng_keys=rng_keys)
cf_bvmap = bvmap_gen(f_mul, xs, pred_fn=pred_fn, y_targets=y_targets, rng_keys=rng_keys)
cf_bpmap = bpmap_gen(f_mul, xs, pred_fn=pred_fn, y_targets=y_targets, rng_keys=rng_keys)
assert jnp.allclose(cf_iter, cf_vmap, atol=1e-4)
assert jnp.allclose(cf_iter, cf_bvmap, atol=1e-4)
assert jnp.allclose(cf_iter, cf_pmap, atol=1e-4)
assert jnp.allclose(cf_iter, cf_bpmap, atol=1e-4)
assert cf_bvmap.shape == (xs.shape[0], 5, xs.shape[1])relax.strategy.StrategyFactory
class relax.strategy.StrategyFactory ()
Factory class for Parallelism Strategy.
Methods
get_default_strategy ()
Get default strategy.
get_strategy (strategy)
Get strategy.
it = StrategyFactory.get_strategy('iter')
vm = StrategyFactory.get_strategy('vmap')
pm = StrategyFactory.get_strategy('pmap')
default = StrategyFactory.get_default_strategy()
cus = StrategyFactory.get_strategy(VmapStrategy())
assert isinstance(it, IterativeStrategy)
assert isinstance(vm, VmapStrategy)
assert isinstance(pm, PmapStrategy)
assert isinstance(default, VmapStrategy)
assert isinstance(cus, VmapStrategy)