Feature and Features List

relax.data_utils.features.Feature

[source]

class relax.data_utils.features.Feature (name, data, transformation, transformed_data=None, is_immutable=False, is_categorical=None)

THe feature class which represents a column in the dataset.

feat_cont = Feature(
    name='continuous',
    data=np.random.randn(100, 1),
    transformation='minmax',
    is_immutable=False,
)
assert feat_cont.transformed_data.shape == (100, 1)
assert feat_cont.transformed_data.min() >= 0
assert feat_cont.transformed_data.max() <= 1
assert jnp.allclose(
    feat_cont.inverse_transform(feat_cont.transformed_data), feat_cont.data)
assert feat_cont.is_categorical is False

feat_cont_1 = feat_cont.with_transformed_data(feat_cont.transformed_data)
assert isinstance(feat_cont_1, Feature)
assert feat_cont_1 is not feat_cont
assert np.allclose(
    feat_cont_1.data, feat_cont.data
)
assert feat_cont.transformation.to_dict() == feat_cont_1.transformation.to_dict()

feat_cat = Feature(
    name='category',
    data=np.random.choice(['a', 'b', 'c'], size=(100, 1)),
    transformation='ohe',
    is_immutable=False,
)
assert feat_cat.transformed_data.shape == (100, 3)
assert np.all(feat_cat.inverse_transform(feat_cat.transformed_data) == feat_cat.data)
assert feat_cat.is_categorical

feat_cat_1 = feat_cat.with_transformed_data(jax.nn.one_hot(jnp.array([0, 1, 2, 0, 1, 2]), 3))
assert feat_cat_1 is not feat_cat
assert np.array_equal(
    feat_cat_1.data, np.array(['a', 'b', 'c', 'a', 'b', 'c']).reshape(-1, 1)
) 

# Test serialization
d = feat_cont.to_dict()
feat_cont_1 = Feature.from_dict(d)
assert feat_cont_1.name == feat_cont.name
assert np.allclose(feat_cont_1.data, feat_cont.data)
assert np.allclose(feat_cont_1.transformed_data, feat_cont.transformed_data)
assert feat_cont_1.is_immutable == feat_cont.is_immutable
# Test set_transformation
feat_cat = Feature(
    name='category',
    data=np.random.choice(['a', 'b', 'c'], size=(100, 1)),
    transformation='ohe',
    is_immutable=False,
)
assert feat_cat.transformation.name == 'ohe'
assert feat_cat.transformed_data.shape == (100, 3)
feat_cat.set_transformation('ordinal')
assert feat_cat.transformation.name == 'ordinal'
assert feat_cat.is_categorical
assert feat_cat.transformed_data.shape == (100, 1)
assert feat_cat.is_immutable is False

relax.data_utils.features.FeaturesList

[source]

class relax.data_utils.features.FeaturesList (features, *args, **kwargs)

Initialize self. See help(type(self)) for accurate signature.

df = pd.read_csv('../assets/adult/data/data.csv')
cont_feats = ['age', 'hours_per_week']
cat_feats = ["workclass", "education", "marital_status","occupation", "race", "gender"]

feats_list = FeaturesList([
    Feature(name, df[name].to_numpy().reshape(-1, 1), 'minmax') for name in cont_feats
] + [
    Feature(name, df[name].to_numpy().reshape(-1, 1), 'ohe') for name in cat_feats
])
assert feats_list.transformed_data.shape == (32561, 29)
# test __get_item__
assert np.allclose(
    feats_list['age'].transformed_data,
    feats_list.transformed_data[:, 0:1]
)
assert np.allclose(
    FeaturesList(feats_list[['age', 'hours_per_week', 'workclass']]).transformed_data,
    feats_list.transformed_data[:, :6]
)
# Test with_transformed_data
transformed_xs = feats_list.transformed_data
indices = np.random.choice(len(transformed_xs), size=100)
feats_list_1 = feats_list.with_transformed_data(transformed_xs[indices])

