returnn icon indicating copy to clipboard operation
returnn copied to clipboard

Step count is not reset when loading a checkpoint and resetting the epoch

Open mmueller00 opened this issue 1 year ago • 2 comments

Current behavior in the torch engine when using a Checkpoint during training via "import_model_train_epoch1" is to reset the epoch to 0 but keeping the global train step count of the checkpoint (see https://github.com/rwth-i6/returnn/blob/13640bc0a19e245e79b1e6df1cd8ae6f40fbc005/returnn/torch/engine.py#L798). Is this the expected behavior or should we reset the step count as well?

mmueller00 avatar Nov 26 '24 13:11 mmueller00

For reference on the code, in get_epoch_model, this is the relevant case when training is done and import_model_train_epoch1 is set:

        elif config.value("task", "train") == "train" and import_model_train_epoch1 and start_epoch in [None, 1]:
            epoch_model = (0, import_model_train_epoch1)

And then the start epoch in training is last_epoch + 1, i.e. 1 here.

However, there is no such logic for the global train step. It just overtakes what it gets in the model checkpoint.

(Note, this logic on epoch/step in this _load_model func is a bit strange because for the _create_model call, it wants to use the right epoch/step of the model checkpoint.)

So, if we want to change that, and also start with step 0, it means:

  • Instead of the step -= 1, do:
if epoch == 0:
    step = 0
else:
    step -= 1
  • Instead of the step += 1, do:
if epoch != 1:
    step += 1

(Or maybe use start_epoch instead of 1.)

So the main question is, should we just change this, and this is good for everyone? Or make an option for it?

albertz avatar Nov 26 '24 14:11 albertz

Please vote here on this comment, use:

  • 👍: if you think we should just change this without option (changing current behavior), or
  • 👎: if we should add an option for it which is disabled by default (i.e. keeping current behavior if you don't use the option).

albertz avatar Nov 26 '24 14:11 albertz