pytorch_tao.plugins.BasePlugin

class pytorch_tao.plugins.BasePlugin(attach_to: Optional[str] = None)

The base class of all the other plugins.

A plugin is a set of functions that can be hooked into the training or evaluating process. It’s basicly the same as ignite.handlers.

Plugin use the same events as ignite.engine.events.Events.

A Plugin that is inherited directly from BasePlugin should always pass the at arguments when calling trainer.use, see example bellow.

Example of create a custom plugin:

import pytorch_tao as tao
from ignite.engine import Events

class CustomPlugin(BasePlugin):

    def __init__(self):
        self.interval = 3

    @tao.on(Events.ITERATION_COMPLETED)
    def _on_iteration_completed(self, engine: Engine):
        # do something when iteration has completed

    @tao.on(lambda self: Events.ITERATION_COMPLETED(every=self.interval))
    def _on_every_3_iterations_completed(self, engine: Engine):
        # do something when every 3 iterations has completed

trainer = tao.Trainer()
trainer.use(CustomPlugin(), at="train")

In most cases, plugin is phase related, it means that it can be used either for training phase or evaluation phase, like Scheduler plugin is only for training but Metric plugin is only for evaluation.

So it’s common for a custom plugin to inherited from TrainPlugin or ValPlugin.

__init__(attach_to: Optional[str] = None)

Methods

__init__([attach_to])

after_use()

attach(engine)

set_engine(engine)

Attributes

engine

trainer