from relax.legacy.utils import make_hk_module
Module
train_model
.
Networks
Networks are haiku.module, which define model architectures.
relax.legacy.module.BaseNetwork
class relax.legacy.module.BaseNetwork ()
BaseNetwork needs a is_training
argument
relax.legacy.module.DenseBlock
class relax.legacy.module.DenseBlock (output_size, dropout_rate=0.3, name=None)
A DenseBlock
consists of a dense layer, followed by Leaky Relu and a dropout layer.
Parameters:
- output_size (
<class 'int'>
) – Output dimensionality. - dropout_rate (
<class 'float'>
, default=0.3) – Dropout rate. - name (
str | None
, default=None) – Name of the Module
relax.ml_model.MLP
class relax.ml_model.MLP (sizes, dropout_rate=0.3, name=None)
A MLP
consists of a list of DenseBlock
layers.
Parameters:
- sizes (
typing.Iterable[int]
) – Sequence of layer sizes. - dropout_rate (
<class 'float'>
, default=0.3) – Dropout rate. - name (
str | None
, default=None) – Name of the Module
Predictive Model
relax.legacy.module.PredictiveModel
class relax.legacy.module.PredictiveModel (sizes, dropout_rate=0.3, name=None)
A basic predictive model for binary classification.
Parameters:
- sizes (
typing.List[int]
) – Sequence of layer sizes. - dropout_rate (
<class 'float'>
, default=0.3) – Dropout rate. - name (
typing.Optional[str]
, default=None) – Name of the module.
Use make_hk_module
to create a haiku.Transformed
model.
= make_hk_module(PredictiveModel, sizes=[50, 20, 10], dropout_rate=0.3) net
We make some random data.
= hk.PRNGSequence(42)
key = random.normal(next(key), (1000, 10)) xs
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
We can then initalize the model
= net.init(next(key), xs, is_training=True) params
We can view model’s structure via jax.tree_map
.
lambda x: x.shape, params) jax.tree_map(
{'predictive_model/linear': {'b': (1,), 'w': (10, 1)},
'predictive_model/mlp/dense_block/linear': {'b': (50,), 'w': (10, 50)},
'predictive_model/mlp/dense_block_1/linear': {'b': (20,), 'w': (50, 20)},
'predictive_model/mlp/dense_block_2/linear': {'b': (10,), 'w': (20, 10)}}
Model output is produced via apply
function.
= net.apply(params, next(key), xs, is_training=True) y
For more usage of haiku.module
, please refer to Haiku documentation.
Training Modules API
relax.legacy.module.BaseTrainingModule
class relax.legacy.module.BaseTrainingModule ()
Helper class that provides a standard way to create an ABC using inheritance.
Predictive Training Module
relax.legacy.module.PredictiveTrainingModuleConfigs
class relax.legacy.module.PredictiveTrainingModuleConfigs (lr, sizes, dropout_rate=0.3)
Configurator of PredictiveTrainingModule
.
Parameters:
- lr (
float
) – Learning rate. - sizes (
List[int]
) – Sequence of layer sizes. - dropout_rate (
float
, default=0.3) – Dropout rate
relax.legacy.module.PredictiveTrainingModule
class relax.legacy.module.PredictiveTrainingModule (m_configs)
A training module for predictive models.