feat_cont = Feature(
name='continuous',
data=np.random.randn(100, 1),
transformation='minmax',
is_immutable=False,
)
assert feat_cont.transformed_data.shape == (100, 1)
assert feat_cont.transformed_data.min() >= 0
assert feat_cont.transformed_data.max() <= 1
assert jnp.allclose(
feat_cont.inverse_transform(feat_cont.transformed_data), feat_cont.data)
assert feat_cont.is_categorical is False
feat_cont_1 = feat_cont.with_transformed_data(feat_cont.transformed_data)
assert isinstance(feat_cont_1, Feature)
assert feat_cont_1 is not feat_cont
assert np.allclose(
feat_cont_1.data, feat_cont.data
)
assert feat_cont.transformation.to_dict() == feat_cont_1.transformation.to_dict()
feat_cat = Feature(
name='category',
data=np.random.choice(['a', 'b', 'c'], size=(100, 1)),
transformation='ohe',
is_immutable=False,
)
assert feat_cat.transformed_data.shape == (100, 3)
assert np.all(feat_cat.inverse_transform(feat_cat.transformed_data) == feat_cat.data)
assert feat_cat.is_categorical
feat_cat_1 = feat_cat.with_transformed_data(jax.nn.one_hot(jnp.array([0, 1, 2, 0, 1, 2]), 3))
assert feat_cat_1 is not feat_cat
assert np.array_equal(
feat_cat_1.data, np.array(['a', 'b', 'c', 'a', 'b', 'c']).reshape(-1, 1)
)
# Test serialization
d = feat_cont.to_dict()
feat_cont_1 = Feature.from_dict(d)
assert feat_cont_1.name == feat_cont.name
assert np.allclose(feat_cont_1.data, feat_cont.data)
assert np.allclose(feat_cont_1.transformed_data, feat_cont.transformed_data)
assert feat_cont_1.is_immutable == feat_cont.is_immutableFeature and Features List
relax.data_utils.features.Feature
class relax.data_utils.features.Feature (name, data, transformation, transformed_data=None, is_immutable=False, is_categorical=None)
THe feature class which represents a column in the dataset.
# Test set_transformation
feat_cat = Feature(
name='category',
data=np.random.choice(['a', 'b', 'c'], size=(100, 1)),
transformation='ohe',
is_immutable=False,
)
assert feat_cat.transformation.name == 'ohe'
assert feat_cat.transformed_data.shape == (100, 3)
feat_cat.set_transformation('ordinal')
assert feat_cat.transformation.name == 'ordinal'
assert feat_cat.is_categorical
assert feat_cat.transformed_data.shape == (100, 1)
assert feat_cat.is_immutable is Falserelax.data_utils.features.FeaturesList
class relax.data_utils.features.FeaturesList (features, *args, **kwargs)
Initialize self. See help(type(self)) for accurate signature.
df = pd.read_csv('../assets/adult/data/data.csv')
cont_feats = ['age', 'hours_per_week']
cat_feats = ["workclass", "education", "marital_status","occupation", "race", "gender"]
feats_list = FeaturesList([
Feature(name, df[name].to_numpy().reshape(-1, 1), 'minmax') for name in cont_feats
] + [
Feature(name, df[name].to_numpy().reshape(-1, 1), 'ohe') for name in cat_feats
])
assert feats_list.transformed_data.shape == (32561, 29)# test __get_item__
assert np.allclose(
feats_list['age'].transformed_data,
feats_list.transformed_data[:, 0:1]
)
assert np.allclose(
FeaturesList(feats_list[['age', 'hours_per_week', 'workclass']]).transformed_data,
feats_list.transformed_data[:, :6]
)# Test with_transformed_data
transformed_xs = feats_list.transformed_data
indices = np.random.choice(len(transformed_xs), size=100)
feats_list_1 = feats_list.with_transformed_data(transformed_xs[indices])
pd.testing.assert_frame_equal(
feats_list.to_pandas().iloc[indices].reset_index(drop=True),
feats_list_1.to_pandas(),
check_exact=False,
check_dtype=False,
check_index_type=False
)def test_set_transformations(transformation, correct_shape):
T = transformation
feats_list_2 = deepcopy(feats_list)
feats_list_2.set_transformations({
feat: T for feat in cat_feats
})
assert feats_list_2.transformed_data.shape == correct_shape
name = T.name if isinstance(T, BaseTransformation) else T
for feat in feats_list_2:
if feat.name in cat_feats:
assert feat.transformation.name == name
assert feat.is_categorical
else:
assert feat.transformation.name == 'minmax'
assert feat.is_categorical is False
assert feat.is_immutable is False
x = jax.random.uniform(jax.random.PRNGKey(0), shape=(100, correct_shape[-1]))
_ = feats_list_2.apply_constraints(feats_list_2.transformed_data[:100], x, hard=False)
_ = feats_list_2.apply_constraints(feats_list_2.transformed_data[:100], x, hard=True)test_set_transformations('ordinal', (32561, 8))
test_set_transformations('ohe', (32561, 29))
test_set_transformations('gumbel', (32561, 29))
# TODO: [bug] raise error when set_transformations is called with
# SoftmaxTransformation() or GumbelSoftmaxTransformation(),
# instead of "ohe" or "gumbel".
test_set_transformations(SoftmaxTransformation(), (32561, 29))
test_set_transformations(GumbelSoftmaxTransformation(), (32561, 29))# Test transform and inverse_transform
# Convert df to dict[str, np.ndarray]
df_dict = {k: np.array(v).reshape(-1, 1) for k, v in df.iloc[:, :-1].to_dict(orient='list').items()}
# feats_list.transform(df_dict) should be the same as feats_list.transformed_data
transformed_data = feats_list.transform(df_dict)
assert np.equal(feats_list.transformed_data, transformed_data).all()
# feats_list.inverse_transform(transformed_data) should be the same as df_dict
inverse_transformed_data = feats_list.inverse_transform(transformed_data)
pd.testing.assert_frame_equal(
pd.DataFrame.from_dict({k: v.reshape(-1) for k, v in inverse_transformed_data.items()}),
pd.DataFrame.from_dict({k: v.reshape(-1) for k, v in df_dict.items()}),
check_dtype=False, check_exact=False,
)# Test apply_constraints and compute_reg_loss
x = np.random.randn(10, 29)
constraint_cfs = feats_list.apply_constraints(feats_list.transformed_data[:10, :], x, hard=False)
assert constraint_cfs.shape == (10, 29)
assert np.allclose(
constraint_cfs[:, 2:].sum(axis=-1),
np.ones((10,)) * 6
)
assert constraint_cfs[: :2].min() >= 0 and constraint_cfs[: :2].max() <= 1
assert feats_list.apply_constraints(feats_list.transformed_data[:10, :], x, hard=True).shape == (10, 29)
reg_loss = feats_list.compute_reg_loss(feats_list.transformed_data, x)
assert jnp.ndim(reg_loss) == 0
assert np.all(reg_loss > 0)
assert np.allclose(feats_list.compute_reg_loss(x, constraint_cfs), 0)# Test `to_pandas`
feats_pd = feats_list.to_pandas()
pd.testing.assert_frame_equal(
feats_pd,
pd.DataFrame.from_dict({k: v.reshape(-1) for k, v in df_dict.items()}),
check_dtype=False,
)# Test save and load
feats_list.save('tmp/data_module/')
feats_list_1 = FeaturesList.load_from_path('tmp/data_module/')
# remove tmp folder
shutil.rmtree('tmp/data_module/')sk_ohe = skp.OneHotEncoder(sparse_output=False)
sk_minmax = skp.MinMaxScaler()
# for feat in feats_list.features:
for feat in feats_list:
if feat.name in cont_feats:
assert np.allclose(
sk_minmax.fit_transform(feat.data),
feat.transformed_data,
), f"Failed at {feat.name}. "
else:
assert np.allclose(
sk_ohe.fit_transform(feat.data),
feat.transformed_data,
), f"Failed at {feat.name}"