MONAI
MONAI copied to clipboard
Enable setting of training iteration in Trainers
Is your feature request related to a problem? Please describe.
Currently, SupervisedTrainer supports controlling the number of iterations by adjusting epoch_length
and max_epochs
. It would be nice to be able to set the number of iterations to be executed directly.
Describe the solution you'd like
Add a n_iterations
argument (or similar) that allows overwriting epoch-based definitions of the training steps number of training steps to be executed.
Note, this should allow the training to resume from the final iteration if n_iterations
is reached. Related to #4554.
Describe alternatives you've considered
Live with adjusting epoch_length
and max_epochs
but that seems confusing.
Additional context Add any other context or screenshots about the feature request here.
@holgerroth Is epoch_length corresponds to minibatch_size or something else?
epoch_length
corresponds to number of iterations needed to iterate once of the data (i.e., one epoch). It defaults to len(train_data_loader)
.
Hi @holgerroth ,
I got confused by the feature request, is this n_iterations
same as max_epochs=1
and set epoch_length
?
Thanks.
Yes, I would propose that if n_iterations
is provided, it should be set to max_epochs=1 and epoch_length=n_iterations.
Hi @vfdev-5 ,
I think there is a typical use case: users just set number of total iterations to run with random sampler, even the n_iterations
is bigger than the dataset length. How can we support that with ignite engine?
CC @holgerroth @wyli
Thanks.
@holgerroth we have on master
and in nightly releases max_iters
arg for Engine.run()
, https://pytorch.org/ignite/master/generated/ignite.engine.engine.Engine.html#ignite.engine.engine.Engine.run
It probably works as you asked for:
assert len(data) == 100
s = trainer.run(data, max_iters=123)
assert s.iteration = 123
However, we haven't yet released that in stable as there can be issues with how this is saved/loaded in checkpoints etc.
A workaround for stable release to that can be
max_iters = 1234
epoch_length = len(data)
max_epochs = max_iters // epoch_length + 1
@trainer.on(Events.ITERATION_COMPLETED(once=max_iters))
def stop():
trainer.terminate()
trainer.run(data, max_epochs=max_epochs)