Feature Transformation

relax.data_utils.transforms.BaseTransformation

[source]

class relax.data_utils.transforms.BaseTransformation (name, transformer=None)

Base class for all transformations.

relax.data_utils.transforms.MinMaxTransformation

[source]

class relax.data_utils.transforms.MinMaxTransformation ()

Base class for all transformations.

xs = np.random.randn(100, 1)
minmax_t = MinMaxTransformation()
transformed_xs = minmax_t.fit_transform(xs)
assert np.allclose(minmax_t.inverse_transform(transformed_xs), xs)
assert minmax_t.is_categorical is False

x = np.random.randn(100, 1)
cf_constrained = minmax_t.apply_constraints(xs, x)
assert np.all(cf_constrained >= 0) and np.all(cf_constrained <= 1)

# Test from_dict and to_dict
scaler_1 = MinMaxTransformation().from_dict(minmax_t.to_dict())
assert np.allclose(minmax_t.transform(xs), scaler_1.transform(xs))

relax.data_utils.transforms.OneHotTransformation

[source]

relax.data_utils.transforms.OneHotTransformation ()

relax.data_utils.transforms.GumbelSoftmaxTransformation

[source]

class relax.data_utils.transforms.GumbelSoftmaxTransformation (tau=0.1)

Apply Gumbel softmax tricks for categorical transformation.

relax.data_utils.transforms.SoftmaxTransformation

[source]

class relax.data_utils.transforms.SoftmaxTransformation ()

Base class for all transformations.

def test_ohe_t(ohe_cls):
    xs = np.random.choice(['a', 'b', 'c'], size=(100, 1))
    ohe_t = ohe_cls().fit(xs)
    transformed_xs = ohe_t.transform(xs)
    rng_key = jax.random.PRNGKey(get_config().global_seed)
    assert ohe_t.is_categorical

    x = jax.random.uniform(rng_key, shape=(100, 3))
    # Test hard=True which applies softmax function.
    soft = ohe_t.apply_constraints(transformed_xs, x, hard=False, rng_key=rng_key)
    assert jnp.allclose(soft.sum(axis=-1), 1)
    assert jnp.all(soft >= 0)
    assert jnp.all(soft <= 1)
    assert jnp.allclose(jnp.zeros((len(x), 1)), ohe_t.compute_reg_loss(xs, soft, hard=False))
    assert jnp.allclose(soft, ohe_t.apply_constraints(transformed_xs, x, hard=False))

    # Test hard=True which enforce one-hot constraint.
    hard = ohe_t.apply_constraints(transformed_xs, x, hard=True, rng_key=rng_key)
    assert np.all([1 in x for x in hard])
    assert np.all([0 in x for x in hard])
    assert jnp.allclose(hard.sum(axis=-1), 1)
    assert jnp.allclose(jnp.zeros((len(x), 1)), ohe_t.compute_reg_loss(xs, hard, hard=False))

    # Test compute_reg_loss
    assert jnp.ndim(ohe_t.compute_reg_loss(xs, soft, hard=False)) == 0

    # Test from_dict and to_dict
    ohe_t_1 = ohe_cls().from_dict(ohe_t.to_dict())
    assert np.allclose(ohe_t.transform(xs), ohe_t_1.transform(xs))


test_ohe_t(SoftmaxTransformation)
test_ohe_t(GumbelSoftmaxTransformation)

relax.data_utils.transforms.IdentityTransformation

[source]

class relax.data_utils.transforms.IdentityTransformation ()

Base class for all transformations.

relax.data_utils.transforms.OrdinalTransformation

[source]

class relax.data_utils.transforms.OrdinalTransformation ()

Base class for all transformations.

xs = np.random.choice(['a', 'b', 'c'], size=(100, 1))
encoder = OrdinalTransformation().fit(xs)
transformed_xs = encoder.transform(xs)
assert np.all(encoder.inverse_transform(transformed_xs) == xs)
assert encoder.is_categorical

# Test from_dict and to_dict
encoder_1 = OrdinalTransformation().from_dict(encoder.to_dict())
assert np.allclose(encoder.transform(xs), encoder_1.transform(xs))

xs = np.random.randn(100, 1)
scaler = IdentityTransformation()
transformed_xs = scaler.fit_transform(xs)
assert np.all(transformed_xs == xs)

# Test from_dict and to_dict
scaler_1 = IdentityTransformation().from_dict(scaler.to_dict())
assert np.allclose(scaler.transform(xs), scaler_1.transform(xs))