Checkpoint Manager

Manage the model and optimizer checkpoints

LOAD_CHECKPOINT

load_checkpoint (ckpt_dir)


SAVE_CHECKPOINT

save_checkpoint (state, ckpt_dir)


CHECKPOINTMANAGER

CLASS CheckpointManager (log_dir, monitor_metrics, max_n_checkpoints=3)

Initialize self. See help(type(self)) for accurate signature.

Example

from relax.data import load_data
from relax.module import PredictiveTrainingModule
key = hk.PRNGSequence(42)
ckpt_manager = CheckpointManager(
    log_dir='log', 
    monitor_metrics='train/train_loss_1',
    max_n_checkpoints=3
)
dm = load_data('adult')
module = PredictiveTrainingModule({'lr': 0.01, 'sizes': [50, 10, 50]})
params, opt_state = module.init_net_opt(dm, next(key))
logs = {'train/train_loss_1': 0.1}
ckpt_manager.update_checkpoints(params, opt_state, logs, epochs=1)
logs = {'train/train_loss_1': 0.2}
ckpt_manager.update_checkpoints(params, opt_state, logs, epochs=2)
logs = {'train/train_loss_1': 0.15}
ckpt_manager.update_checkpoints(params, opt_state, logs, epochs=3)
logs = {'train/train_loss_1': 0.05}
ckpt_manager.update_checkpoints(params, opt_state, logs, epochs=4)
logs = {'train/train_loss_1': 0.14}
ckpt_manager.update_checkpoints(params, opt_state, logs, epochs=5)
assert ckpt_manager.n_checkpoints == len(ckpt_manager.checkpoints)
assert ckpt_manager.checkpoints.popitem(last=True)[0] == 0.14

shutil.rmtree(Path('log/epoch=1'), ignore_errors=True)
shutil.rmtree(Path('log/epoch=2'), ignore_errors=True)
shutil.rmtree(Path('log/epoch=3'), ignore_errors=True)
shutil.rmtree(Path('log/epoch=4'), ignore_errors=True)
shutil.rmtree(Path('log/epoch=5'), ignore_errors=True)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)