MONAI icon indicating copy to clipboard operation
MONAI copied to clipboard

Enable setting of training iteration in Trainers

Open holgerroth opened this issue 2 years ago • 6 comments

Is your feature request related to a problem? Please describe. Currently, SupervisedTrainer supports controlling the number of iterations by adjusting epoch_length and max_epochs. It would be nice to be able to set the number of iterations to be executed directly.

Describe the solution you'd like Add a n_iterations argument (or similar) that allows overwriting epoch-based definitions of the training steps number of training steps to be executed. Note, this should allow the training to resume from the final iteration if n_iterations is reached. Related to #4554.

Describe alternatives you've considered Live with adjusting epoch_length and max_epochs but that seems confusing.

Additional context Add any other context or screenshots about the feature request here.

holgerroth avatar Jun 22 '22 15:06 holgerroth

@holgerroth Is epoch_length corresponds to minibatch_size or something else?

drmhrehman avatar Jun 22 '22 16:06 drmhrehman

epoch_length corresponds to number of iterations needed to iterate once of the data (i.e., one epoch). It defaults to len(train_data_loader).

holgerroth avatar Jun 22 '22 19:06 holgerroth

Hi @holgerroth ,

I got confused by the feature request, is this n_iterations same as max_epochs=1 and set epoch_length?

Thanks.

Nic-Ma avatar Aug 10 '22 07:08 Nic-Ma

Yes, I would propose that if n_iterations is provided, it should be set to max_epochs=1 and epoch_length=n_iterations.

holgerroth avatar Aug 10 '22 13:08 holgerroth

Hi @vfdev-5 ,

I think there is a typical use case: users just set number of total iterations to run with random sampler, even the n_iterations is bigger than the dataset length. How can we support that with ignite engine? CC @holgerroth @wyli

Thanks.

Nic-Ma avatar Aug 17 '22 15:08 Nic-Ma

@holgerroth we have on master and in nightly releases max_iters arg for Engine.run(), https://pytorch.org/ignite/master/generated/ignite.engine.engine.Engine.html#ignite.engine.engine.Engine.run It probably works as you asked for:

assert len(data) == 100
s = trainer.run(data, max_iters=123)
assert s.iteration = 123

However, we haven't yet released that in stable as there can be issues with how this is saved/loaded in checkpoints etc.

A workaround for stable release to that can be

max_iters = 1234
epoch_length = len(data)
max_epochs = max_iters // epoch_length + 1

@trainer.on(Events.ITERATION_COMPLETED(once=max_iters))
def stop():
    trainer.terminate()

trainer.run(data, max_epochs=max_epochs)

vfdev-5 avatar Aug 17 '22 16:08 vfdev-5