= FeaturesList([
feats "age", np.random.normal(0, 1, (10, 1)),
Feature(='minmax', is_immutable=True),
transformation"workclass", np.random.randint(0, 2, (10, 1)),
Feature(='ohe'),
transformation"education", np.random.randint(0, 2, (10, 1)),
Feature(='ordinal'),
transformation
])= FeaturesList([
labels "income", np.random.randint(0, 2, (10, 1)),
Feature(='identity'),
transformation
])= features2pandas(feats, labels)
df assert isinstance(df, pd.DataFrame)
assert df.shape == (10, 4)
Data Module
DataModule
for training parametric models, generating and benchmarking CF explanations.
Data Module Interfaces
High-level interfaces for DataModule
. Docs to be added.
relax.data_module.BaseDataModule
class relax.data_module.BaseDataModule (config, name=None)
DataModule Interface
Data Module
DataModule
for processing data, training models, and benchmarking CF explanations.
Config
relax.data_module.DataModuleConfig
class relax.data_module.DataModuleConfig (data_dir=None, data_name=None, continous_cols=[], discret_cols=[], imutable_cols=[], continuous_transformation=‘minmax’, discret_transformation=‘ohe’, sample_frac=None, train_indices=[], test_indices=[])
Configurator of DataModule
.
Parameters:
- data_dir (
str
) – The directory of dataset. - data_name (
str
) – The name ofDataModule
. - continous_cols (
List[str]
, default=[]) – Continuous features/columns in the data. - discret_cols (
List[str]
, default=[]) – Categorical features/columns in the data. - imutable_cols (
List[str]
, default=[]) – Immutable features/columns in the data. - continuous_transformation (
Optional[str]
, default=minmax) – Transformation for continuous features.None
indicates unknown. - discret_transformation (
Optional[str]
, default=ohe) – Transformation for categorical features.None
indicates unknown. - sample_frac (
Optional[float]
) – Sample fraction of the data. Default to use the entire data. - train_indices (
List[int]
, default=[]) – Indices of training data. - test_indices (
List[int]
, default=[]) – Indices of testing data.
Utils
util functions for
DataModule
relax.data_module.features2config
relax.data_module.features2config (features, name, return_dict=False)
Get DataModuleConfig
from FeaturesList
.
Parameters:
- features (
<class 'relax.data_utils.features.FeaturesList'>
) – FeaturesList to be converted - name (
<class 'str'>
) – Name of the data used forDataModuleConfig
- return_dict (
<class 'bool'>
, default=False) – Whether to return a dict orDataModuleConfig
Returns:
(typing.Union[__main__.DataModuleConfig, typing.Dict]
) – Return configs
relax.data_module.features2pandas
relax.data_module.features2pandas (features, labels)
Convert FeaturesList
to pandas dataframe.
Parameters:
- features (
<class 'relax.data_utils.features.FeaturesList'>
) – FeaturesList to be converted - labels (
<class 'relax.data_utils.features.FeaturesList'>
) – labels to be converted
Returns:
(<class 'pandas.core.frame.DataFrame'>
) – Return pandas dataframe
Example:
relax.data_module.dataframe2labels
relax.data_module.dataframe2labels (data, config)
Convert pandas dataframe of labels to FeaturesList
.
relax.data_module.dataframe2features
relax.data_module.dataframe2features (data, config)
Convert pandas dataframe of features to FeaturesList
.
Main Data Module
Main module.
relax.data_module.DataModule
class relax.data_module.DataModule (features, label, config=None, data=None, **kwargs)
DataModule for tabular data.
Methods
load_from_path (path, config=None)
Load DataModule
from a directory.
Parameters:
- path (
<class 'str'>
) – Path to the directory to loadDataModule
- config (
typing.Union[typing.Dict, __main__.DataModuleConfig]
, default=None) – Configs ofDataModule
. This argument is ignored.
Returns:
(<class '__main__.DataModule'>
) – Initialized DataModule
from path
from_config (config, data=None)
Parameters:
- config (
typing.Union[typing.Dict, __main__.DataModuleConfig]
) – Configs ofDataModule
- data (
<class 'pandas.core.frame.DataFrame'>
, default=None) – Passed in pd.Dataframe
Returns:
(<class '__main__.DataModule'>
) – Initialized DataModule
from configs and data
from_features (features, label, name=None)
Create DataModule
from FeaturesList
.
Parameters:
- features (
<class 'relax.data_utils.features.FeaturesList'>
) – Features ofDataModule
- label (
<class 'relax.data_utils.features.FeaturesList'>
) – Labels ofDataModule
- name (
<class 'str'>
, default=None) – Name ofDataModule
Returns:
(<class '__main__.DataModule'>
) – Initialized DataModule
from features and labels
from_numpy (xs, ys, name=None, transformation=‘minmax’)
Create DataModule
from numpy arrays. Note that the xs
are treated as continuous features.
Parameters:
- xs (
<class 'numpy.ndarray'>
) – Input data - ys (
<class 'numpy.ndarray'>
) – Labels - name (
<class 'str'>
, default=None) – Name ofDataModule
- transformation (
<class 'str'>
, default=minmax)
Returns:
(<class '__main__.DataModule'>
) – Initialized DataModule
from numpy arrays
save (path)
Save DataModule
to a directory.
Parameters:
- path (
<class 'str'>
) – Path to the directory to saveDataModule
transform (data)
Transform data to jax.Array
.
Parameters:
- data (
typing.Union[pandas.core.frame.DataFrame, typing.Dict[str, jax.Array]]
) – Data to be transformed
Returns:
(<class 'jax.Array'>
) – Transformed data
inverse_transform (data, return_type=‘pandas’)
Inverse transform data to pandas.DataFrame
.
Parameters:
- data (
<class 'jax.Array'>
) – Data to be inverse transformed - return_type (
<class 'str'>
, default=pandas) – Type of the returned data. Should be one of [‘pandas’, ‘dict’]
Returns:
(<class 'pandas.core.frame.DataFrame'>
) – Inverse transformed data
apply_constraints (xs, cfs, hard=False, rng_key=None, **kwargs)
Apply constraints to counterfactuals.
Parameters:
- xs (
<class 'jax.Array'>
) – Input data - cfs (
<class 'jax.Array'>
) – Counterfactuals to be constrained - hard (
<class 'bool'>
, default=False) – Whether to apply hard constraints or not - rng_key (
<function PRNGKey at 0x7fe228f3ab00>
, default=None) – Random key - kwargs
Returns:
(<class 'jax.Array'>
) – Constrained counterfactuals
compute_reg_loss (xs, cfs, hard=False)
Compute regularization loss.
Parameters:
- xs (
<class 'jax.Array'>
) – Input data - cfs (
<class 'jax.Array'>
) – Counterfactuals to be constrained - hard (
<class 'bool'>
, default=False) – Whether to apply hard constraints or not
Returns:
(<class 'float'>
)
set_transformations (feature_names_to_transformation)
Reset transformations for features.
Parameters:
- feature_names_to_transformation (
typing.Dict[str, typing.Union[str, typing.Dict, relax.data_utils.transforms.BaseTransformation]]
) – Dict[feature_name, Transformation]
Returns:
(<class '__main__.DataModule'>
)
sample (size, stage=‘train’, key=None)
Sample data from DataModule
.
Parameters:
- size (
float | int
) – Size of the sample. If float, should be 0<=size<=1. - stage (
<class 'str'>
, default=train) – Stage of data to sample from. Should be one of [‘train’, ‘valid’, ‘test’] - key (
<function PRNGKey at 0x7fe228f3ab00>
, default=None) – Random key.
Returns:
(typing.Tuple[jax.Array, jax.Array]
) – Sampled data
# Test initialization
= DataModuleConfig.load_from_json("assets/adult/data/config.json")
config = config.dict()
config_1 "imutable_cols": []})
config_1.update({= DataModule.from_config(config)
dm = DataModule.from_config(config.dict())
dm_1 assert dm_equals(dm, dm_1)
= DataModule.from_path("assets/adult/data")
dm_2 assert dm_equals(dm, dm_2)
= DataModule.from_config(config_1)
dm_3 assert dm_equals(dm, dm_3)
assert dm_3.config.imutable_cols == []
= FeaturesList.load_from_path("assets/adult/data/features")
feats = FeaturesList.load_from_path("assets/adult/data/label")
label = DataModule.from_features(feats, label)
dm_4 assert dm_equals(dm, dm_4)
# Test from_numpy
= make_classification(n_samples=100, n_features=5, n_informative=3, random_state=0)
xs, ys = DataModule.from_numpy(xs, ys, name="test", transformation='identity')
dm_5 = dm_5.config
config_5 assert dm_5.config.data_name == "test"
assert dm_5.data.shape == (100, 6)
assert np.allclose(dm_5.data.to_numpy(), np.concatenate([xs, ys.reshape(-1, 1)], axis=1))
assert np.allclose(
xs[config_5.train_indices],'train'][0]
dm_5[
)assert np.allclose(
xs[config_5.test_indices],'test'][0]
dm_5[
)'tmp/test')
dm_5.save(= DataModule.load_from_path('tmp/test')
dm_6 assert dm_equals(dm_5, dm_6)
"tmp/test") shutil.rmtree(
# Test save and load
"tmp/adult")
dm.save(= DataModule.load_from_path("tmp/adult")
dm_5 assert dm_equals(dm, dm_5)
"tmp/adult") shutil.rmtree(
# Test set_transformations
= deepcopy(dm)
dm_6 "age": 'identity'})
dm_6.set_transformations({assert dm_6.features['age'].transformation.name == 'identity'
assert np.array_equal(dm_6.xs[:, :1], dm_6.data[['age']].to_numpy())
'ordinal' for feat in config.discret_cols})
dm_6.set_transformations({feat: assert dm_6.xs.shape == (dm.data.shape[0], len(config.continous_cols) + len(config.discret_cols))
assert np.array_equal(dm_6.xs[:, :1], dm_6.data[['age']].to_numpy())
lambda: dm_6.set_transformations({1: 'identity'}), contains="Invalid idx type")
test_fail(lambda: dm_6.set_transformations({"❤": 'identity'}), contains="Invalid feature name")
test_fail(lambda: dm_6.set_transformations({"age": '❤'}), contains="Unknown transformation")
test_fail(lambda: dm_6.set_transformations('❤'), contains="Invalid feature_names_to_transformation type")
test_fail(
"age": MinMaxTransformation()})
dm_6.set_transformations({assert np.allclose(dm_6.xs[:, :1], dm.xs[:, :1])
# Test sample
= dm.sample(0.1)
sampled_xs, sampled_ys assert len(sampled_xs) == len(sampled_ys)
assert sampled_xs.shape[0] == int(0.1 * dm['train'][0].shape[0])
assert not jnp.all(sampled_xs == dm['train'][0][:sampled_xs.shape[0]])
= dm.sample(100)
sampled_xs, sampled_ys assert len(sampled_xs) == len(sampled_ys)
assert sampled_xs.shape[0] == 100
assert not jnp.all(sampled_xs == dm['train'][0][:100])
lambda: dm.sample(1.1), contains='should be a floating number 0<=size<=1,')
test_fail(lambda: dm.sample('train'), contains='or an integer')
test_fail(
= dm['train'][0]
xs = jrand.uniform(jrand.PRNGKey(0), shape=xs.shape, minval=0.01, maxval=0.99)
cfs = dm.apply_constraints(xs, cfs, hard=False)
cfs assert cfs.shape == xs.shape
= dm.apply_constraints(xs, cfs, hard=True)
cfs assert cfs.shape == xs.shape
# Test transform
= dm.transform(dm.data)
data assert np.allclose(data, dm.xs)
Load Data
# from sklearn.datasets import make_classification
# xs, ys = make_classification(n_samples=1000, n_features=10)
# xs = pd.DataFrame(xs, columns=[f"col_{i}" for i in range(10)])
# ys = pd.DataFrame(ys, columns=['label'])
# data = pd.concat([xs, ys], axis=1)
# os.makedirs('assets/dummy/data', exist_ok=True)
# data.to_csv('assets/dummy/data/data.csv', index=False)
# config = DataModuleConfig(
# data_name="dummy",
# data_dir="assets/dummy/data/data.csv",
# continous_cols=[f"col_{i}" for i in range(10)]
# )
# dm = DataModule(config)
# dm.save('assets/dummy/data')
# for data_name in DEFAULT_DATA_CONFIGS.keys():
# print(f"Loading {data_name}...")
# shutil.rmtree(f'../relax-assets/{data_name}', ignore_errors=True)
# conf_path = DEFAULT_DATA_CONFIGS[data_name]['conf']
# config = load_json(conf_path)['data_configs']
# dm_config = DataModuleConfig(**config)
# dm = DataModule(dm_config)
# dm.save(f'../relax-assets/{data_name}/data')
# for data_name in DEFAULT_DATA_CONFIGS.keys():
# print(f"Loading {data_name}...")
# DataModule.load_from_path(f'../relax-assets/{data_name}/data')
# config = load_json('assets/adult/configs.json')['data_configs']
# dm_config = DataModuleConfig(**config)
# dm = DataModule(dm_config)
relax.data_module.load_data
relax.data_module.load_data (data_name, return_config=False, data_configs=None)
High-level util function for loading data
and data_config
.
Parameters:
- data_name (
<class 'str'>
) – The name of data - return_config (
<class 'bool'>
, default=False) – Deprecated - data_configs (
<class 'dict'>
, default=None) – Data configs to override default configuration
Returns:
(typing.Union[__main__.DataModule, typing.Tuple[__main__.DataModule, __main__.DataModuleConfig]]
) – Return DataModule
or (DataModule
, DataModuleConfig
)
relax.data_module.download_data_module_files
relax.data_module.download_data_module_files (data_name, data_parent_dir, download_original_data=False)
Parameters:
- data_name (
<class 'str'>
) – The name of data - data_parent_dir (
<class 'pathlib.Path'>
) – The directory to save data. - download_original_data (
<class 'bool'>
, default=False) – Download original data or not
load_data
easily loads example datasets by passing the data_name
. For example, you can load the adult as:
= load_data(data_name = 'adult') dm
Supported Datasets
load_data
currently supports following datasets:
# Cont Features | # Cat Features | # of Data Points | |
---|---|---|---|
adult | 2 | 6 | 32561 |
heloc | 21 | 2 | 10459 |
oulad | 23 | 8 | 32593 |
credit | 20 | 3 | 30000 |
cancer | 30 | 0 | 569 |
student_performance | 2 | 14 | 649 |
titanic | 2 | 24 | 891 |
german | 7 | 13 | 1000 |
spam | 57 | 0 | 4601 |
ozone | 72 | 0 | 2534 |
qsar | 38 | 3 | 1055 |
bioresponse | 1776 | 0 | 3751 |
churn | 3 | 16 | 7043 |
road | 29 | 3 | 111762 |
dummy | 10 | 0 | 1000 |