Base API
relax.methods.base.CFModule
class relax.methods.base.CFModule (config, name=None, apply_constraints_fn=None, compute_reg_loss_fn=None, **kwargs)
Base class for all counterfactual modules.
Methods
set_apply_constraints_fn (apply_constraints_fn)
set_compute_reg_loss_fn (compute_reg_loss_fn)
apply_constraints (*args, **kwargs)
compute_reg_loss (*args, **kwargs)
save (path)
load_from_path (path)
before_generate_cf (*args, **kwargs)
generate_cf (x, pred_fn=None, y_target=None, rng_key=None, **kwargs)
Parameters:
- x (
<class 'jax.Array'>
) - pred_fn (
typing.Callable
, default=None) - y_target (
<class 'jax.Array'>
, default=None) - rng_key (
<function PRNGKey at 0x7fcff174d360>
, default=None) - kwargs
Returns:
(<class 'jax.Array'>
) – Return counterfactual of x.
relax.methods.base.ParametricCFModule
class relax.methods.base.ParametricCFModule (config, name=None, apply_constraints_fn=None, compute_reg_loss_fn=None, **kwargs)
Base class for parametric counterfactual modules.
Methods
set_apply_constraints_fn (apply_constraints_fn)
set_compute_reg_loss_fn (compute_reg_loss_fn)
apply_constraints (*args, **kwargs)
compute_reg_loss (*args, **kwargs)
save (path)
load_from_path (path)
before_generate_cf (*args, **kwargs)
generate_cf (x, pred_fn=None, y_target=None, rng_key=None, **kwargs)
Parameters:
- x (
<class 'jax.Array'>
) - pred_fn (
typing.Callable
, default=None) - y_target (
<class 'jax.Array'>
, default=None) - rng_key (
<function PRNGKey at 0x7fcff174d360>
, default=None) - kwargs
Returns:
(<class 'jax.Array'>
) – Return counterfactual of x.