Data Module

DataModule for training parametric models, generating and benchmarking CF explanations.

Data Module Interfaces

High-level interfaces for DataModule. Docs to be added.

relax.data_module.BaseDataModule

[source]

class relax.data_module.BaseDataModule (config, name=None)

DataModule Interface

Data Module

DataModule for processing data, training models, and benchmarking CF explanations.

Config

relax.data_module.DataModuleConfig

[source]

class relax.data_module.DataModuleConfig (data_dir=None, data_name=None, continous_cols=[], discret_cols=[], imutable_cols=[], continuous_transformation=‘minmax’, discret_transformation=‘ohe’, sample_frac=None, train_indices=[], test_indices=[])

Configurator of DataModule.

Parameters:

  • data_dir (str) – The directory of dataset.
  • data_name (str) – The name of DataModule.
  • continous_cols (List[str], default=[]) – Continuous features/columns in the data.
  • discret_cols (List[str], default=[]) – Categorical features/columns in the data.
  • imutable_cols (List[str], default=[]) – Immutable features/columns in the data.
  • continuous_transformation (Optional[str], default=minmax) – Transformation for continuous features. None indicates unknown.
  • discret_transformation (Optional[str], default=ohe) – Transformation for categorical features. None indicates unknown.
  • sample_frac (Optional[float]) – Sample fraction of the data. Default to use the entire data.
  • train_indices (List[int], default=[]) – Indices of training data.
  • test_indices (List[int], default=[]) – Indices of testing data.

Utils

util functions for DataModule

relax.data_module.features2config

[source]

relax.data_module.features2config (features, name, return_dict=False)

Get DataModuleConfig from FeaturesList.

Parameters:

  • features (<class 'relax.data_utils.features.FeaturesList'>) – FeaturesList to be converted
  • name (<class 'str'>) – Name of the data used for DataModuleConfig
  • return_dict (<class 'bool'>, default=False) – Whether to return a dict or DataModuleConfig

Returns:

    (typing.Union[__main__.DataModuleConfig, typing.Dict]) – Return configs

relax.data_module.features2pandas

[source]

relax.data_module.features2pandas (features, labels)

Convert FeaturesList to pandas dataframe.

Parameters:

  • features (<class 'relax.data_utils.features.FeaturesList'>) – FeaturesList to be converted
  • labels (<class 'relax.data_utils.features.FeaturesList'>) – labels to be converted

Returns:

    (<class 'pandas.core.frame.DataFrame'>) – Return pandas dataframe

Example:

feats = FeaturesList([
    Feature("age", np.random.normal(0, 1, (10, 1)), 
            transformation='minmax', is_immutable=True),
    Feature("workclass", np.random.randint(0, 2, (10, 1)), 
            transformation='ohe'),
    Feature("education", np.random.randint(0, 2, (10, 1)), 
            transformation='ordinal'),    
])
labels = FeaturesList([
    Feature("income", np.random.randint(0, 2, (10, 1)), 
            transformation='identity'),
])
df = features2pandas(feats, labels)
assert isinstance(df, pd.DataFrame)
assert df.shape == (10, 4)

relax.data_module.dataframe2labels

[source]

relax.data_module.dataframe2labels (data, config)

Convert pandas dataframe of labels to FeaturesList.

relax.data_module.dataframe2features

[source]

relax.data_module.dataframe2features (data, config)

Convert pandas dataframe of features to FeaturesList.

Main Data Module

Main module.

relax.data_module.DataModule

[source]

class relax.data_module.DataModule (features, label, config=None, data=None, **kwargs)

DataModule for tabular data.

Methods

[source]

load_from_path (path, config=None)

Load DataModule from a directory.

Parameters:

  • path (<class 'str'>) – Path to the directory to load DataModule
  • config (typing.Union[typing.Dict, __main__.DataModuleConfig], default=None) – Configs of DataModule. This argument is ignored.

Returns:

    (<class '__main__.DataModule'>) – Initialized DataModule from path

[source]

from_config (config, data=None)

Parameters:

  • config (typing.Union[typing.Dict, __main__.DataModuleConfig]) – Configs of DataModule
  • data (<class 'pandas.core.frame.DataFrame'>, default=None) – Passed in pd.Dataframe

Returns:

    (<class '__main__.DataModule'>) – Initialized DataModule from configs and data

[source]

from_features (features, label, name=None)

Create DataModule from FeaturesList.

Parameters:

  • features (<class 'relax.data_utils.features.FeaturesList'>) – Features of DataModule
  • label (<class 'relax.data_utils.features.FeaturesList'>) – Labels of DataModule
  • name (<class 'str'>, default=None) – Name of DataModule

Returns:

    (<class '__main__.DataModule'>) – Initialized DataModule from features and labels

[source]

from_numpy (xs, ys, name=None, transformation=‘minmax’)

Create DataModule from numpy arrays. Note that the xs are treated as continuous features.

