accelerate
accelerate copied to clipboard
GradientAccumulationPlugin(sync_with_dataloader=True) default behavior is bad for train/val dataloader setup
System Info
- `Accelerate` version: 0.33.0
- Platform: Linux-5.15.0-1067-azure-x86_64-with-glibc2.31
- `accelerate` bash location: /home/azureuser/miniconda3/envs/pytorch/bin/accelerate
- Python version: 3.10.13
- Numpy version: 1.26.4
- PyTorch version (GPU?): 2.4.0+cu118 (True)
- PyTorch XPU available: False
- PyTorch NPU available: False
- PyTorch MLU available: False
- PyTorch MUSA available: False
- System RAM: 866.06 GB
- GPU type: NVIDIA A100 80GB PCIe
- `Accelerate` default config:
- compute_environment: LOCAL_MACHINE
- distributed_type: MULTI_GPU
- mixed_precision: bf16
- use_cpu: False
- debug: False
- num_processes: 4
- machine_rank: 0
- num_machines: 1
- gpu_ids: all
- rdzv_backend: static
- same_network: True
- main_training_function: main
- enable_cpu_affinity: False
- downcast_bf16: no
- tpu_use_cluster: False
- tpu_use_sudo: False
- tpu_env: []
Information
- [ ] The official example scripts
- [X] My own modified scripts
Tasks
- [ ] One of the scripts in the examples/ folder of Accelerate or an officially supported
no_trainer
script in theexamples
folder of thetransformers
repo (such asrun_no_trainer_glue.py
) - [X] My own task or dataset (give details below)
Reproduction
I prepare my training dataloader and validation dataloader like this:
self.denoiser.model, self.train_dl, self.valid_dl, self.optimizer = self.accelerator.prepare(
self.denoiser.model, self.train_dl, self.valid_dl, self.optimizer
)
I create iterators from them like this:
def cycle(self, dl, dl_name: Optional[str] = None):
counter = 0
while True:
counter += 1
self.logger.info(
f"process {self.accelerator.process_index} starting {dl_name} cycle {counter}",
main_process_only=False,
)
for data in dl:
yield data
def create_train_iter(self):
assert exists(self.train_dl), "training dataloader has not been registered with the trainer yet"
if exists(self.train_dl_iter):
return
self.train_dl_iter = self.cycle(self.train_dl, dl_name="train")
def create_valid_iter(self):
assert exists(self.valid_dl), "validation dataloader has not been registered with the trainer yet"
if exists(self.valid_dl_iter):
return
self.valid_dl_iter = self.cycle(self.valid_dl, dl_name="validation")
And I am iterating over them like this:
def train_step(self, *args, **kwargs):
assert self.prepared, "You need to prepare the trainer before training"
self.create_train_iter()
sample = next(self.train_dl_iter)
with self.accelerator.accumulate(self.denoiser.model):
losses = self.denoiser.forward(*args, accelerator=self.accelerator, **{**kwargs, **sample})
gathered_losses = self.accelerator.gather(losses) # for logging
self.accelerator.backward(losses.mean())
self.update()
return gathered_losses.mean().item()
@torch.no_grad()
def valid_step(self, **kwargs):
assert self.prepared, "You need to prepare the trainer before validation"
self.create_valid_iter()
sample = next(self.valid_dl_iter)
with denoiser_eval_model_switch(self.denoiser, ema_model=self.ema_model):
losses = self.denoiser.forward(**{**kwargs, **sample})
return self.accelerator.gather(losses).mean().item()
Expected behavior
For a few days I tackled with a problematic pattern I've observed during DDP training. In a strangely periodic manner, I would see the gradient norms collapse and for a period every step did a gradient sync instead of accumulating gradients. Like this:
This would not cause training to crash, but it messed up the LR scheduler and overall was a waste of expensive GPU time.
I narrowed it down to this collapse happening every time the validation dataloader finishes a cycle. In those moments, because of GradientState.sync_with_dataloader = True
what happens is that the GradientState enters a perpetual sync_gradients = True
state until the next validation iteration. Even though the active dataloader is now the train dataloader.
I am not sure what could be the best solution for this, perhaps a way to designate whether a dataloader is meant for training or validation (no need for gradient sync tracking).