pytorch_tao.plugins.Checkpoint

class pytorch_tao.plugins.Checkpoint(metric_name: str, objects: Dict, score_sign: int = 1, n_saved=3)

Save models or any other states. This is just a wrapper of ignite.handlers.checkpoint.Checkpoint.

Parameters
  • metric_name – the name of the metric to determine which is the best model to save.

  • objects – objects to save.

  • score_sign – 1 higher the better or -1 lower the better.

  • n_saved – how many file to keep.

Hooks

Hook Point

Logic

EPOCH_COMPLETED

saving the states

import pytorch_tao as tao

model = ...
trainer = tao.Trainer()

# save the top-3 models of accuracy
trainer.use(Checkpoint("accuracy", {"model": model}))
__init__(metric_name: str, objects: Dict, score_sign: int = 1, n_saved=3)

Methods

__init__(metric_name, objects[, score_sign, ...])

after_use()

attach(engine)

set_engine(engine)

Attributes