resume training from the last+1 batch of data
the same bug from tinyllama
Hey @peiji1981
Can you explain this change? Our pretraining script saves and loads the iteration number in the state directly as iter_num:
https://github.com/Lightning-AI/litgpt/blob/502a9817bfbd025627c34c229b8c3eb2eca51150/litgpt/pretrain.py#L183-L195
So the iteration is already correct and updated at that point. A counter in the loop is not needed.
Hey @peiji1981 Can you explain this change? Our pretraining script saves and loads the iteration number in the state directly as
iter_num:https://github.com/Lightning-AI/litgpt/blob/502a9817bfbd025627c34c229b8c3eb2eca51150/litgpt/pretrain.py#L183-L195
So the iteration is already correct and updated at that point. A counter in the loop is not needed.
when resuming training, you should jump over the data batchs trained before
@awaelchli
@peiji1981 The important detail here is that LitGPT leverages the Streaming DataLoader from LitData for pretraining. It is stateful, meaning it remembers at which sample the training was stopped/interrupted at the checkpoint. We save and load the dataloader state here: https://github.com/Lightning-AI/litgpt/blob/5895df1004c3d05ef69f7aeffcfee757dbc42d58/litgpt/pretrain.py#L195 https://github.com/Lightning-AI/litgpt/blob/5895df1004c3d05ef69f7aeffcfee757dbc42d58/litgpt/pretrain.py#L204
Your idea would interfere with this and make the resuming incorrect.
Aside from that, your proposed change wouldn't achieve what you want even if we used a regular dataloader. When you pass an initial iteration to the enumerate() as a second argument, it won't skip elements in the iterable, but only offset the counter for the for-loop.