from relax.legacy.module import PredictiveTrainingModule, PredictiveModelConfigs
from relax.data_module import load_data
Training
relax.legacy.trainer.TrainingConfigs
class relax.legacy.trainer.TrainingConfigs (n_epochs, batch_size, monitor_metrics=None, seed=42, log_dir=‘log’, logger_name=‘debug’, log_on_step=False, max_n_checkpoints=3)
Configurator of train_model
.
Parameters:
- n_epochs (
int
) – Number of epochs. - batch_size (
int
) – Batch size. - monitor_metrics (
Optional[str]
) – Monitor metrics used to evaluate the training result after each epoch. - seed (
int
, default=42) – Seed for generating random number. - log_dir (
str
, default=log) – The name for the directory that holds logged data during training. - logger_name (
str
, default=debug) – The name for the directory that holds logged data during training under log directory. - log_on_step (
bool
, default=False) – Log the evaluate metrics at the current step. - max_n_checkpoints (
int
, default=3) – Maximum number of checkpoints stored.
relax.legacy.trainer.train_model_with_states
relax.legacy.trainer.train_model_with_states (training_module, params, opt_state, data_module, t_configs)
Train models with params
and opt_state
.
relax.legacy.trainer.train_model
relax.legacy.trainer.train_model (training_module, data_module, batch_size=128, epochs=1, **fit_kwargs)
Train models.
Parameters:
- training_module (
<class 'relax.legacy.module.BaseTrainingModule'>
) – Training module - data_module (
<class 'relax.data_module.DataModule'>
) – Data module - batch_size (
<class 'int'>
, default=128) – Batch size - epochs (
<class 'int'>
, default=1) – Number of epochs - fit_kwargs
Returns:
(typing.Tuple[collections.abc.Mapping[str, collections.abc.Mapping[str, jax.Array]], typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, typing.Iterable[ForwardRef('ArrayTree')], typing.Mapping[typing.Any, ForwardRef('ArrayTree')]]]
) – Positional arguments for keras.Model.fit
Examples
A siimple example to train a predictive model.
= load_data('adult')
datamodule
= train_model(
params, opt_state 'sizes': [64, 32, 16], 'lr': 0.003}),
PredictiveTrainingModule({
datamodule, )
/home/birk/code/jax-relax/relax/legacy/ckpt_manager.py:47: UserWarning: `monitor_metrics` is not specified in `CheckpointManager`. No checkpoints will be stored.
warnings.warn(
Epoch 0: 100%|██████████| 191/191 [00:01<00:00, 106.57batch/s, train/train_loss=0.08575804]
from relax.ml_model import MLModule
= MLModule()
model =128, epochs=1) model.train(datamodule, batch_size
191/191 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.6769 - loss: 0.6131
<relax.ml_model.MLModule>