Data Module

DataModule for training parametric models, generating and benchmarking CF explanations.

Data Module Interfaces

High-level interfaces for DataModule. Docs to be added.


source

BASEDATAMODULE

CLASS relax.data.module.BaseDataModule ()

DataModule Interface

Tabula Data Module

DataModule for processing tabular data.


source

TRANSFORMERMIXINTYPE

CLASS relax.data.module.TransformerMixinType ()

Mixin class for all transformers in scikit-learn.

If :term:get_feature_names_out is defined, then BaseEstimator will automatically wrap transform and fit_transform to follow the set_output API. See the :ref:developer_api_set_output for details.

:class:base.OneToOneFeatureMixin and :class:base.ClassNamePrefixFeaturesOutMixin are helpful mixins for defining :term:get_feature_names_out.


source

TABULARDATAMODULECONFIGS

CLASS relax.data.module.TabularDataModuleConfigs (data_dir, data_name, continous_cols=[], discret_cols=[], imutable_cols=[], normalizer=None, encoder=None, sample_frac=None, backend=‘jax’)

Configurator of TabularDataModule.

Parameters:
  • data_dir (str) – The directory of dataset.
  • data_name (str) – The name of TabularDataModule.
  • continous_cols (List[str], default=[]) – Continuous features/columns in the data.
  • discret_cols (List[str], default=[]) – Categorical features/columns in the data.
  • imutable_cols (List[str], default=[]) – Immutable features/columns in the data.
  • normalizer (Optional[TransformerMixinType]) – Sklearn scalar for continuous features. Can be unfitted, fitted, or None. If not fitted, the TabularDataModule will fit using the training data. If fitted, no fitting will be applied. If None, no transformation will be applied. Default to MinMaxScaler().
  • encoder (Optional[TransformerMixinType]) – Fitted encoder for categorical features. Can be unfitted, fitted, or None. If not fitted, the TabularDataModule will fit using the training data. If fitted, no fitting will be applied. If None, no transformation will be applied. Default to OneHotEncoder(sparse=False).
  • sample_frac (Optional[float]) – Sample fraction of the data. Default to use the entire data.
  • backend (str, default=jax) – Dataloader backend. Currently supports: [‘jax’, ‘pytorch’]

An example configurator of the adult dataset:

configs_dict = {
    "data_dir": "assets/data/s_adult.csv",
    "data_name": "adult",
    "continous_cols": ["age", "hours_per_week"],
    "discret_cols": ["workclass", "education", "marital_status","occupation"],
    "imutable_cols": ["age", "workclass", "marital_status"],
    "normalizer": MinMaxScaler(),
    "encoder": OneHotEncoder(sparse=False),
    "sample_frac": 0.1,
    "backend": "jax"
}
configs = TabularDataModuleConfigs(**configs_dict)

source

TABULARDATAMODULE

CLASS relax.data.module.TabularDataModule (data_config, data=None)

DataModule for tabular data

Parameters:
  • data_config (dict | TabularDataModuleConfigs) – Configurator of TabularDataModule
  • data (pd.DataFrame, default=None) – Data in pd.DataFrame. If data is None, the DataModule will load data from data_dir.

To load TabularDataModule from TabularDataModuleConfigs,

configs = TabularDataModuleConfigs(
    data_name='adult',
    data_dir='assets/data/s_adult.csv',
    continous_cols=['age', 'hours_per_week'],
    discret_cols=['workclass', 'education', 'marital_status', 'occupation'],
    imutable_cols=['age', 'workclass', 'marital_status'],
    sample_frac=0.1
)

dm = TabularDataModule(configs)

We can also explicitly pass a pd.DataFrame to TabularDataModule. In this case, TabularDataModule will use the passed pd.DataFrame, instead of loading data from data_dir in TabularDataModuleConfigs.

df = pd.read_csv('assets/data/s_adult.csv')[:1000]
dm = TabularDataModule(configs, data=df)
assert len(dm.data) == 1000 # dm contains `df`

source

TABULARDATAMODULE.DATA

relax.data.module.TabularDataModule.data ()

Loaded data in pd.DataFrame.

TabulaDataModule loads either a csv file (specified in data_dir in data_config), or directly passes a DataFrame (specified as data). Either way, this data needs to satisfy following conditions:

  • It requires the target column (i.e., the labels) to be the last column of the DataFrame, and the rest columns are features. This target column needs to be binary-valued (i.e., it is either 0 or 1).
    • In the belowing example, income is the target column, and the rest columns are features.
  • It requires continous_cols and discret_cols in data_config to be subsets of data.columns.
  • It only use columns specified in continous_cols and discret_cols.
    • It loads continous_cols first, then discret_cols.
dm.data.head()
age hours_per_week workclass education marital_status occupation income
0 42.0 45.0 Private HS-grad Married Blue-Collar 1
1 32.0 40.0 Self-Employed Some-college Married Blue-Collar 0
2 35.0 40.0 Private Assoc Single White-Collar 1
3 36.0 40.0 Private HS-grad Single Blue-Collar 0
4 57.0 35.0 Private School Married Service 0

source

TABULARDATAMODULE.TRANSFORM

relax.data.module.TabularDataModule.transform (data)

Transform data into numerical representations.

Parameters:
  • data (pd.DataFrame) – Data to be transformed to numpy.ndarray
Returns:

    (Tuple[np.ndarray, np.ndarray]) – Return (X, y)

By default, we transform continuous features via MinMaxScaler, and discrete features via OneHotEncoding.