pd.testing.assert_frame_equal(
    feats_list.to_pandas().iloc[indices].reset_index(drop=True),
    feats_list_1.to_pandas(),
    check_exact=False,
    check_dtype=False,
    check_index_type=False
)
def test_set_transformations(transformation, correct_shape):
    T = transformation
    feats_list_2 = deepcopy(feats_list)
    feats_list_2.set_transformations({
        feat: T for feat in cat_feats
    })
    assert feats_list_2.transformed_data.shape == correct_shape
    name = T.name if isinstance(T, BaseTransformation) else T

    for feat in feats_list_2:
        if feat.name in cat_feats:  
            assert feat.transformation.name == name
            assert feat.is_categorical
        else:
            assert feat.transformation.name == 'minmax'                       
            assert feat.is_categorical is False
        assert feat.is_immutable is False

    x = jax.random.uniform(jax.random.PRNGKey(0), shape=(100, correct_shape[-1]))
    _ = feats_list_2.apply_constraints(feats_list_2.transformed_data[:100], x, hard=False)
    _ = feats_list_2.apply_constraints(feats_list_2.transformed_data[:100], x, hard=True)
test_set_transformations('ordinal', (32561, 8))
test_set_transformations('ohe', (32561, 29))
test_set_transformations('gumbel', (32561, 29))
# TODO: [bug] raise error when set_transformations is called with 
# SoftmaxTransformation() or GumbelSoftmaxTransformation(),
# instead of "ohe" or "gumbel".
test_set_transformations(SoftmaxTransformation(), (32561, 29))
test_set_transformations(GumbelSoftmaxTransformation(), (32561, 29))
# Test transform and inverse_transform
# Convert df to dict[str, np.ndarray]
df_dict = {k: np.array(v).reshape(-1, 1) for k, v in df.iloc[:, :-1].to_dict(orient='list').items()}
# feats_list.transform(df_dict) should be the same as feats_list.transformed_data
transformed_data = feats_list.transform(df_dict)
assert np.equal(feats_list.transformed_data, transformed_data).all()
# feats_list.inverse_transform(transformed_data) should be the same as df_dict
inverse_transformed_data = feats_list.inverse_transform(transformed_data)
pd.testing.assert_frame_equal(
    pd.DataFrame.from_dict({k: v.reshape(-1) for k, v in inverse_transformed_data.items()}),
    pd.DataFrame.from_dict({k: v.reshape(-1) for k, v in df_dict.items()}),
    check_dtype=False, check_exact=False,
)
# Test apply_constraints and compute_reg_loss
x = np.random.randn(10, 29)
constraint_cfs = feats_list.apply_constraints(feats_list.transformed_data[:10, :], x, hard=False)
assert constraint_cfs.shape == (10, 29)
assert np.allclose(
    constraint_cfs[:, 2:].sum(axis=-1),
    np.ones((10,)) * 6
)
assert constraint_cfs[: :2].min() >= 0 and constraint_cfs[: :2].max() <= 1
assert feats_list.apply_constraints(feats_list.transformed_data[:10, :], x, hard=True).shape == (10, 29)

reg_loss = feats_list.compute_reg_loss(feats_list.transformed_data, x)
assert jnp.ndim(reg_loss) == 0
assert np.all(reg_loss > 0)
assert np.allclose(feats_list.compute_reg_loss(x, constraint_cfs), 0)
# Test `to_pandas`
feats_pd = feats_list.to_pandas()
pd.testing.assert_frame_equal(
    feats_pd,
    pd.DataFrame.from_dict({k: v.reshape(-1) for k, v in df_dict.items()}),
    check_dtype=False,
)
# Test save and load
feats_list.save('tmp/data_module/')
feats_list_1 = FeaturesList.load_from_path('tmp/data_module/')
# remove tmp folder
shutil.rmtree('tmp/data_module/')
sk_ohe = skp.OneHotEncoder(sparse_output=False)
sk_minmax = skp.MinMaxScaler()

# for feat in feats_list.features:
for feat in feats_list:
    if feat.name in cont_feats:
        assert np.allclose(
            sk_minmax.fit_transform(feat.data),
            feat.transformed_data,
        ), f"Failed at {feat.name}. "
    else:
        assert np.allclose(
            sk_ohe.fit_transform(feat.data),
            feat.transformed_data,
        ), f"Failed at {feat.name}"