= {
configs_dict "data_dir": "assets/data/s_adult.csv",
"data_name": "adult",
"continous_cols": ["age", "hours_per_week"],
"discret_cols": ["workclass", "education", "marital_status","occupation"],
"imutable_cols": ["age", "workclass", "marital_status"],
"normalizer": MinMaxScaler(),
"encoder": OneHotEncoder(sparse=False),
"sample_frac": 0.1,
"backend": "jax"
}= TabularDataModuleConfigs(**configs_dict) configs
Data Module
DataModule
for training parametric models, generating and benchmarking CF explanations.
Data Module Interfaces
High-level interfaces for DataModule
. Docs to be added.
BASEDATAMODULE
CLASS relax.data.module.BaseDataModule ()
DataModule Interface
Tabula Data Module
DataModule
for processing tabular data.
TRANSFORMERMIXINTYPE
CLASS relax.data.module.TransformerMixinType ()
Mixin class for all transformers in scikit-learn.
If :term:get_feature_names_out
is defined, then BaseEstimator
will automatically wrap transform
and fit_transform
to follow the set_output
API. See the :ref:developer_api_set_output
for details.
:class:base.OneToOneFeatureMixin
and :class:base.ClassNamePrefixFeaturesOutMixin
are helpful mixins for defining :term:get_feature_names_out
.
TABULARDATAMODULECONFIGS
CLASS relax.data.module.TabularDataModuleConfigs (data_dir, data_name, continous_cols=[], discret_cols=[], imutable_cols=[], normalizer=None, encoder=None, sample_frac=None, backend=‘jax’)
Configurator of TabularDataModule
.
An example configurator of the adult dataset:
TABULARDATAMODULE
CLASS relax.data.module.TabularDataModule (data_config, data=None)
DataModule for tabular data
To load TabularDataModule
from TabularDataModuleConfigs
,
= TabularDataModuleConfigs(
configs ='adult',
data_name='assets/data/s_adult.csv',
data_dir=['age', 'hours_per_week'],
continous_cols=['workclass', 'education', 'marital_status', 'occupation'],
discret_cols=['age', 'workclass', 'marital_status'],
imutable_cols=0.1
sample_frac
)
= TabularDataModule(configs) dm
We can also explicitly pass a pd.DataFrame
to TabularDataModule
. In this case, TabularDataModule
will use the passed pd.DataFrame
, instead of loading data from data_dir
in TabularDataModuleConfigs
.
= pd.read_csv('assets/data/s_adult.csv')[:1000]
df = TabularDataModule(configs, data=df)
dm assert len(dm.data) == 1000 # dm contains `df`
TABULARDATAMODULE.DATA
relax.data.module.TabularDataModule.data ()
Loaded data in pd.DataFrame
.
TabulaDataModule
loads either a csv file (specified in data_dir
in data_config
), or directly passes a DataFrame (specified as data
). Either way, this data needs to satisfy following conditions:
- It requires the target column (i.e., the labels) to be the last column of the DataFrame, and the rest columns are features. This target column needs to be binary-valued (i.e., it is either
0
or1
).- In the belowing example, income is the target column, and the rest columns are features.
- It requires
continous_cols
anddiscret_cols
indata_config
to be subsets ofdata.columns
. - It only use columns specified in
continous_cols
anddiscret_cols
.- It loads
continous_cols
first, thendiscret_cols
.
- It loads
dm.data.head()
age | hours_per_week | workclass | education | marital_status | occupation | income | |
---|---|---|---|---|---|---|---|
0 | 42.0 | 45.0 | Private | HS-grad | Married | Blue-Collar | 1 |
1 | 32.0 | 40.0 | Self-Employed | Some-college | Married | Blue-Collar | 0 |
2 | 35.0 | 40.0 | Private | Assoc | Single | White-Collar | 1 |
3 | 36.0 | 40.0 | Private | HS-grad | Single | Blue-Collar | 0 |
4 | 57.0 | 35.0 | Private | School | Married | Service | 0 |
TABULARDATAMODULE.TRANSFORM
relax.data.module.TabularDataModule.transform (data)
Transform data into numerical representations.
By default, we transform continuous features via MinMaxScaler
, and discrete features via OneHotEncoding
.
A tabular data point x is encoded as x = [\underbrace{x_{0}, x_{1}, ..., x_{m}}_{\text{cont features}}, \underbrace{x_{m+1}^{c=1},..., x_{m+p}^{c=1}}_{\text{cat feature} (1)}, ..., \underbrace{x_{k-q}^{c=i},..., x_{k}^{^{c=i}}}_{\text{cat feature} (i)}]
= dm.data.head()
df = dm.transform(df)
X, y
assert isinstance(X, np.ndarray)
assert isinstance(y, np.ndarray)
assert y.shape == (len(X), 1)
<string>:1: DeprecationWarning: jax.numpy.DeviceArray is deprecated. Use jax.Array.
TABULARDATAMODULE.INVERSE_TRANSFORM
relax.data.module.TabularDataModule.inverse_transform (x, y=None)
Scaled back into pd.DataFrame
.
TabularDataModule.inverse_transform
scales numerical representations back to the original DataFrame.
dm.inverse_transform(X, y)
age | hours_per_week | workclass | education | marital_status | occupation | income | |
---|---|---|---|---|---|---|---|
0 | 42.0 | 45.0 | Private | HS-grad | Married | Blue-Collar | 1 |
1 | 32.0 | 40.0 | Self-Employed | Some-college | Married | Blue-Collar | 0 |
2 | 35.0 | 40.0 | Private | Assoc | Single | White-Collar | 1 |
3 | 36.0 | 40.0 | Private | HS-grad | Single | Blue-Collar | 0 |
4 | 57.0 | 35.0 | Private | School | Married | Service | 0 |
If y
is not passed, it will only scale back X
.
dm.inverse_transform(X)
age | hours_per_week | workclass | education | marital_status | occupation | |
---|---|---|---|---|---|---|
0 | 42.0 | 45.0 | Private | HS-grad | Married | Blue-Collar |
1 | 32.0 | 40.0 | Self-Employed | Some-college | Married | Blue-Collar |
2 | 35.0 | 40.0 | Private | Assoc | Single | White-Collar |
3 | 36.0 | 40.0 | Private | HS-grad | Single | Blue-Collar |
4 | 57.0 | 35.0 | Private | School | Married | Service |
<string>:1: DeprecationWarning: jax.numpy.DeviceArray is deprecated. Use jax.Array.
TABULARDATAMODULE.APPLY_CONSTRAINTS
relax.data.module.TabularDataModule.apply_constraints (x, cf, hard=False)
Apply categorical normalization and immutability constraints
TabularDataModule.apply_constraints
does two things:
- It ensures that generated counterfactuals respect the one-hot encoding format (i.e., \sum_{p \to q} x^{c=i}_{p} = 1).
- It ensures the immutability constraints (i.e., immutable features defined in
imutable_cols
will not be changed).
= next(iter(dm.test_dataloader(batch_size=128)))
x, y # unnormalized counterfactuals
= random.normal(
cf 0), x.shape
random.PRNGKey(
)# normalized counterfactuals
= dm.apply_constraints(x, cf) cf_normed
<string>:1: DeprecationWarning: jax.numpy.DeviceArray is deprecated. Use jax.Array.
TABULARDATAMODULE.APPLY_REGULARIZATION
relax.data.module.TabularDataModule.apply_regularization (x, cf)
Apply categorical constraints by adding regularization terms
= next(iter(dm.test_dataloader(batch_size=128)))
x, y # unnormalized counterfactuals
= random.normal(
cf 0), x.shape
random.PRNGKey(
)# normalized counterfactuals
= dm.apply_constraints(x, cf) cf_normed
SAMPLE
relax.data.module.sample (datamodule, frac=1.0)
Load Data
LOAD_DATA
relax.data.module.load_data (data_name, return_config=False, data_configs=None)
High-level util function for loading data
and data_config
.
load_data
easily loads example datasets by passing the data_name
. For example, you can load the adult as:
= load_data(data_name = 'adult') dm
Underlying, load_data
loads the default data_configs
. To access this data_configs
,
= load_data(data_name = 'adult', return_config=True) dm, data_configs
If you want to override some of the data configs, you can pass it as an auxillary argumenet in data_configs
. For example, if you want to use only 10% of the data, you can
= load_data(
dm = 'adult', data_configs={'sample_frac': 0.1}
data_name )
Supported Datasets
load_data
currently supports following datasets:
# Cont Features | # Cat Features | # of Data Points | |
---|---|---|---|
adult | 2 | 6 | 32561 |
heloc | 21 | 2 | 10459 |
oulad | 23 | 8 | 32593 |
credit | 20 | 3 | 30000 |
cancer | 30 | 0 | 569 |
student_performance | 2 | 14 | 649 |
titanic | 2 | 24 | 891 |
german | 7 | 13 | 1000 |
spam | 57 | 0 | 4601 |
ozone | 72 | 0 | 2534 |
qsar | 38 | 3 | 1055 |
bioresponse | 1776 | 0 | 3751 |
churn | 3 | 16 | 7043 |
road | 29 | 3 | 111762 |
DEFAULT_DATA_CONFIGS.keys()
dict_keys(['adult', 'heloc', 'oulad', 'credit', 'cancer', 'student_performance', 'titanic', 'german', 'spam', 'ozone', 'qsar', 'bioresponse', 'churn', 'road'])