feats = FeaturesList([
Feature("age", np.random.normal(0, 1, (10, 1)),
transformation='minmax', is_immutable=True),
Feature("workclass", np.random.randint(0, 2, (10, 1)),
transformation='ohe'),
Feature("education", np.random.randint(0, 2, (10, 1)),
transformation='ordinal'),
])
labels = FeaturesList([
Feature("income", np.random.randint(0, 2, (10, 1)),
transformation='identity'),
])
df = features2pandas(feats, labels)
assert isinstance(df, pd.DataFrame)
assert df.shape == (10, 4)Data Module
DataModule for training parametric models, generating and benchmarking CF explanations.
Data Module Interfaces
High-level interfaces for DataModule. Docs to be added.
relax.data_module.BaseDataModule
class relax.data_module.BaseDataModule (config, name=None)
DataModule Interface
Data Module
DataModule for processing data, training models, and benchmarking CF explanations.
Config
relax.data_module.DataModuleConfig
class relax.data_module.DataModuleConfig (data_dir=None, data_name=None, continous_cols=[], discret_cols=[], imutable_cols=[], continuous_transformation=‘minmax’, discret_transformation=‘ohe’, sample_frac=None, train_indices=[], test_indices=[])
Configurator of DataModule.
Parameters:
- data_dir (
str) – The directory of dataset. - data_name (
str) – The name ofDataModule. - continous_cols (
List[str], default=[]) – Continuous features/columns in the data. - discret_cols (
List[str], default=[]) – Categorical features/columns in the data. - imutable_cols (
List[str], default=[]) – Immutable features/columns in the data. - continuous_transformation (
Optional[str], default=minmax) – Transformation for continuous features.Noneindicates unknown. - discret_transformation (
Optional[str], default=ohe) – Transformation for categorical features.Noneindicates unknown. - sample_frac (
Optional[float]) – Sample fraction of the data. Default to use the entire data. - train_indices (
List[int], default=[]) – Indices of training data. - test_indices (
List[int], default=[]) – Indices of testing data.
Utils
util functions for
DataModule
relax.data_module.features2config
relax.data_module.features2config (features, name, return_dict=False)
Get DataModuleConfig from FeaturesList.
Parameters:
- features (
<class 'relax.data_utils.features.FeaturesList'>) – FeaturesList to be converted - name (
<class 'str'>) – Name of the data used forDataModuleConfig - return_dict (
<class 'bool'>, default=False) – Whether to return a dict orDataModuleConfig
Returns:
(typing.Union[__main__.DataModuleConfig, typing.Dict]) – Return configs
relax.data_module.features2pandas
relax.data_module.features2pandas (features, labels)
Convert FeaturesList to pandas dataframe.
Parameters:
- features (
<class 'relax.data_utils.features.FeaturesList'>) – FeaturesList to be converted - labels (
<class 'relax.data_utils.features.FeaturesList'>) – labels to be converted
Returns:
(<class 'pandas.core.frame.DataFrame'>) – Return pandas dataframe
Example:
relax.data_module.dataframe2labels
relax.data_module.dataframe2labels (data, config)
Convert pandas dataframe of labels to FeaturesList.
relax.data_module.dataframe2features
relax.data_module.dataframe2features (data, config)
Convert pandas dataframe of features to FeaturesList.
Main Data Module
Main module.
relax.data_module.DataModule
class relax.data_module.DataModule (features, label, config=None, data=None, **kwargs)
DataModule for tabular data.
Methods
load_from_path (path, config=None)
Load DataModule from a directory.
Parameters:
- path (
<class 'str'>) – Path to the directory to loadDataModule - config (
typing.Union[typing.Dict, __main__.DataModuleConfig], default=None) – Configs ofDataModule. This argument is ignored.
Returns:
(<class '__main__.DataModule'>) – Initialized DataModule from path
from_config (config, data=None)
Parameters:
- config (
typing.Union[typing.Dict, __main__.DataModuleConfig]) – Configs ofDataModule - data (
<class 'pandas.core.frame.DataFrame'>, default=None) – Passed in pd.Dataframe
Returns:
(<class '__main__.DataModule'>) – Initialized DataModule from configs and data
from_features (features, label, name=None)
Create DataModule from FeaturesList.
Parameters:
- features (
<class 'relax.data_utils.features.FeaturesList'>) – Features ofDataModule - label (
<class 'relax.data_utils.features.FeaturesList'>) – Labels ofDataModule - name (
<class 'str'>, default=None) – Name ofDataModule
Returns:
(<class '__main__.DataModule'>) – Initialized DataModule from features and labels
from_numpy (xs, ys, name=None, transformation=‘minmax’)
Create DataModule from numpy arrays. Note that the xs are treated as continuous features.
Parameters:
- xs (
<class 'numpy.ndarray'>) – Input data - ys (
<class 'numpy.ndarray'>) – Labels - name (
<class 'str'>, default=None) – Name ofDataModule - transformation (
<class 'str'>, default=minmax)
Returns:
(<class '__main__.DataModule'>) – Initialized DataModule from numpy arrays
save (path)
Save DataModule to a directory.
Parameters:
- path (
<class 'str'>) – Path to the directory to saveDataModule
transform (data)
Transform data to jax.Array.
Parameters:
- data (
typing.Union[pandas.core.frame.DataFrame, typing.Dict[str, jax.Array]]) – Data to be transformed
Returns:
(<class 'jax.Array'>) – Transformed data
inverse_transform (data, return_type=‘pandas’)
Inverse transform data to pandas.DataFrame.
Parameters:
- data (
<class 'jax.Array'>) – Data to be inverse transformed - return_type (
<class 'str'>, default=pandas) – Type of the returned data. Should be one of [‘pandas’, ‘dict’]
Returns:
(<class 'pandas.core.frame.DataFrame'>) – Inverse transformed data
apply_constraints (xs, cfs, hard=False, rng_key=None, **kwargs)
Apply constraints to counterfactuals.
Parameters:
- xs (
<class 'jax.Array'>) – Input data - cfs (
<class 'jax.Array'>) – Counterfactuals to be constrained - hard (
<class 'bool'>, default=False) – Whether to apply hard constraints or not - rng_key (
<function PRNGKey at 0x7fcaaae82d40>, default=None) – Random key - kwargs (
VAR_KEYWORD)
Returns:
(<class 'jax.Array'>) – Constrained counterfactuals
compute_reg_loss (xs, cfs, hard=False)
Compute regularization loss.
Parameters:
- xs (
<class 'jax.Array'>) – Input data - cfs (
<class 'jax.Array'>) – Counterfactuals to be constrained - hard (
<class 'bool'>, default=False) – Whether to apply hard constraints or not
Returns:
(<class 'float'>)
set_transformations (feature_names_to_transformation)
Reset transformations for features.
Parameters:
- feature_names_to_transformation (
typing.Dict[str, typing.Union[str, typing.Dict, relax.data_utils.transforms.BaseTransformation]]) – Dict[feature_name, Transformation]
Returns:
(<class '__main__.DataModule'>)
sample (size, stage=‘train’, key=None)
Sample data from DataModule.
Parameters:
- size (
float | int) – Size of the sample. If float, should be 0<=size<=1. - stage (
<class 'str'>, default=train) – Stage of data to sample from. Should be one of [‘train’, ‘valid’, ‘test’] - key (
<function PRNGKey at 0x7fcaaae82d40>, default=None) – Random key.
Returns:
(typing.Tuple[jax.Array, jax.Array]) – Sampled data
# Test initialization
config = DataModuleConfig.load_from_json("assets/adult/data/config.json")
config_1 = config.dict()
config_1.update({"imutable_cols": []})
dm = DataModule.from_config(config)
dm_1 = DataModule.from_config(config.dict())
assert dm_equals(dm, dm_1)
dm_2 = DataModule.from_path("assets/adult/data")
assert dm_equals(dm, dm_2)
dm_3 = DataModule.from_config(config_1)
assert dm_equals(dm, dm_3)
assert dm_3.config.imutable_cols == []
feats = FeaturesList.load_from_path("assets/adult/data/features")
label = FeaturesList.load_from_path("assets/adult/data/label")
dm_4 = DataModule.from_features(feats, label)
assert dm_equals(dm, dm_4, indices_equals=False) # Indices are not supposed to be equal# Test from_numpy
xs, ys = make_classification(n_samples=100, n_features=5, n_informative=3, random_state=0)
dm_5 = DataModule.from_numpy(xs, ys, name="test", transformation='identity')
config_5 = dm_5.config
assert dm_5.config.data_name == "test"
assert dm_5.data.shape == (100, 6)
assert np.allclose(dm_5.data.to_numpy(), np.concatenate([xs, ys.reshape(-1, 1)], axis=1))
assert np.allclose(
xs[config_5.train_indices],
dm_5['train'][0]
)
assert np.allclose(
xs[config_5.test_indices],
dm_5['test'][0]
)
dm_5.save('tmp/test')
dm_6 = DataModule.load_from_path('tmp/test')
assert dm_equals(dm_5, dm_6)
shutil.rmtree("tmp/test")# Test save and load
dm.save("tmp/adult")
dm_5 = DataModule.load_from_path("tmp/adult")
assert dm_equals(dm, dm_5)
shutil.rmtree("tmp/adult")# Test set_transformations
dm_6 = deepcopy(dm)
dm_6.set_transformations({"age": 'identity'})
assert dm_6.features['age'].transformation.name == 'identity'
assert np.array_equal(dm_6.xs[:, :1], dm_6.data[['age']].to_numpy())
dm_6.set_transformations({feat: 'ordinal' for feat in config.discret_cols})
assert dm_6.xs.shape == (dm.data.shape[0], len(config.continous_cols) + len(config.discret_cols))
assert np.array_equal(dm_6.xs[:, :1], dm_6.data[['age']].to_numpy())
test_fail(lambda: dm_6.set_transformations({1: 'identity'}), contains="Invalid idx type")
test_fail(lambda: dm_6.set_transformations({"❤": 'identity'}), contains="Invalid feature name")
test_fail(lambda: dm_6.set_transformations({"age": '❤'}), contains="Unknown transformation")
test_fail(lambda: dm_6.set_transformations('❤'), contains="Invalid feature_names_to_transformation type")
dm_6.set_transformations({"age": MinMaxTransformation()})
assert np.allclose(dm_6.xs[:, :1], dm.xs[:, :1])# Test sample
sampled_xs, sampled_ys = dm.sample(0.1)
assert len(sampled_xs) == len(sampled_ys)
assert sampled_xs.shape[0] == int(0.1 * dm['train'][0].shape[0])
assert not jnp.all(sampled_xs == dm['train'][0][:sampled_xs.shape[0]])
sampled_xs, sampled_ys = dm.sample(100)
assert len(sampled_xs) == len(sampled_ys)
assert sampled_xs.shape[0] == 100
assert not jnp.all(sampled_xs == dm['train'][0][:100])
test_fail(lambda: dm.sample(1.1), contains='should be a floating number 0<=size<=1,')
test_fail(lambda: dm.sample('train'), contains='or an integer')
xs = dm['train'][0]
cfs = jrand.uniform(jrand.PRNGKey(0), shape=xs.shape, minval=0.01, maxval=0.99)
cfs = dm.apply_constraints(xs, cfs, hard=False)
assert cfs.shape == xs.shape
cfs = dm.apply_constraints(xs, cfs, hard=True)
assert cfs.shape == xs.shape# Test transform
data = dm.transform(dm.data)
assert np.allclose(data, dm.xs)Load Data
# from sklearn.datasets import make_classification
# xs, ys = make_classification(n_samples=1000, n_features=10)
# xs = pd.DataFrame(xs, columns=[f"col_{i}" for i in range(10)])
# ys = pd.DataFrame(ys, columns=['label'])
# data = pd.concat([xs, ys], axis=1)
# os.makedirs('assets/dummy/data', exist_ok=True)
# data.to_csv('assets/dummy/data/data.csv', index=False)
# config = DataModuleConfig(
# data_name="dummy",
# data_dir="assets/dummy/data/data.csv",
# continous_cols=[f"col_{i}" for i in range(10)]
# )
# dm = DataModule(config)
# dm.save('assets/dummy/data')# for data_name in DEFAULT_DATA_CONFIGS.keys():
# print(f"Loading {data_name}...")
# shutil.rmtree(f'../relax-assets/{data_name}', ignore_errors=True)
# conf_path = DEFAULT_DATA_CONFIGS[data_name]['conf']
# config = load_json(conf_path)['data_configs']
# dm_config = DataModuleConfig(**config)
# dm = DataModule(dm_config)
# dm.save(f'../relax-assets/{data_name}/data')# for data_name in DEFAULT_DATA_CONFIGS.keys():
# print(f"Loading {data_name}...")
# DataModule.load_from_path(f'../relax-assets/{data_name}/data')# config = load_json('assets/adult/configs.json')['data_configs']
# dm_config = DataModuleConfig(**config)
# dm = DataModule(dm_config)relax.data_module.load_data
relax.data_module.load_data (data_name, return_config=False, data_configs=None)
High-level util function for loading data and data_config.
Parameters:
- data_name (
<class 'str'>) – The name of data - return_config (
<class 'bool'>, default=False) – Deprecated - data_configs (
<class 'dict'>, default=None) – Data configs to override default configuration
Returns:
(typing.Union[__main__.DataModule, typing.Tuple[__main__.DataModule, __main__.DataModuleConfig]]) – Return DataModule or (DataModule, DataModuleConfig)
relax.data_module.download_data_module_files
relax.data_module.download_data_module_files (data_name, data_parent_dir, download_original_data=False)
Parameters:
- data_name (
<class 'str'>) – The name of data - data_parent_dir (
<class 'pathlib.Path'>) – The directory to save data. - download_original_data (
<class 'bool'>, default=False) – Download original data or not
load_data easily loads example datasets by passing the data_name. For example, you can load the adult as:
dm = load_data(data_name = 'adult')Supported Datasets
load_data currently supports following datasets:
| # Cont Features | # Cat Features | # of Data Points | |
|---|---|---|---|
| adult | 2 | 6 | 32561 |
| heloc | 21 | 2 | 10459 |
| oulad | 23 | 8 | 32593 |
| credit | 20 | 3 | 30000 |
| cancer | 30 | 0 | 569 |
| student_performance | 2 | 14 | 649 |
| titanic | 2 | 24 | 891 |
| german | 7 | 13 | 1000 |
| spam | 57 | 0 | 4601 |
| ozone | 72 | 0 | 2534 |
| qsar | 38 | 3 | 1055 |
| bioresponse | 1776 | 0 | 3751 |
| churn | 3 | 16 | 7043 |
| road | 29 | 3 | 111762 |
| dummy | 10 | 0 | 1000 |