Module

Modules used for defining model architecture and training procedure, which are passed to train_model.

Networks

Networks are haiku.module, which define model architectures.


source

BASENETWORK

CLASS relax.module.BaseNetwork ()

BaseNetwork needs a is_training argument


source

DENSEBLOCK

CLASS relax.module.DenseBlock (output_size, dropout_rate=0.3, name=None)

A DenseBlock consists of a dense layer, followed by Leaky Relu and a dropout layer.

Parameters:
  • output_size (int) – Output dimensionality.
  • dropout_rate (float, default=0.3) – Dropout rate.
  • name (str | None, default=None) – Name of the Module

source

MLP

CLASS relax.module.MLP (sizes, dropout_rate=0.3, name=None)

A MLP consists of a list of DenseBlock layers.

Parameters:
  • sizes (Iterable[int]) – Sequence of layer sizes.
  • dropout_rate (float, default=0.3) – Dropout rate.
  • name (str | None, default=None) – Name of the Module

Predictive Model


source

PREDICTIVEMODEL

CLASS relax.module.PredictiveModel (sizes, dropout_rate=0.3, name=None)

A basic predictive model for binary classification.

Parameters:
  • sizes (List[int]) – Sequence of layer sizes.
  • dropout_rate (float, default=0.3) – Dropout rate.
  • name (Optional[str], default=None) – Name of the module.

Use make_hk_module to create a haiku.Transformed model.

from relax.utils import make_hk_module
net = make_hk_module(PredictiveModel, sizes=[50, 20, 10], dropout_rate=0.3)

We make some random data.

key = hk.PRNGSequence(42)
xs = random.normal(next(key), (1000, 10))

We can then initalize the model

params = net.init(next(key), xs, is_training=True)

We can view model’s structure via jax.tree_map.

jax.tree_map(lambda x: x.shape, params)
{'predictive_model/linear': {'b': (1,), 'w': (10, 1)},
 'predictive_model/mlp/dense_block/linear': {'b': (50,), 'w': (10, 50)},
 'predictive_model/mlp/dense_block_1/linear': {'b': (20,), 'w': (50, 20)},
 'predictive_model/mlp/dense_block_2/linear': {'b': (10,), 'w': (20, 10)}}

Model output is produced via apply function.

y = net.apply(params, next(key), xs, is_training=True)

For more usage of haiku.module, please refer to Haiku documentation.

Training Modules API


source

BASETRAININGMODULE

CLASS relax.module.BaseTrainingModule ()

Helper class that provides a standard way to create an ABC using inheritance.

Predictive Training Module


source

PREDICTIVETRAININGMODULECONFIGS

CLASS relax.module.PredictiveTrainingModuleConfigs (lr, sizes, dropout_rate=0.3)

Configurator of PredictiveTrainingModule.

Parameters:
  • lr (float) – Learning rate.
  • sizes (List[int]) – Sequence of layer sizes.
  • dropout_rate (float, default=0.3) – Dropout rate

source

PREDICTIVETRAININGMODULE

CLASS relax.module.PredictiveTrainingModule (m_configs)

A training module for predictive models.