pytorch_tao.plugins.OutputRecorder
- class pytorch_tao.plugins.OutputRecorder(*fields: List[str])
Record the any output using tracker. Note that this is inherited from BasePlugin so one should pass at argument to trainer.
- Parameters
fields – the field of output to track
Hooks Hook Point
Logic
ITERATION_COMPLETED
call
tracker.add_points()
with the outputsimport pytorch_tao as tao from pytorch_tao.plugin import OutputRecorder model = ... optimizer = ... trainer = tao.Trainer() @trainer.train() def _train(images, targets): logits = model(images) loss = F.cross_entropy(logits, targets) return {"loss": loss} trainer.use(OutputRecorder("loss"), at="train")
- __init__(*fields: List[str])
Methods
__init__
(*fields)after_use
()attach
(engine)set_engine
(engine)Attributes