pytorch_tao.plugins.OutputRecorder

class pytorch_tao.plugins.OutputRecorder(*fields: List[str])

Record the any output using tracker. Note that this is inherited from BasePlugin so one should pass at argument to trainer.

Parameters

fields – the field of output to track

Hooks

Hook Point

Logic

ITERATION_COMPLETED

call tracker.add_points() with the outputs

import pytorch_tao as tao
from pytorch_tao.plugin import OutputRecorder

model = ...
optimizer = ...

trainer = tao.Trainer()
@trainer.train()
def _train(images, targets):
    logits = model(images)
    loss = F.cross_entropy(logits, targets)
    return {"loss": loss}

trainer.use(OutputRecorder("loss"), at="train")
__init__(*fields: List[str])

Methods

__init__(*fields)

after_use()

attach(engine)

set_engine(engine)

Attributes