A tabular data point x is encoded as x = [\underbrace{x_{0}, x_{1}, ..., x_{m}}_{\text{cont features}}, \underbrace{x_{m+1}^{c=1},..., x_{m+p}^{c=1}}_{\text{cat feature} (1)}, ..., \underbrace{x_{k-q}^{c=i},..., x_{k}^{^{c=i}}}_{\text{cat feature} (i)}]

df = dm.data.head()
X, y = dm.transform(df)

assert isinstance(X, np.ndarray)
assert isinstance(y, np.ndarray)
assert y.shape == (len(X), 1)
<string>:1: DeprecationWarning: jax.numpy.DeviceArray is deprecated. Use jax.Array.

source

TABULARDATAMODULE.INVERSE_TRANSFORM

relax.data.module.TabularDataModule.inverse_transform (x, y=None)

Scaled back into pd.DataFrame.

Parameters:
  • x (jnp.DeviceArray) – The transformed input to be scaled back
  • y (jnp.DeviceArray, default=None) – The transformed label to be scaled back. If None, the target columns will not be scaled back.
Returns:

    (pd.DataFrame) – Transformed pd.DataFrame.

TabularDataModule.inverse_transform scales numerical representations back to the original DataFrame.

dm.inverse_transform(X, y)
age hours_per_week workclass education marital_status occupation income
0 42.0 45.0 Private HS-grad Married Blue-Collar 1
1 32.0 40.0 Self-Employed Some-college Married Blue-Collar 0
2 35.0 40.0 Private Assoc Single White-Collar 1
3 36.0 40.0 Private HS-grad Single Blue-Collar 0
4 57.0 35.0 Private School Married Service 0

If y is not passed, it will only scale back X.

dm.inverse_transform(X)
age hours_per_week workclass education marital_status occupation
0 42.0 45.0 Private HS-grad Married Blue-Collar
1 32.0 40.0 Self-Employed Some-college Married Blue-Collar
2 35.0 40.0 Private Assoc Single White-Collar
3 36.0 40.0 Private HS-grad Single Blue-Collar
4 57.0 35.0 Private School Married Service
<string>:1: DeprecationWarning: jax.numpy.DeviceArray is deprecated. Use jax.Array.

source

TABULARDATAMODULE.APPLY_CONSTRAINTS

relax.data.module.TabularDataModule.apply_constraints (x, cf, hard=False)

Apply categorical normalization and immutability constraints

Parameters:
  • x (jnp.DeviceArray) – input
  • cf (jnp.DeviceArray) – Unnormalized counterfactuals
  • hard (bool, default=False) – Apply hard constraints or not
Returns:

    (jnp.DeviceArray)

TabularDataModule.apply_constraints does two things:

  1. It ensures that generated counterfactuals respect the one-hot encoding format (i.e., \sum_{p \to q} x^{c=i}_{p} = 1).
  2. It ensures the immutability constraints (i.e., immutable features defined in imutable_cols will not be changed).
x, y = next(iter(dm.test_dataloader(batch_size=128)))
# unnormalized counterfactuals
cf = random.normal(
    random.PRNGKey(0), x.shape
)
# normalized counterfactuals
cf_normed = dm.apply_constraints(x, cf)
<string>:1: DeprecationWarning: jax.numpy.DeviceArray is deprecated. Use jax.Array.

source

TABULARDATAMODULE.APPLY_REGULARIZATION

relax.data.module.TabularDataModule.apply_regularization (x, cf)

Apply categorical constraints by adding regularization terms

Parameters:
  • x (jnp.DeviceArray) – Input
  • cf (jnp.DeviceArray) – Unnormalized counterfactuals
Returns:

    (float) – Return regularization loss

x, y = next(iter(dm.test_dataloader(batch_size=128)))
# unnormalized counterfactuals
cf = random.normal(
    random.PRNGKey(0), x.shape
)
# normalized counterfactuals
cf_normed = dm.apply_constraints(x, cf)

source

SAMPLE

relax.data.module.sample (datamodule, frac=1.0)

Load Data


source

LOAD_DATA

relax.data.module.load_data (data_name, return_config=False, data_configs=None)

High-level util function for loading data and data_config.

Parameters:
  • data_name (str) – The name of data
  • return_config (bool, default=False) – Return data_configor not
  • data_configs (dict, default=None) – Data configs to override default configuration
Returns:

    (TabularDataModule | Tuple[TabularDataModule, TabularDataModuleConfigs])

load_data easily loads example datasets by passing the data_name. For example, you can load the adult as:

dm = load_data(data_name = 'adult')

Underlying, load_data loads the default data_configs. To access this data_configs,

dm, data_configs = load_data(data_name = 'adult', return_config=True)

If you want to override some of the data configs, you can pass it as an auxillary argumenet in data_configs. For example, if you want to use only 10% of the data, you can

dm = load_data(
    data_name = 'adult', data_configs={'sample_frac': 0.1}
)

Supported Datasets

load_data currently supports following datasets:

# Cont Features # Cat Features # of Data Points
adult 2 6 32561
heloc 21 2 10459
oulad 23 8 32593
credit 20 3 30000
cancer 30 0 569
student_performance 2 14 649
titanic 2 24 891
german 7 13 1000
spam 57 0 4601
ozone 72 0 2534
qsar 38 3 1055
bioresponse 1776 0 3751
churn 3 16 7043
road 29 3 111762
DEFAULT_DATA_CONFIGS.keys()
dict_keys(['adult', 'heloc', 'oulad', 'credit', 'cancer', 'student_performance', 'titanic', 'german', 'spam', 'ozone', 'qsar', 'bioresponse', 'churn', 'road'])