configs_dict = {
"data_dir": "assets/data/s_adult.csv",
"data_name": "adult",
"continous_cols": ["age", "hours_per_week"],
"discret_cols": ["workclass", "education", "marital_status","occupation"],
"imutable_cols": ["age", "workclass", "marital_status"],
"normalizer": MinMaxScaler(),
"encoder": OneHotEncoder(sparse=False),
"sample_frac": 0.1,
"backend": "jax"
}
configs = TabularDataModuleConfigs(**configs_dict)Data Module
DataModule for training parametric models, generating and benchmarking CF explanations.
Data Module Interfaces
High-level interfaces for DataModule. Docs to be added.
BASEDATAMODULE
CLASS relax.data.module.BaseDataModule ()
DataModule Interface
Tabula Data Module
DataModule for processing tabular data.
TRANSFORMERMIXINTYPE
CLASS relax.data.module.TransformerMixinType ()
Mixin class for all transformers in scikit-learn.
If :term:get_feature_names_out is defined, then BaseEstimator will automatically wrap transform and fit_transform to follow the set_output API. See the :ref:developer_api_set_output for details.
:class:base.OneToOneFeatureMixin and :class:base.ClassNamePrefixFeaturesOutMixin are helpful mixins for defining :term:get_feature_names_out.
TABULARDATAMODULECONFIGS
CLASS relax.data.module.TabularDataModuleConfigs (data_dir, data_name, continous_cols=[], discret_cols=[], imutable_cols=[], normalizer=None, encoder=None, sample_frac=None, backend=‘jax’)
Configurator of TabularDataModule.
An example configurator of the adult dataset:
TABULARDATAMODULE
CLASS relax.data.module.TabularDataModule (data_config, data=None)
DataModule for tabular data
To load TabularDataModule from TabularDataModuleConfigs,
configs = TabularDataModuleConfigs(
data_name='adult',
data_dir='assets/data/s_adult.csv',
continous_cols=['age', 'hours_per_week'],
discret_cols=['workclass', 'education', 'marital_status', 'occupation'],
imutable_cols=['age', 'workclass', 'marital_status'],
sample_frac=0.1
)
dm = TabularDataModule(configs)We can also explicitly pass a pd.DataFrame to TabularDataModule. In this case, TabularDataModule will use the passed pd.DataFrame, instead of loading data from data_dir in TabularDataModuleConfigs.
df = pd.read_csv('assets/data/s_adult.csv')[:1000]
dm = TabularDataModule(configs, data=df)
assert len(dm.data) == 1000 # dm contains `df`TABULARDATAMODULE.DATA
relax.data.module.TabularDataModule.data ()
Loaded data in pd.DataFrame.
TabulaDataModule loads either a csv file (specified in data_dir in data_config), or directly passes a DataFrame (specified as data). Either way, this data needs to satisfy following conditions:
- It requires the target column (i.e., the labels) to be the last column of the DataFrame, and the rest columns are features. This target column needs to be binary-valued (i.e., it is either
0or1).- In the belowing example, income is the target column, and the rest columns are features.
- It requires
continous_colsanddiscret_colsindata_configto be subsets ofdata.columns. - It only use columns specified in
continous_colsanddiscret_cols.- It loads
continous_colsfirst, thendiscret_cols.
- It loads
dm.data.head()| age | hours_per_week | workclass | education | marital_status | occupation | income | |
|---|---|---|---|---|---|---|---|
| 0 | 42.0 | 45.0 | Private | HS-grad | Married | Blue-Collar | 1 |
| 1 | 32.0 | 40.0 | Self-Employed | Some-college | Married | Blue-Collar | 0 |
| 2 | 35.0 | 40.0 | Private | Assoc | Single | White-Collar | 1 |
| 3 | 36.0 | 40.0 | Private | HS-grad | Single | Blue-Collar | 0 |
| 4 | 57.0 | 35.0 | Private | School | Married | Service | 0 |
TABULARDATAMODULE.TRANSFORM
relax.data.module.TabularDataModule.transform (data)
Transform data into numerical representations.
By default, we transform continuous features via MinMaxScaler, and discrete features via OneHotEncoding.
A tabular data point x is encoded as x = [\underbrace{x_{0}, x_{1}, ..., x_{m}}_{\text{cont features}}, \underbrace{x_{m+1}^{c=1},..., x_{m+p}^{c=1}}_{\text{cat feature} (1)}, ..., \underbrace{x_{k-q}^{c=i},..., x_{k}^{^{c=i}}}_{\text{cat feature} (i)}]
df = dm.data.head()
X, y = dm.transform(df)
assert isinstance(X, np.ndarray)
assert isinstance(y, np.ndarray)
assert y.shape == (len(X), 1)<string>:1: DeprecationWarning: jax.numpy.DeviceArray is deprecated. Use jax.Array.
TABULARDATAMODULE.INVERSE_TRANSFORM
relax.data.module.TabularDataModule.inverse_transform (x, y=None)
Scaled back into pd.DataFrame.
TabularDataModule.inverse_transform scales numerical representations back to the original DataFrame.
dm.inverse_transform(X, y)| age | hours_per_week | workclass | education | marital_status | occupation | income | |
|---|---|---|---|---|---|---|---|
| 0 | 42.0 | 45.0 | Private | HS-grad | Married | Blue-Collar | 1 |
| 1 | 32.0 | 40.0 | Self-Employed | Some-college | Married | Blue-Collar | 0 |
| 2 | 35.0 | 40.0 | Private | Assoc | Single | White-Collar | 1 |
| 3 | 36.0 | 40.0 | Private | HS-grad | Single | Blue-Collar | 0 |
| 4 | 57.0 | 35.0 | Private | School | Married | Service | 0 |
If y is not passed, it will only scale back X.
dm.inverse_transform(X)| age | hours_per_week | workclass | education | marital_status | occupation | |
|---|---|---|---|---|---|---|
| 0 | 42.0 | 45.0 | Private | HS-grad | Married | Blue-Collar |
| 1 | 32.0 | 40.0 | Self-Employed | Some-college | Married | Blue-Collar |
| 2 | 35.0 | 40.0 | Private | Assoc | Single | White-Collar |
| 3 | 36.0 | 40.0 | Private | HS-grad | Single | Blue-Collar |
| 4 | 57.0 | 35.0 | Private | School | Married | Service |
<string>:1: DeprecationWarning: jax.numpy.DeviceArray is deprecated. Use jax.Array.
TABULARDATAMODULE.APPLY_CONSTRAINTS
relax.data.module.TabularDataModule.apply_constraints (x, cf, hard=False)
Apply categorical normalization and immutability constraints
TabularDataModule.apply_constraints does two things:
- It ensures that generated counterfactuals respect the one-hot encoding format (i.e., \sum_{p \to q} x^{c=i}_{p} = 1).
- It ensures the immutability constraints (i.e., immutable features defined in
imutable_colswill not be changed).
x, y = next(iter(dm.test_dataloader(batch_size=128)))
# unnormalized counterfactuals
cf = random.normal(
random.PRNGKey(0), x.shape
)
# normalized counterfactuals
cf_normed = dm.apply_constraints(x, cf)<string>:1: DeprecationWarning: jax.numpy.DeviceArray is deprecated. Use jax.Array.
TABULARDATAMODULE.APPLY_REGULARIZATION
relax.data.module.TabularDataModule.apply_regularization (x, cf)
Apply categorical constraints by adding regularization terms
x, y = next(iter(dm.test_dataloader(batch_size=128)))
# unnormalized counterfactuals
cf = random.normal(
random.PRNGKey(0), x.shape
)
# normalized counterfactuals
cf_normed = dm.apply_constraints(x, cf)SAMPLE
relax.data.module.sample (datamodule, frac=1.0)
Load Data
LOAD_DATA
relax.data.module.load_data (data_name, return_config=False, data_configs=None)
High-level util function for loading data and data_config.
load_data easily loads example datasets by passing the data_name. For example, you can load the adult as:
dm = load_data(data_name = 'adult')Underlying, load_data loads the default data_configs. To access this data_configs,
dm, data_configs = load_data(data_name = 'adult', return_config=True)If you want to override some of the data configs, you can pass it as an auxillary argumenet in data_configs. For example, if you want to use only 10% of the data, you can
dm = load_data(
data_name = 'adult', data_configs={'sample_frac': 0.1}
)Supported Datasets
load_data currently supports following datasets:
| # Cont Features | # Cat Features | # of Data Points | |
|---|---|---|---|
| adult | 2 | 6 | 32561 |
| heloc | 21 | 2 | 10459 |
| oulad | 23 | 8 | 32593 |
| credit | 20 | 3 | 30000 |
| cancer | 30 | 0 | 569 |
| student_performance | 2 | 14 | 649 |
| titanic | 2 | 24 | 891 |
| german | 7 | 13 | 1000 |
| spam | 57 | 0 | 4601 |
| ozone | 72 | 0 | 2534 |
| qsar | 38 | 3 | 1055 |
| bioresponse | 1776 | 0 | 3751 |
| churn | 3 | 16 | 7043 |
| road | 29 | 3 | 111762 |
DEFAULT_DATA_CONFIGS.keys()dict_keys(['adult', 'heloc', 'oulad', 'credit', 'cancer', 'student_performance', 'titanic', 'german', 'spam', 'ozone', 'qsar', 'bioresponse', 'churn', 'road'])