Parallelism Strategy

relax.strategy.BaseStrategy

[source]

class relax.strategy.BaseStrategy ()

Base class for mapping strategy.

Methods

[source]

call (fn, xs, pred_fn, y_targets, rng_keys, **kwargs)

Call self as a function.

Parameters:

  • fn (typing.Callable) – Function to generate cf for a single input
  • xs (<class 'jax.Array'>) – Input instances to be explained
  • pred_fn (typing.Callable[[jax.Array], jax.Array])
  • y_targets (<class 'jax.Array'>)
  • rng_keys (typing.Iterable[PRNGKey])
  • kwargs

Returns:

    (<class 'jax.Array'>) – Generated counterfactual explanations

relax.strategy.IterativeStrategy

[source]

class relax.strategy.IterativeStrategy ()

Iterativly generate counterfactuals.

Methods

[source]

call (fn, xs, pred_fn, y_targets, rng_keys, **kwargs)

Call self as a function.

Parameters:

  • fn (typing.Callable) – Function to generate cf for a single input
  • xs (<class 'jax.Array'>) – Input instances to be explained
  • pred_fn (typing.Callable[[jax.Array], jax.Array])
  • y_targets (<class 'jax.Array'>)
  • rng_keys (typing.Iterable[PRNGKey])
  • kwargs

Returns:

    (<class 'jax.Array'>) – Generated counterfactual explanations

relax.strategy.VmapStrategy

[source]

class relax.strategy.VmapStrategy ()

Generate counterfactuals via jax.vmap.

Methods

[source]

call (fn, xs, pred_fn, y_targets, rng_keys, **kwargs)

Call self as a function.

Parameters:

  • fn (typing.Callable) – Function to generate cf for a single input
  • xs (<class 'jax.Array'>) – Input instances to be explained
  • pred_fn (typing.Callable[[jax.Array], jax.Array])
  • y_targets (<class 'jax.Array'>)
  • rng_keys (typing.Iterable[PRNGKey])
  • kwargs

Returns:

    (<class 'jax.Array'>) – Generated counterfactual explanations

relax.strategy.PmapStrategy

[source]

class relax.strategy.PmapStrategy (n_devices=None, strategy=‘auto’, **kwargs)

Base class for mapping strategy.

Parameters:

  • n_devices (<class 'int'>, default=None) – Number of devices. If None, use all available devices
  • strategy (<class 'str'>, default=auto) – Strategy to generate counterfactuals
  • kwargs

Methods

[source]

call (fn, xs, pred_fn, y_targets, rng_keys, **kwargs)

Call self as a function.

Parameters:

  • fn (typing.Callable) – Function to generate cf for a single input
  • xs (<class 'jax.Array'>) – Input instances to be explained
  • pred_fn (typing.Callable[[jax.Array], jax.Array])
  • y_targets (<class 'jax.Array'>)
  • rng_keys (typing.Iterable[PRNGKey])
  • kwargs

Returns:

    (<class 'jax.Array'>) – Generated counterfactual explanations

relax.strategy.BatchedVmapStrategy

[source]

class relax.strategy.BatchedVmapStrategy (batch_size)

Auto-batching for generate counterfactuals via jax.vmap.

Methods

[source]

call (fn, xs, pred_fn, y_targets, rng_keys, **kwargs)

Call self as a function.

Parameters:

  • fn (typing.Callable) – Function to generate cf for a single input
  • xs (<class 'jax.Array'>) – Input instances to be explained
  • pred_fn (typing.Callable[[jax.Array], jax.Array])
  • y_targets (<class 'jax.Array'>)
  • rng_keys (typing.Iterable[PRNGKey])
  • kwargs

Returns:

    (<class 'jax.Array'>) – Generated counterfactual explanations

relax.strategy.BatchedPmapStrategy

[source]

class relax.strategy.BatchedPmapStrategy (batch_size, n_devices=None)

Auto-batching for generate counterfactuals via jax.vmap.

