Training

Functions for training models

relax.legacy.trainer.TrainingConfigs

[source]

class relax.legacy.trainer.TrainingConfigs (n_epochs, batch_size, monitor_metrics=None, seed=42, log_dir=‘log’, logger_name=‘debug’, log_on_step=False, max_n_checkpoints=3)

Configurator of train_model.

Parameters:

  • n_epochs (int) – Number of epochs.
  • batch_size (int) – Batch size.
  • monitor_metrics (Optional[str]) – Monitor metrics used to evaluate the training result after each epoch.
  • seed (int, default=42) – Seed for generating random number.
  • log_dir (str, default=log) – The name for the directory that holds logged data during training.
  • logger_name (str, default=debug) – The name for the directory that holds logged data during training under log directory.
  • log_on_step (bool, default=False) – Log the evaluate metrics at the current step.
  • max_n_checkpoints (int, default=3) – Maximum number of checkpoints stored.

relax.legacy.trainer.train_model_with_states

[source]

relax.legacy.trainer.train_model_with_states (training_module, params, opt_state, data_module, t_configs)

Train models with params and opt_state.

relax.legacy.trainer.train_model

[source]

relax.legacy.trainer.train_model (training_module, data_module, batch_size=128, epochs=1, **fit_kwargs)

Train models.

Parameters:

  • training_module (<class 'relax.legacy.module.BaseTrainingModule'>) – Training module
  • data_module (<class 'relax.data_module.DataModule'>) – Data module
  • batch_size (<class 'int'>, default=128) – Batch size
  • epochs (<class 'int'>, default=1) – Number of epochs
  • fit_kwargs

Returns:

    (typing.Tuple[collections.abc.Mapping[str, collections.abc.Mapping[str, jax.Array]], typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, typing.Iterable[ForwardRef('ArrayTree')], typing.Mapping[typing.Any, ForwardRef('ArrayTree')]]]) – Positional arguments for keras.Model.fit

Examples

A siimple example to train a predictive model.

from relax.legacy.module import PredictiveTrainingModule, PredictiveModelConfigs
from relax.data_module import load_data
datamodule = load_data('adult')

params, opt_state = train_model(
    PredictiveTrainingModule({'sizes': [64, 32, 16], 'lr': 0.003}), 
    datamodule,
)
/home/birk/code/jax-relax/relax/legacy/ckpt_manager.py:47: UserWarning: `monitor_metrics` is not specified in `CheckpointManager`. No checkpoints will be stored.
  warnings.warn(
Epoch 0: 100%|██████████| 191/191 [00:01<00:00, 106.57batch/s, train/train_loss=0.08575804] 
from relax.ml_model import MLModule
model = MLModule()
model.train(datamodule, batch_size=128, epochs=1)
191/191 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.6769 - loss: 0.6131
<relax.ml_model.MLModule>