pytorch_tao.plugins.Checkpoint
- class pytorch_tao.plugins.Checkpoint(metric_name: str, objects: Dict, score_sign: int = 1, n_saved=3)
Save models or any other states. This is just a wrapper of
ignite.handlers.checkpoint.Checkpoint
.- Parameters
metric_name – the name of the metric to determine which is the best model to save.
objects – objects to save.
score_sign – 1 higher the better or -1 lower the better.
n_saved – how many file to keep.
Hooks Hook Point
Logic
EPOCH_COMPLETED
saving the states
import pytorch_tao as tao model = ... trainer = tao.Trainer() # save the top-3 models of accuracy trainer.use(Checkpoint("accuracy", {"model": model}))
- __init__(metric_name: str, objects: Dict, score_sign: int = 1, n_saved=3)
Methods
__init__
(metric_name, objects[, score_sign, ...])after_use
()attach
(engine)set_engine
(engine)Attributes