Base API

relax.methods.base.CFModule

[source]

class relax.methods.base.CFModule (config, name=None, apply_constraints_fn=None, compute_reg_loss_fn=None, **kwargs)

Base class for all counterfactual modules.

Methods

[source]

set_apply_constraints_fn (apply_constraints_fn)

[source]

set_compute_reg_loss_fn (compute_reg_loss_fn)

[source]

apply_constraints (*args, **kwargs)

[source]

compute_reg_loss (*args, **kwargs)

[source]

save (path)

[source]

load_from_path (path)

[source]

before_generate_cf (*args, **kwargs)

[source]

generate_cf (x, pred_fn=None, y_target=None, rng_key=None, **kwargs)

Parameters:

  • x (<class 'jax.Array'>)
  • pred_fn (typing.Callable, default=None)
  • y_target (<class 'jax.Array'>, default=None)
  • rng_key (<function PRNGKey at 0x7fcff174d360>, default=None)
  • kwargs

Returns:

    (<class 'jax.Array'>) – Return counterfactual of x.

relax.methods.base.ParametricCFModule

[source]

class relax.methods.base.ParametricCFModule (config, name=None, apply_constraints_fn=None, compute_reg_loss_fn=None, **kwargs)

Base class for parametric counterfactual modules.

Methods

[source]

set_apply_constraints_fn (apply_constraints_fn)

[source]

set_compute_reg_loss_fn (compute_reg_loss_fn)

[source]

apply_constraints (*args, **kwargs)

[source]

compute_reg_loss (*args, **kwargs)

[source]

save (path)

[source]

load_from_path (path)

[source]

before_generate_cf (*args, **kwargs)

[source]

generate_cf (x, pred_fn=None, y_target=None, rng_key=None, **kwargs)

Parameters:

  • x (<class 'jax.Array'>)
  • pred_fn (typing.Callable, default=None)
  • y_target (<class 'jax.Array'>, default=None)
  • rng_key (<function PRNGKey at 0x7fcff174d360>, default=None)
  • kwargs

Returns:

    (<class 'jax.Array'>) – Return counterfactual of x.