from relax.data import load_data
from relax.module import PredictiveTrainingModuleCheckpoint Manager
Manage the model and optimizer checkpoints
LOAD_CHECKPOINT
load_checkpoint (ckpt_dir)
SAVE_CHECKPOINT
save_checkpoint (state, ckpt_dir)
CHECKPOINTMANAGER
CLASS CheckpointManager (log_dir, monitor_metrics, max_n_checkpoints=3)
Initialize self. See help(type(self)) for accurate signature.
Example
key = hk.PRNGSequence(42)
ckpt_manager = CheckpointManager(
log_dir='log',
monitor_metrics='train/train_loss_1',
max_n_checkpoints=3
)
dm = load_data('adult')
module = PredictiveTrainingModule({'lr': 0.01, 'sizes': [50, 10, 50]})
params, opt_state = module.init_net_opt(dm, next(key))
logs = {'train/train_loss_1': 0.1}
ckpt_manager.update_checkpoints(params, opt_state, logs, epochs=1)
logs = {'train/train_loss_1': 0.2}
ckpt_manager.update_checkpoints(params, opt_state, logs, epochs=2)
logs = {'train/train_loss_1': 0.15}
ckpt_manager.update_checkpoints(params, opt_state, logs, epochs=3)
logs = {'train/train_loss_1': 0.05}
ckpt_manager.update_checkpoints(params, opt_state, logs, epochs=4)
logs = {'train/train_loss_1': 0.14}
ckpt_manager.update_checkpoints(params, opt_state, logs, epochs=5)
assert ckpt_manager.n_checkpoints == len(ckpt_manager.checkpoints)
assert ckpt_manager.checkpoints.popitem(last=True)[0] == 0.14
shutil.rmtree(Path('log/epoch=1'), ignore_errors=True)
shutil.rmtree(Path('log/epoch=2'), ignore_errors=True)
shutil.rmtree(Path('log/epoch=3'), ignore_errors=True)
shutil.rmtree(Path('log/epoch=4'), ignore_errors=True)
shutil.rmtree(Path('log/epoch=5'), ignore_errors=True)WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)