from relax.utils import make_hk_module
Module
train_model
.
Networks
Networks are haiku.module, which define model architectures.
BASENETWORK
CLASS relax.module.BaseNetwork ()
BaseNetwork needs a is_training
argument
DENSEBLOCK
CLASS relax.module.DenseBlock (output_size, dropout_rate=0.3, name=None)
A DenseBlock
consists of a dense layer, followed by Leaky Relu and a dropout layer.
MLP
CLASS relax.module.MLP (sizes, dropout_rate=0.3, name=None)
A MLP
consists of a list of DenseBlock
layers.
Predictive Model
PREDICTIVEMODEL
CLASS relax.module.PredictiveModel (sizes, dropout_rate=0.3, name=None)
A basic predictive model for binary classification.
Use make_hk_module
to create a haiku.Transformed
model.
= make_hk_module(PredictiveModel, sizes=[50, 20, 10], dropout_rate=0.3) net
We make some random data.
= hk.PRNGSequence(42)
key = random.normal(next(key), (1000, 10)) xs
We can then initalize the model
= net.init(next(key), xs, is_training=True) params
We can view model’s structure via jax.tree_map
.
lambda x: x.shape, params) jax.tree_map(
{'predictive_model/linear': {'b': (1,), 'w': (10, 1)},
'predictive_model/mlp/dense_block/linear': {'b': (50,), 'w': (10, 50)},
'predictive_model/mlp/dense_block_1/linear': {'b': (20,), 'w': (50, 20)},
'predictive_model/mlp/dense_block_2/linear': {'b': (10,), 'w': (20, 10)}}
Model output is produced via apply
function.
= net.apply(params, next(key), xs, is_training=True) y
For more usage of haiku.module
, please refer to Haiku documentation.
Training Modules API
BASETRAININGMODULE
CLASS relax.module.BaseTrainingModule ()
Helper class that provides a standard way to create an ABC using inheritance.
Predictive Training Module
PREDICTIVETRAININGMODULECONFIGS
CLASS relax.module.PredictiveTrainingModuleConfigs (lr, sizes, dropout_rate=0.3)
Configurator of PredictiveTrainingModule
.
PREDICTIVETRAININGMODULE
CLASS relax.module.PredictiveTrainingModule (m_configs)
A training module for predictive models.