Trainer
- class pytorch_tao.trainer.Trainer(device: ~torch.device = device(type='cpu'), model: ~typing.Optional[~torch.nn.modules.module.Module] = None, optimizer: ~typing.Optional[~torch.optim.optimizer.Optimizer] = None, train_loader: ~typing.Optional[~typing.Iterable] = None, val_loader: ~typing.Optional[~typing.Iterable] = None, train_func: ~typing.Callable = <function Trainer.<lambda>>, val_func: ~typing.Callable = <function Trainer.<lambda>>, val_event: ~typing.Callable = <function Trainer.<lambda>>)
Trainer is basicly a warp of two
ignite.engine.engine.Engine
: train_engine and val_engine.- Parameters
device – torch.device, to train with cuda or cpu.
model – a pytorch model, this is an optional argument but is highly recommanded to pass in for some features to work.
optimizer – a pytorch optimizer which is also optional but highly recommended to pass in.
train_loader – dataloader for training stage.
val_loader – dataloader for evaluation stage, if None then will no do evaluation.
train_func – training forward function, it’s more recommanded to use
train()
decorator after trainer is initialized.val_func – evaluation forward function, it’s more recommanded to use
eval()
decorator after trainer is initialized.val_stride – epochs stride to do evaluation.
- eval(fields: Optional[List[str]] = None, amp: bool = False)
Decorator that define the training process. Parameters of this has the same functionality as
train()
.- Parameters
fields – fields to selecte from the yield batch by dataloader. if None, then the all yield data will be passed in as the decorated functions parameters. If the yield type is of dict, then will be passed as keywords args, if it’s a tuple then will be passed as positional args.
amp – whether to use amp.
Ideally, the decorated function should do the following:
the parameters should match the fields arguments.
do model forwarding and return the raw output of model.
This function is mainly used for computing metrics.
Cause different metrics may have different input requirements, like ROC AUC requires probabilities but F-score requires hard 0 or 1. So for this method, it should only return the raw output from model, and let metrics to do the transformation. Thus we can have multiple metrics but only forward the input once. Luckly that every metric from pytorch ignite has a
output_transform
method to do the thing.import pytorch_tao as tao from pytorch_tao.plugin import Metric from ignite.metrics import Accuracy, ROC_AUC trainer = tao.Trainer() @trainer.eval() def _eval(images, targets): return model(images), targets trainer.use(Metric("accuracy", Accuracy(lambda: logits, targets: logits > 0, targets))) trainer.use(Metric("roc_auc", ROC_AUC(lambda: logits, targets: torch.sigmoid(logits), targets)))
- fit(*, max_epochs: int)
Start the training and evaluation loop process.
- Parameters
max_epochs – how many epochs to train.
- to(device: device)
Move to device
- Parameters
device – torch.device
- train(optimizer: ~typing.Optional[~torch.optim.optimizer.Optimizer] = None, fields: ~typing.Optional[~typing.List[str]] = None, amp: bool = False, grad: bool = True, accumulate: int = 1, scaler: ~typing.Optional[~torch.cuda.amp.grad_scaler.GradScaler] = None, adversarial: ~typing.Optional[~pytorch_tao.adversarial.Adversarial] = None, adversarial_enabled: ~typing.Callable[[~ignite.engine.engine.Engine], bool] = <function Trainer.<lambda>>)
Decorator that define the training process.
- Parameters
optimizer – torch optimizer, if given, will override the one passed in the constructor.
fields – fields to selecte from the yield batch by dataloader. if None, then the all yield data will be passed in as the decorated functions parameters. If the yield type is of dict, then will be passed as keywords args, if it’s a tuple then will be passed as positional args.
amp – whether to use amp.
grad – whther to enabled grad.
accumulate – gradient accumulation step.
scaler – if amp is enabled, which scaler to use.
Ideally, the decorated function should do the following:
the parameters should match the fields arguments.
do model forwarding and compute loss, if multiple loss then a scalar loss should be made by weighted sum.
return at least a scalar represent the final loss, if returning type is dict, then a loss key should be there, if returning type if tuple, then the first element will be treated as the sum loss.
See the following code examples for details.
import pytorch_tao as tao trainer = tao.Trainer() # assume that dataloader yields a tuple of (images, targets) @trainer.train() def _train(images, targets): logits = model(images) loss = F.cross_entropy(logits, targets) return {"loss": loss}
import pytorch_tao as tao trainer = tao.Trainer() # assume that dataloader yields a tuple of (images, targets, metadata) @trainer.train(fields=(0, 1)) def _train(images, targets): logits = model(images) loss = F.cross_entropy(logits, targets) return {"loss": loss}
import pytorch_tao as tao trainer = tao.Trainer() # assume that dataloader yields a tuple of (images, targets, metadata) @trainer.train() def _train(images, targets, metadata): logits = model(images) loss = F.cross_entropy(logits, targets) return {"loss": loss}
- use(plugin: BasePlugin, at: Optional[str] = None)
Use a plugin.
If use a
pytorch_tao.plugins.TrainPlugin
, then it’ll be attached to train_engine by default, andpytorch_tao.plugins.ValPlugin
will attached to val_engine as well.If the plugin inherited from
pytorch_tao.plugins.BasePlugin
, then using this plugin should pass theat
argument to either “train” or “val”.- Parameters
plugin – the plugin to use.
at – which engine to attach.