accelerate
accelerate copied to clipboard
Dataloader yields wrong sequence when resuming training
System Info
- `Accelerate` version: 0.29.3
- Platform: Linux-5.15.0-101-generic-x86_64-with-glibc2.35
- `accelerate` bash location: /(...)/.venv/bin/accelerate
- Python version: 3.10.11
- Numpy version: 1.26.4
- PyTorch version (GPU?): 2.2.2+cu121 (True)
- PyTorch XPU available: False
- PyTorch NPU available: False
- PyTorch MLU available: False
- System RAM: 377.53 GB
- GPU type: Quadro RTX 8000
- `Accelerate` default config:
Not found
Information
- [ ] The official example scripts
- [X] My own modified scripts
Tasks
- [ ] One of the scripts in the examples/ folder of Accelerate or an officially supported
no_trainerscript in theexamplesfolder of thetransformersrepo (such asrun_no_trainer_glue.py) - [X] My own task or dataset (give details below)
Reproduction
This is about training with a dataloader which shuffles at every epoch. For reproducibility, when resuming training, the dataloader's order should be identical to the one from the epoch where training was interrupted. However, setting train_dataloader.set_epoch(epoch) has zero effect (no change on the sequence yielded no matter the value of the epoch used).
The sequence the dataloader yields is actually the epoch n+1 if training was interrupted during epoch n.
Here is a minimal example of outputs for a DataLoader(list(range(10)), shuffle=True, batch_size=4)
Without resuming:
Epoch: 0 [6, 7, 1, 4] [2, 0, 9, 8] [3, 5]
Epoch: 1 [2, 4, 7, 0] [8, 9, 5, 3] [6, 1]
Epoch: 2 [7, 4, 5, 1] [9, 3, 8, 2] [0, 6]
With resuming after two steps:
Epoch: 0 [6, 7, 1, 4] [2, 0, 9, 8] *interuption
*resuming [6, 1]
Epoch: 1 [7, 4, 5, 1] [9, 3, 8, 2] [0, 6]
Code to reproduce
import os
import re
import hydra
from omegaconf import DictConfig
import torch
from torch.utils.data import DataLoader
from accelerate import Accelerator
from accelerate.utils import ProjectConfiguration, set_seed
from collections import defaultdict
class Metrics(defaultdict):
def __init__(self):
super().__init__(int)
def state_dict(self):
return dict(self)
def load_state_dict(self, state_dict):
for k, v in state_dict.items():
self[k] = v
def log(self, accelerator: Accelerator):
# Build the metrics to log
metrics_log = dict()
metrics_log["train/epochs"] = self["train/epochs"]
metrics_log["train/steps"] = self["train/steps"]
# Log the metrics
accelerator.log(metrics_log)
def write_to_file(path, accelerator, obj):
if accelerator.is_main_process:
with open(path, "a") as f:
f.write(obj)
f.write("\n")
@hydra.main(version_base=None, config_path="../conf", config_name="config")
def train(cfg: DictConfig):
# Get the last checkpoint id
checkpoint_dir = os.path.join(cfg.trainer.dir, "checkpoints")
iteration = 0
if cfg.trainer.resume and os.path.exists(checkpoint_dir) and len(os.listdir(checkpoint_dir)) > 0:
folders = os.listdir(checkpoint_dir)
iteration = max(int(re.findall(r"[\/]?([0-9]+)(?=[^\/]*$)", folder)[0]) for folder in folders) + 1
# Accelerator object
project_config = ProjectConfiguration(
cfg.trainer.dir,
automatic_checkpoint_naming=True,
total_limit=50,
iteration=iteration,
)
accelerator = Accelerator(
mixed_precision="no",
gradient_accumulation_steps=1,
project_config=project_config,
)
# File to log outputs
path = "log.txt"
# Set the seed
set_seed(cfg.seed)
# Local and global counters
metrics = Metrics()
accelerator.register_for_checkpointing(metrics)
train_dataloader = DataLoader(list(range(10)), shuffle=True, batch_size=4)
# Accelerate
train_dataloader = accelerator.prepare(train_dataloader)
# Resume from the latest checkpoint
skipped_train_dataloader = None
if cfg.trainer.resume and os.path.exists(checkpoint_dir) and len(os.listdir(checkpoint_dir)) > 0:
accelerator.load_state()
if accelerator.is_main_process:
write_to_file(path, accelerator, "\nResuming in epoch: " + str(metrics["train/epochs"]))
train_dataloader.set_epoch(metrics["train/epochs"])
skipped_train_dataloader = accelerator.skip_first_batches(train_dataloader, metrics["train/batches"] % len(train_dataloader))
while cfg.trainer.max_steps > metrics["train/steps"]:
# Use skipped_train_dataloader the first epoch after resuming
dataloader = train_dataloader if skipped_train_dataloader is None else skipped_train_dataloader
write_to_file(path, accelerator, "\nEpoch: " + str(metrics["train/epochs"]))
for batch in dataloader:
# Update number of batches
metrics["train/batches"] += 1
write_to_file(path, accelerator, "\nSteps: " + str(metrics["train/steps"]))
write_to_file(path, accelerator, str(torch.flatten(batch).tolist()))
metrics["train/steps"] += 1
accelerator.save_state()
if metrics["train/steps"] >= cfg.trainer.max_steps:
break
# Log metrics
metrics["train/epochs"] += 1
# "Remove" the skipped dataloader once exhausted
skipped_train_dataloader = None
# Make sure that the wandb tracker finishes correctly and close the progress bar
accelerator.end_training()
if __name__ == "__main__":
train()
Expected behavior
The dataloader should yield identical sequences with or without resuming.
Can you try updating your accelerate version to see if we fixed it in the prior releases?
Hi! Thank you for your answer. It gives the same results with accelerate 0.30.1
Hello! I am gently uping this issue to know if you have had a chance to look into it?
Sorry for the delay. Zach is currently out of office but I'm sure he'll look into it when he's back.
Also seeing a symptom of this issue [or at least it seems very related]! I'm seeing a weird occurrence where resuming training causes the master rank to only perform grad accum for e.g., 2 passes when it should be 8 [as it is correctly on the other ranks]. Subsequent steps are performed correctly, just the first one is flawed [which eventually causes a deadlock once the dataloader is exhausted].
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
@muellerzr Could you check on this issue again?
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.