from relax.data import load_data
from relax.module import PredictiveTrainingModule
Checkpoint Manager
Manage the model and optimizer checkpoints
LOAD_CHECKPOINT
load_checkpoint (ckpt_dir)
SAVE_CHECKPOINT
save_checkpoint (state, ckpt_dir)
CHECKPOINTMANAGER
CLASS CheckpointManager (log_dir, monitor_metrics, max_n_checkpoints=3)
Initialize self. See help(type(self)) for accurate signature.
Example
= hk.PRNGSequence(42)
key = CheckpointManager(
ckpt_manager ='log',
log_dir='train/train_loss_1',
monitor_metrics=3
max_n_checkpoints
)= load_data('adult')
dm = PredictiveTrainingModule({'lr': 0.01, 'sizes': [50, 10, 50]})
module = module.init_net_opt(dm, next(key))
params, opt_state = {'train/train_loss_1': 0.1}
logs =1)
ckpt_manager.update_checkpoints(params, opt_state, logs, epochs= {'train/train_loss_1': 0.2}
logs =2)
ckpt_manager.update_checkpoints(params, opt_state, logs, epochs= {'train/train_loss_1': 0.15}
logs =3)
ckpt_manager.update_checkpoints(params, opt_state, logs, epochs= {'train/train_loss_1': 0.05}
logs =4)
ckpt_manager.update_checkpoints(params, opt_state, logs, epochs= {'train/train_loss_1': 0.14}
logs =5)
ckpt_manager.update_checkpoints(params, opt_state, logs, epochsassert ckpt_manager.n_checkpoints == len(ckpt_manager.checkpoints)
assert ckpt_manager.checkpoints.popitem(last=True)[0] == 0.14
'log/epoch=1'), ignore_errors=True)
shutil.rmtree(Path('log/epoch=2'), ignore_errors=True)
shutil.rmtree(Path('log/epoch=3'), ignore_errors=True)
shutil.rmtree(Path('log/epoch=4'), ignore_errors=True)
shutil.rmtree(Path('log/epoch=5'), ignore_errors=True) shutil.rmtree(Path(
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)