Parameters:

  • xs (<class 'numpy.ndarray'>) – Input data
  • ys (<class 'numpy.ndarray'>) – Labels
  • name (<class 'str'>, default=None) – Name of DataModule
  • transformation (<class 'str'>, default=minmax)

Returns:

    (<class '__main__.DataModule'>) – Initialized DataModule from numpy arrays

[source]

save (path)

Save DataModule to a directory.

Parameters:

  • path (<class 'str'>) – Path to the directory to save DataModule

[source]

transform (data)

Transform data to jax.Array.

Parameters:

  • data (typing.Union[pandas.core.frame.DataFrame, typing.Dict[str, jax.Array]]) – Data to be transformed

Returns:

    (<class 'jax.Array'>) – Transformed data

[source]

inverse_transform (data, return_type=‘pandas’)

Inverse transform data to pandas.DataFrame.

Parameters:

  • data (<class 'jax.Array'>) – Data to be inverse transformed
  • return_type (<class 'str'>, default=pandas) – Type of the returned data. Should be one of [‘pandas’, ‘dict’]

Returns:

    (<class 'pandas.core.frame.DataFrame'>) – Inverse transformed data

[source]

apply_constraints (xs, cfs, hard=False, rng_key=None, **kwargs)

Apply constraints to counterfactuals.

Parameters:

  • xs (<class 'jax.Array'>) – Input data
  • cfs (<class 'jax.Array'>) – Counterfactuals to be constrained
  • hard (<class 'bool'>, default=False) – Whether to apply hard constraints or not
  • rng_key (<function PRNGKey at 0x7fe228f3ab00>, default=None) – Random key
  • kwargs

Returns:

    (<class 'jax.Array'>) – Constrained counterfactuals

[source]

compute_reg_loss (xs, cfs, hard=False)

Compute regularization loss.

Parameters:

  • xs (<class 'jax.Array'>) – Input data
  • cfs (<class 'jax.Array'>) – Counterfactuals to be constrained
  • hard (<class 'bool'>, default=False) – Whether to apply hard constraints or not

Returns:

    (<class 'float'>)

[source]

set_transformations (feature_names_to_transformation)

Reset transformations for features.

Parameters:

  • feature_names_to_transformation (typing.Dict[str, typing.Union[str, typing.Dict, relax.data_utils.transforms.BaseTransformation]]) – Dict[feature_name, Transformation]

Returns:

    (<class '__main__.DataModule'>)

[source]

sample (size, stage=‘train’, key=None)

Sample data from DataModule.

Parameters:

  • size (float | int) – Size of the sample. If float, should be 0<=size<=1.
  • stage (<class 'str'>, default=train) – Stage of data to sample from. Should be one of [‘train’, ‘valid’, ‘test’]
  • key (<function PRNGKey at 0x7fe228f3ab00>, default=None) – Random key.

Returns:

    (typing.Tuple[jax.Array, jax.Array]) – Sampled data

# Test initialization
config = DataModuleConfig.load_from_json("assets/adult/data/config.json")
config_1 = config.dict()
config_1.update({"imutable_cols": []})
dm = DataModule.from_config(config)
dm_1 = DataModule.from_config(config.dict())
assert dm_equals(dm, dm_1)
dm_2 = DataModule.from_path("assets/adult/data")
assert dm_equals(dm, dm_2)
dm_3 = DataModule.from_config(config_1)
assert dm_equals(dm, dm_3)
assert dm_3.config.imutable_cols == []
feats = FeaturesList.load_from_path("assets/adult/data/features")
label = FeaturesList.load_from_path("assets/adult/data/label")
dm_4 = DataModule.from_features(feats, label)
assert dm_equals(dm, dm_4)
# Test from_numpy
xs, ys = make_classification(n_samples=100, n_features=5, n_informative=3, random_state=0)
dm_5 = DataModule.from_numpy(xs, ys, name="test", transformation='identity')
config_5 = dm_5.config
assert dm_5.config.data_name == "test"
assert dm_5.data.shape == (100, 6)
assert np.allclose(dm_5.data.to_numpy(), np.concatenate([xs, ys.reshape(-1, 1)], axis=1))
assert np.allclose(
    xs[config_5.train_indices],
    dm_5['train'][0]
)
assert np.allclose(
    xs[config_5.test_indices],
    dm_5['test'][0]
)
dm_5.save('tmp/test')
dm_6 = DataModule.load_from_path('tmp/test')
assert dm_equals(dm_5, dm_6)
shutil.rmtree("tmp/test")
# Test save and load
dm.save("tmp/adult")
dm_5 = DataModule.load_from_path("tmp/adult")
assert dm_equals(dm, dm_5)
shutil.rmtree("tmp/adult")
# Test set_transformations
dm_6 = deepcopy(dm)
dm_6.set_transformations({"age": 'identity'})
assert dm_6.features['age'].transformation.name == 'identity'
assert np.array_equal(dm_6.xs[:, :1], dm_6.data[['age']].to_numpy())
dm_6.set_transformations({feat: 'ordinal' for feat in config.discret_cols})
assert dm_6.xs.shape == (dm.data.shape[0], len(config.continous_cols) + len(config.discret_cols))

