from relax.legacy.module import PredictiveTrainingModule, PredictiveModelConfigs
from relax.data_module import load_dataTraining
relax.legacy.trainer.TrainingConfigs
class relax.legacy.trainer.TrainingConfigs (n_epochs, batch_size, monitor_metrics=None, seed=42, log_dir=‘log’, logger_name=‘debug’, log_on_step=False, max_n_checkpoints=3)
Configurator of train_model.
Parameters:
- n_epochs (
int) – Number of epochs. - batch_size (
int) – Batch size. - monitor_metrics (
Optional[str]) – Monitor metrics used to evaluate the training result after each epoch. - seed (
int, default=42) – Seed for generating random number. - log_dir (
str, default=log) – The name for the directory that holds logged data during training. - logger_name (
str, default=debug) – The name for the directory that holds logged data during training under log directory. - log_on_step (
bool, default=False) – Log the evaluate metrics at the current step. - max_n_checkpoints (
int, default=3) – Maximum number of checkpoints stored.
relax.legacy.trainer.train_model_with_states
relax.legacy.trainer.train_model_with_states (training_module, params, opt_state, data_module, t_configs)
Train models with params and opt_state.
relax.legacy.trainer.train_model
relax.legacy.trainer.train_model (training_module, data_module, batch_size=128, epochs=1, **fit_kwargs)
Train models.
Parameters:
- training_module (
<class 'relax.legacy.module.BaseTrainingModule'>) – Training module - data_module (
<class 'relax.data_module.DataModule'>) – Data module - batch_size (
<class 'int'>, default=128) – Batch size - epochs (
<class 'int'>, default=1) – Number of epochs - fit_kwargs (
VAR_KEYWORD)
Returns:
(typing.Tuple[collections.abc.Mapping[str, collections.abc.Mapping[str, jax.Array]], typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, typing.Iterable[ForwardRef('ArrayTree')], typing.Mapping[typing.Any, ForwardRef('ArrayTree')]]]) – Positional arguments for keras.Model.fit
Examples
A siimple example to train a predictive model.
datamodule = load_data('adult')
params, opt_state = train_model(
PredictiveTrainingModule({'sizes': [64, 32, 16], 'lr': 0.003}),
datamodule,
)/home/birk/code/jax-relax/relax/legacy/ckpt_manager.py:47: UserWarning: `monitor_metrics` is not specified in `CheckpointManager`. No checkpoints will be stored.
warnings.warn(
Epoch 0: 100%|██████████| 191/191 [00:01<00:00, 106.57batch/s, train/train_loss=0.08575804]
from relax.ml_model import MLModulemodel = MLModule()
model.train(datamodule, batch_size=128, epochs=1)191/191 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.6769 - loss: 0.6131
<relax.ml_model.MLModule>