from relax.utils import make_hk_moduleModule
train_model.
Networks
Networks are haiku.module, which define model architectures.
BASENETWORK
CLASS relax.module.BaseNetwork ()
BaseNetwork needs a is_training argument
DENSEBLOCK
CLASS relax.module.DenseBlock (output_size, dropout_rate=0.3, name=None)
A DenseBlock consists of a dense layer, followed by Leaky Relu and a dropout layer.
MLP
CLASS relax.module.MLP (sizes, dropout_rate=0.3, name=None)
A MLP consists of a list of DenseBlock layers.
Predictive Model
PREDICTIVEMODEL
CLASS relax.module.PredictiveModel (sizes, dropout_rate=0.3, name=None)
A basic predictive model for binary classification.
Use make_hk_module to create a haiku.Transformed model.
net = make_hk_module(PredictiveModel, sizes=[50, 20, 10], dropout_rate=0.3)We make some random data.
key = hk.PRNGSequence(42)
xs = random.normal(next(key), (1000, 10))We can then initalize the model
params = net.init(next(key), xs, is_training=True)We can view model’s structure via jax.tree_map.
jax.tree_map(lambda x: x.shape, params){'predictive_model/linear': {'b': (1,), 'w': (10, 1)},
'predictive_model/mlp/dense_block/linear': {'b': (50,), 'w': (10, 50)},
'predictive_model/mlp/dense_block_1/linear': {'b': (20,), 'w': (50, 20)},
'predictive_model/mlp/dense_block_2/linear': {'b': (10,), 'w': (20, 10)}}
Model output is produced via apply function.
y = net.apply(params, next(key), xs, is_training=True)For more usage of haiku.module, please refer to Haiku documentation.
Training Modules API
BASETRAININGMODULE
CLASS relax.module.BaseTrainingModule ()
Helper class that provides a standard way to create an ABC using inheritance.
Predictive Training Module
PREDICTIVETRAININGMODULECONFIGS
CLASS relax.module.PredictiveTrainingModuleConfigs (lr, sizes, dropout_rate=0.3)
Configurator of PredictiveTrainingModule.
PREDICTIVETRAININGMODULE
CLASS relax.module.PredictiveTrainingModule (m_configs)
A training module for predictive models.