ignite icon indicating copy to clipboard operation
ignite copied to clipboard

`Timer`'s misleading behaviour when epoch completion time calculated

Open Priyansi opened this issue 4 years ago • 7 comments

When I tried calculating the time taken to complete a single epoch via Timer, the handlers attached to trainer before Timer were executed first, and thus their time also got recorded by the Timer too for a single epoch. Therefore the true time taken for epoch completion, provided by trainer.state.times, is less than what Timer calculated. This can be misleading. More clarification on how this actually works in the docs would be appreciated. Or the Timer's functionality can be enhanced to step before all other handlers attached to an event could also be helpful. Notebook to quickly verify this here.

Priyansi avatar Aug 11 '21 11:08 Priyansi

The logic is here https://github.com/pytorch/ignite/blob/38f30c41076ebe5bb954228729b17b9334e8e7eb/ignite/engine/engine.py#L745

The state timers should return the times for specific events. However, if an handler, during a given event, tries accessing the state timer of this event, the time is not yet computed. Yet the value in the timer is the time spent in the previous events rather than an undefined value.

I agree that it should be explained, or modified. For instance, the state timer of an event could be lazily updated each time it is reached by the user in the event. It could be done but not sure it really worth…

sdesrozis avatar Aug 16 '21 22:08 sdesrozis

Hi, I would like to work on this issue, please assign this to me. Also, please provide examples to understand the issue better.

FarehaNousheen avatar Oct 16 '21 16:10 FarehaNousheen

Also, please provide examples to understand the issue better.

@FarehaNousheen please read attentively https://github.com/pytorch/ignite/issues/2157#issue-966515984. There is a notebook with a concrete example provided by Priyansi

vfdev-5 avatar Oct 18 '21 08:10 vfdev-5

Things to do for this issue:

  • [ ] Reproduce the results from https://colab.research.google.com/drive/1xX7slM9BiDSWYGId5zB_8ISneJ0wezAM
    • rerun the notebook and confirm that there is a difference of 15-20 seconds between time measurements for the single epoch. For example:
Time taken by a single epoch calculated by Timer: 197.58
Time Taken for single epoch calculated by State of engine : 170.89924907684326
  • [ ] Modify the notebook such that timer is attached before running the validation: Current notebook
trainer = create_supervised_trainer(model, optimizer, criterion, device=device)
evaluator = create_supervised_evaluator(
    model, metrics={"accuracy": Accuracy(), "loss": Loss(criterion)}, device=device
)

@trainer.on(Events.EPOCH_COMPLETED)
def log_training_results(trainer):
    evaluator.run(train_loader)
    metrics = evaluator.state.metrics
    print(
        f"Training Results - Epoch[{trainer.state.epoch}] Avg accuracy: {metrics['accuracy']:.2f} Avg loss: {metrics['loss']:.2f}"
    )
    print(
    f"Time taken by a single epoch calculated by Timer: {timer.value():.2f}"
    )
    print(
        f"Time Taken for single epoch calculated by State of engine : {trainer.state.times['EPOCH_COMPLETED']}"
    )


@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(trainer):
    evaluator.run(val_loader)
    metrics = evaluator.state.metrics
    print(
        f"Validation Results - Epoch[{trainer.state.epoch}] Avg accuracy: {metrics['accuracy']:.2f} Avg loss: {metrics['loss']:.2f}"
    )
    print(
    f"Time taken by a single epoch calculated by Timer: {timer.value():.2f}"
    )
    print(
        f"Time Taken for single epoch calculated by State of engine : {trainer.state.times['EPOCH_COMPLETED']}"
    )

timer = Timer(average=True)
timer.attach(trainer,
             start=Events.EPOCH_STARTED,
             resume=Events.EPOCH_STARTED,
             pause=Events.EPOCH_COMPLETED,
             step=Events.EPOCH_COMPLETED)

Expected modfication:

trainer = create_supervised_trainer(model, optimizer, criterion, device=device)
evaluator = create_supervised_evaluator(
    model, metrics={"accuracy": Accuracy(), "loss": Loss(criterion)}, device=device
)

timer = Timer(average=True)
timer.attach(trainer,
             start=Events.EPOCH_STARTED,
             resume=Events.EPOCH_STARTED,
             pause=Events.EPOCH_COMPLETED,
             step=Events.EPOCH_COMPLETED)

@trainer.on(Events.EPOCH_COMPLETED)
def log_training_results(trainer):
    evaluator.run(train_loader)
    metrics = evaluator.state.metrics
    print(
        f"Training Results - Epoch[{trainer.state.epoch}] Avg accuracy: {metrics['accuracy']:.2f} Avg loss: {metrics['loss']:.2f}"
    )
    print(
    f"Time taken by a single epoch calculated by Timer: {timer.value():.2f}"
    )
    print(
        f"Time Taken for single epoch calculated by State of engine : {trainer.state.times['EPOCH_COMPLETED']}"
    )


@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(trainer):
    evaluator.run(val_loader)
    metrics = evaluator.state.metrics
    print(
        f"Validation Results - Epoch[{trainer.state.epoch}] Avg accuracy: {metrics['accuracy']:.2f} Avg loss: {metrics['loss']:.2f}"
    )
    print(
    f"Time taken by a single epoch calculated by Timer: {timer.value():.2f}"
    )
    print(
        f"Time Taken for single epoch calculated by State of engine : {trainer.state.times['EPOCH_COMPLETED']}"
    )
  • [ ] Re-run the notebook and report here the results for the time measurements for a single epoch
    • put the logs here
    • put the link on the colab here and make sure that colab is accessible to everyone

vfdev-5 avatar Nov 08 '21 12:11 vfdev-5

That's a really detailed description. Thank you for sharing it. I'll work on it tomorrow.

FarehaNousheen avatar Nov 15 '21 18:11 FarehaNousheen

Hi all, I have modified the notebook with the timer attached before running the validation. Sharing the link of Google colab here https://colab.research.google.com/drive/1SJYqgmzYvGivruq9OAS0q6MADjBFqI13

FarehaNousheen avatar Nov 24 '21 18:11 FarehaNousheen

I executed the notebook with the Timer attached before the second code snippet and removed the cell where the Timer is attached after validation. In doing so I got the result that is updated on the copy of the Colab shared above. In short sharing results below. Training Results - Epoch[1] Avg accuracy: 0.95 Avg loss: 0.16 Time taken by a single epoch calculated by Timer: 73.73 Time Taken for single epoch calculated by State of engine : 73.72484469413757 Validation Results - Epoch[1] Avg accuracy: 0.95 Avg loss: 0.15 Time taken by a single epoch calculated by Timer: 73.73 Time Taken for single epoch calculated by State of engine : 73.72484469413757 Training Results - Epoch[2] Avg accuracy: 0.97 Avg loss: 0.13 Time taken by a single epoch calculated by Timer: 73.51 Time Taken for single epoch calculated by State of engine : 73.5108630657196 Validation Results - Epoch[2] Avg accuracy: 0.96 Avg loss: 0.13 Time taken by a single epoch calculated by Timer: 73.51 Time Taken for single epoch calculated by State of engine : 73.5108630657196

FarehaNousheen avatar Dec 07 '21 15:12 FarehaNousheen