assert np.array_equal(dm_6.xs[:, :1], dm_6.data[['age']].to_numpy())

test_fail(lambda: dm_6.set_transformations({1: 'identity'}), contains="Invalid idx type")
test_fail(lambda: dm_6.set_transformations({"❤": 'identity'}), contains="Invalid feature name")
test_fail(lambda: dm_6.set_transformations({"age": '❤'}), contains="Unknown transformation")
test_fail(lambda: dm_6.set_transformations('❤'), contains="Invalid feature_names_to_transformation type")

dm_6.set_transformations({"age": MinMaxTransformation()})
assert np.allclose(dm_6.xs[:, :1], dm.xs[:, :1])
# Test sample
sampled_xs, sampled_ys = dm.sample(0.1)
assert len(sampled_xs) == len(sampled_ys)
assert sampled_xs.shape[0] == int(0.1 * dm['train'][0].shape[0])
assert not jnp.all(sampled_xs == dm['train'][0][:sampled_xs.shape[0]])

sampled_xs, sampled_ys = dm.sample(100)
assert len(sampled_xs) == len(sampled_ys)
assert sampled_xs.shape[0] == 100
assert not jnp.all(sampled_xs == dm['train'][0][:100])

test_fail(lambda: dm.sample(1.1), contains='should be a floating number 0<=size<=1,')
test_fail(lambda: dm.sample('train'), contains='or an integer')

xs = dm['train'][0]
cfs = jrand.uniform(jrand.PRNGKey(0), shape=xs.shape, minval=0.01, maxval=0.99)
cfs = dm.apply_constraints(xs, cfs, hard=False)
assert cfs.shape == xs.shape

cfs = dm.apply_constraints(xs, cfs, hard=True)
assert cfs.shape == xs.shape
# Test transform
data = dm.transform(dm.data)
assert np.allclose(data, dm.xs)

Load Data

# from sklearn.datasets import make_classification

# xs, ys = make_classification(n_samples=1000, n_features=10)
# xs = pd.DataFrame(xs, columns=[f"col_{i}" for i in range(10)])
# ys = pd.DataFrame(ys, columns=['label'])
# data = pd.concat([xs, ys], axis=1)
# os.makedirs('assets/dummy/data', exist_ok=True)
# data.to_csv('assets/dummy/data/data.csv', index=False)
# config = DataModuleConfig(
#     data_name="dummy", 
#     data_dir="assets/dummy/data/data.csv", 
#     continous_cols=[f"col_{i}" for i in range(10)]
# )
# dm = DataModule(config)
# dm.save('assets/dummy/data')
# for data_name in DEFAULT_DATA_CONFIGS.keys():
#     print(f"Loading {data_name}...")
#     shutil.rmtree(f'../relax-assets/{data_name}', ignore_errors=True)
#     conf_path = DEFAULT_DATA_CONFIGS[data_name]['conf']
#     config = load_json(conf_path)['data_configs']
#     dm_config = DataModuleConfig(**config)
#     dm = DataModule(dm_config)
#     dm.save(f'../relax-assets/{data_name}/data')
# for data_name in DEFAULT_DATA_CONFIGS.keys():
#     print(f"Loading {data_name}...")
#     DataModule.load_from_path(f'../relax-assets/{data_name}/data')
# config = load_json('assets/adult/configs.json')['data_configs']
# dm_config = DataModuleConfig(**config)
# dm = DataModule(dm_config)

relax.data_module.load_data

[source]

relax.data_module.load_data (data_name, return_config=False, data_configs=None)

High-level util function for loading data and data_config.

Parameters:

  • data_name (<class 'str'>) – The name of data
  • return_config (<class 'bool'>, default=False) – Deprecated
  • data_configs (<class 'dict'>, default=None) – Data configs to override default configuration

Returns:

    (typing.Union[__main__.DataModule, typing.Tuple[__main__.DataModule, __main__.DataModuleConfig]]) – Return DataModule or (DataModule, DataModuleConfig)

relax.data_module.download_data_module_files

[source]

relax.data_module.download_data_module_files (data_name, data_parent_dir, download_original_data=False)

Parameters:

  • data_name (<class 'str'>) – The name of data
  • data_parent_dir (<class 'pathlib.Path'>) – The directory to save data.
  • download_original_data (<class 'bool'>, default=False) – Download original data or not

load_data easily loads example datasets by passing the data_name. For example, you can load the adult as:

dm = load_data(data_name = 'adult')

Supported Datasets

load_data currently supports following datasets:

# Cont Features # Cat Features # of Data Points
adult 2 6 32561
heloc 21 2 10459
oulad 23 8 32593
credit 20 3 30000
cancer 30 0 569
student_performance 2 14 649
titanic 2 24 891
german 7 13 1000
spam 57 0 4601
ozone 72 0 2534
qsar 38 3 1055
bioresponse 1776 0 3751
churn 3 16 7043
road 29 3 111762
dummy 10 0 1000