= np.random.randn(100, 1)
xs = MinMaxTransformation()
minmax_t = minmax_t.fit_transform(xs)
transformed_xs assert np.allclose(minmax_t.inverse_transform(transformed_xs), xs)
assert minmax_t.is_categorical is False
= np.random.randn(100, 1)
x = minmax_t.apply_constraints(xs, x)
cf_constrained assert np.all(cf_constrained >= 0) and np.all(cf_constrained <= 1)
# Test from_dict and to_dict
= MinMaxTransformation().from_dict(minmax_t.to_dict())
scaler_1 assert np.allclose(minmax_t.transform(xs), scaler_1.transform(xs))
Feature Transformation
relax.data_utils.transforms.BaseTransformation
class relax.data_utils.transforms.BaseTransformation (name, transformer=None)
Base class for all transformations.
relax.data_utils.transforms.MinMaxTransformation
class relax.data_utils.transforms.MinMaxTransformation ()
Base class for all transformations.
relax.data_utils.transforms.OneHotTransformation
relax.data_utils.transforms.OneHotTransformation ()
relax.data_utils.transforms.GumbelSoftmaxTransformation
class relax.data_utils.transforms.GumbelSoftmaxTransformation (tau=0.1)
Apply Gumbel softmax tricks for categorical transformation.
relax.data_utils.transforms.SoftmaxTransformation
class relax.data_utils.transforms.SoftmaxTransformation ()
Base class for all transformations.
def test_ohe_t(ohe_cls):
= np.random.choice(['a', 'b', 'c'], size=(100, 1))
xs = ohe_cls().fit(xs)
ohe_t = ohe_t.transform(xs)
transformed_xs = jax.random.PRNGKey(get_config().global_seed)
rng_key assert ohe_t.is_categorical
= jax.random.uniform(rng_key, shape=(100, 3))
x # Test hard=True which applies softmax function.
= ohe_t.apply_constraints(transformed_xs, x, hard=False, rng_key=rng_key)
soft assert jnp.allclose(soft.sum(axis=-1), 1)
assert jnp.all(soft >= 0)
assert jnp.all(soft <= 1)
assert jnp.allclose(jnp.zeros((len(x), 1)), ohe_t.compute_reg_loss(xs, soft, hard=False))
assert jnp.allclose(soft, ohe_t.apply_constraints(transformed_xs, x, hard=False))
# Test hard=True which enforce one-hot constraint.
= ohe_t.apply_constraints(transformed_xs, x, hard=True, rng_key=rng_key)
hard assert np.all([1 in x for x in hard])
assert np.all([0 in x for x in hard])
assert jnp.allclose(hard.sum(axis=-1), 1)
assert jnp.allclose(jnp.zeros((len(x), 1)), ohe_t.compute_reg_loss(xs, hard, hard=False))
# Test compute_reg_loss
assert jnp.ndim(ohe_t.compute_reg_loss(xs, soft, hard=False)) == 0
# Test from_dict and to_dict
= ohe_cls().from_dict(ohe_t.to_dict())
ohe_t_1 assert np.allclose(ohe_t.transform(xs), ohe_t_1.transform(xs))
test_ohe_t(SoftmaxTransformation) test_ohe_t(GumbelSoftmaxTransformation)
relax.data_utils.transforms.IdentityTransformation
class relax.data_utils.transforms.IdentityTransformation ()
Base class for all transformations.
relax.data_utils.transforms.OrdinalTransformation
class relax.data_utils.transforms.OrdinalTransformation ()
Base class for all transformations.
= np.random.choice(['a', 'b', 'c'], size=(100, 1))
xs = OrdinalTransformation().fit(xs)
encoder = encoder.transform(xs)
transformed_xs assert np.all(encoder.inverse_transform(transformed_xs) == xs)
assert encoder.is_categorical
# Test from_dict and to_dict
= OrdinalTransformation().from_dict(encoder.to_dict())
encoder_1 assert np.allclose(encoder.transform(xs), encoder_1.transform(xs))
= np.random.randn(100, 1)
xs = IdentityTransformation()
scaler = scaler.fit_transform(xs)
transformed_xs assert np.all(transformed_xs == xs)
# Test from_dict and to_dict
= IdentityTransformation().from_dict(scaler.to_dict())
scaler_1 assert np.allclose(scaler.transform(xs), scaler_1.transform(xs))