accelerate
accelerate copied to clipboard
Bug on resume training on single GPU and multiple GPUs
System Info
accelerate: 0.12.0
Information
- [X] The official example scripts
- [ ] My own modified scripts
Tasks
- [X] One of the scripts in the examples/ folder of Accelerate or an officially supported
no_trainerscript in theexamplesfolder of thetransformersrepo (such asrun_no_trainer_glue.py) - [X] My own task or dataset (give details below)
Reproduction
When loading checkpoint with accelerator.load_state() for resuming training, the DataLoader will sample data from the beginning, i.e, from epoch 0 step 0. An extreme case in (https://github.com/huggingface/accelerate/blob/main/examples/complete_nlp_example.py) is that the codes can only train a small part of the whole data when resuming many times.
In this example, https://github.com/huggingface/accelerate/blob/main/examples/complete_nlp_example.py#L185, In fact, it is resumed from epoch 0, not the start epoch.
import torch
import torch.utils.data as Data
from accelerate import Accelerator
from accelerate.utils import set_seed
accelerator = Accelerator()
BATCH_SIZE = 3
set_seed(42)
x = torch.linspace(1, 12, 12)
print(x)
y = torch.linspace(12, 1, 12)
print(y)
torch_dataset = Data.TensorDataset(x, y)
output_path = "save_path"
loader = Data.DataLoader(
dataset=torch_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
)
loader = accelerator.prepare(loader)
start_epoch = 0
# when start_epoch =2, how to reach epoch 2 and step 1?
for epoch in range(start_epoch, 5):
for step, (batch_x, batch_y) in enumerate(loader):
print(batch_x, batch_y)
if epoch ==2 and step==1:
accelerator.save_state(output_path)
When accelerator.load_state(output_path), how to reach epoch 2 and step 1? I tested (one GPU and multiple GPUs) that the loader samples data from epoch 0 step 0. The function of set_epoch is expected.
Expected behavior
Expect reach the data at epoch 2 and step 1.
This is not a bug. You need to resume your training in the exact same distributed setup.
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.