Methods

[source]

call (fn, xs, pred_fn, y_targets, rng_keys, **kwargs)

Call self as a function.

Parameters:

  • fn (typing.Callable) – Function to generate cf for a single input
  • xs (<class 'jax.Array'>) – Input instances to be explained
  • pred_fn (typing.Callable[[jax.Array], jax.Array])
  • y_targets (<class 'jax.Array'>)
  • rng_keys (typing.Iterable[PRNGKey])
  • kwargs

Returns:

    (<class 'jax.Array'>) – Generated counterfactual explanations

os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'

w = jrand.normal(jrand.PRNGKey(0), (100, 100))
xs = jrand.normal(jrand.PRNGKey(0), (1000, 100))

@jit
def pred_fn(x): return jnp.dot(x, w.T)

def f(x, pred_fn=None, y_target=None, rng_key=None, **kwargs):
    return pred_fn(x) + jrand.normal(rng_key, (1,))

rng_keys = jrand.split(jrand.PRNGKey(0), 1000)
y_targets = jnp.ones((1000, 100))

iter_gen = IterativeStrategy()
vmap_gen = VmapStrategy()
pmap_gen = PmapStrategy()
bvmap_gen = BatchedVmapStrategy(128)
bpmap_gen = BatchedPmapStrategy(128)
cf_iter = iter_gen(f, xs, pred_fn=pred_fn, y_targets=y_targets, rng_keys=rng_keys)
cf_vmap = vmap_gen(f, xs, pred_fn=pred_fn, y_targets=y_targets, rng_keys=rng_keys)
cf_pmap = pmap_gen(f, xs, pred_fn=pred_fn, y_targets=y_targets, rng_keys=rng_keys)
cf_bvmap = bvmap_gen(f, xs, pred_fn=pred_fn, y_targets=y_targets, rng_keys=rng_keys)
def f_mul(x, pred_fn=None, **kwargs):
    cf = pred_fn(x)
    return einops.repeat(cf, 'k -> c k', c=5)
cf_iter = iter_gen(f_mul, xs, pred_fn=pred_fn, y_targets=y_targets, rng_keys=rng_keys)
cf_vmap = vmap_gen(f_mul, xs, pred_fn=pred_fn, y_targets=y_targets, rng_keys=rng_keys)
cf_pmap = pmap_gen(f_mul, xs, pred_fn=pred_fn, y_targets=y_targets, rng_keys=rng_keys)
cf_bvmap = bvmap_gen(f_mul, xs, pred_fn=pred_fn, y_targets=y_targets, rng_keys=rng_keys)
cf_bpmap = bpmap_gen(f_mul, xs, pred_fn=pred_fn, y_targets=y_targets, rng_keys=rng_keys)

assert jnp.allclose(cf_iter, cf_vmap, atol=1e-4)
assert jnp.allclose(cf_iter, cf_bvmap, atol=1e-4)
assert jnp.allclose(cf_iter, cf_pmap, atol=1e-4)
assert jnp.allclose(cf_iter, cf_bpmap, atol=1e-4)
assert cf_bvmap.shape == (xs.shape[0], 5, xs.shape[1])

relax.strategy.StrategyFactory

[source]

class relax.strategy.StrategyFactory ()

Factory class for Parallelism Strategy.

Methods

[source]

get_default_strategy ()

Get default strategy.

[source]

get_strategy (strategy)

Get strategy.

it = StrategyFactory.get_strategy('iter')
vm = StrategyFactory.get_strategy('vmap')
pm = StrategyFactory.get_strategy('pmap')
default = StrategyFactory.get_default_strategy()
cus = StrategyFactory.get_strategy(VmapStrategy())

assert isinstance(it, IterativeStrategy)
assert isinstance(vm, VmapStrategy)
assert isinstance(pm, PmapStrategy)
assert isinstance(default, VmapStrategy)
assert isinstance(cus, VmapStrategy)