verl icon indicating copy to clipboard operation
verl copied to clipboard

Gradient Accumulation Error with Ulysses Sequence Parallel Causes Inconsistent Loss

Open GoneZ5 opened this issue 5 months ago • 15 comments

First of all, thank you for your great work on this repository! I encountered a loss inconsistency issue when using Ulysses Sequence Parallel for SFT (commit id: 2c85b432996ad28ed9756ae745e7fcb7ec6eee10). In the verl framework, enabling Ulysses SP will automatically normalize the train batch size, so that gradient accumulation is used to keep the global batch size unchanged (I think this is a great design).

def _normalize_config_bsz(self):
    dp_size = self.device_mesh.size(0) if not self.ulysses_device_mesh else self.ulysses_device_mesh.size(0)
    if self.device_mesh.get_rank() == 0:
        print(f"Normalize batch size by dp {dp_size}")

    assert self.config.data.train_batch_size % dp_size == 0, f"Global batch size {self.config.data.train_batch_size} is not divisible by dp size {dp_size}"

    self.config.data.train_batch_size //= dp_size

    assert self.config.data.train_batch_size % self.config.data.micro_batch_size_per_gpu == 0

However, when I ran SFT experiments with the same data and SP set to 1, 2, and 8, I observed large differences in loss among the runs. The loss and grad_norm curves are as follows:

  • Purple: sft-qwen3-32b-lr5e-5-32k-gpu64-bsz64-sp1
  • Green: sft-qwen3-32b-lr5e-5-32k-gpu64-bsz64-sp2
  • Blue: sft-qwen3-32b-lr5e-5-32k-gpu64-bsz64-sp8
Image Image

Additionally, I modified the batch size normalization logic so that when SP is enabled, the global batch size is directly reduced by a factor of SP, and no gradient accumulation is performed.

self.config.data.train_batch_size //= (dp_size * getattr(self.config, "ulysses_sequence_parallel_size", 1))

I then re-ran experiments with SP=1 and SP=2. In order to keep the global batch size consistent, I used twice as many GPUs for SP=2. In this case, the loss curves were consistent across both experiments. The new loss and grad_norm curves are as follows:

  • Black: sft-qwen3-32b-lr5e-5-32k-gpu32-bsz32-sp1
  • Pink: sft-qwen3-32b-lr5e-5-32k-gpu64-bsz32-sp2
Image Image

My questions are:

  • Is the loss inconsistency caused by incompatibility between SP > 1 and gradient accumulation, or is it a problem with gradient accumulation itself?

  • If the incompatibility is indeed due to SP > 1 and gradient accumulation, will this also affect RL experiments, such as GRPO or DAPO?

I look forward to your reply, and thank you again for your excellent work on this repository!

GoneZ5 avatar Aug 01 '25 02:08 GoneZ5

https://github.com/volcengine/verl/issues/2919 - Suggested a fix in an issue which I raised. Maybe that would fix your issue

puneeshkhanna avatar Aug 05 '25 10:08 puneeshkhanna

#2919 - Suggested a fix in an issue which I raised. Maybe that would fix your issue

Thanks for your reply! I have tried passing n_micro_batches as a parameter to the _compute_loss_and_backward method and performing loss /= n_micro_batches before loss.backward(), but the loss still has a large error. The loss and grad_norm curves are as follows:

  • Black: sft-qwen3-32b-lr5e-5-32k-gpu32-bsz32-sp1
  • Pink: sft-qwen3-32b-lr5e-5-32k-gpu64-bsz32-sp2
  • Green: sft-qwen3-32b-lr5e-5-32k-gpu32-bsz32-sp2-ga2 (original batch size normalization code)
Image Image

GoneZ5 avatar Aug 05 '25 11:08 GoneZ5

Hope you removed the division by n micro batches in below line too for step loss logging

for micro_batch in micro_batches:
            loss = self._compute_loss_and_backward(batch=micro_batch, n_micro_batches=n_micro_batches) #/ n_micro_batches
            step_loss += loss.item()

puneeshkhanna avatar Aug 05 '25 12:08 puneeshkhanna

Hope you removed the division by n micro batches in below line too for step loss logging

for micro_batch in micro_batches:
            loss = self._compute_loss_and_backward(batch=micro_batch, n_micro_batches=n_micro_batches) #/ n_micro_batches
            step_loss += loss.item()

Yes, I removed it.

GoneZ5 avatar Aug 05 '25 12:08 GoneZ5

Also I think for this line, we should pass grad_scaler to False. But your run showing quite different loss with grad accumulation. loss = gather_outpus_and_unpad(loss, gather_dim=0, unpad_dim=0, padding_size=pad_size, grad_scaler=False)

puneeshkhanna avatar Aug 05 '25 13:08 puneeshkhanna

Maybe try above too, it has a chance to help.

puneeshkhanna avatar Aug 05 '25 13:08 puneeshkhanna

Yup, I noticed this before and tried it, but the loss was still inconsistent. But I didn't try it with performing loss /= n_micro_batches before loss.backward(). I'll try it again. Thank you!

GoneZ5 avatar Aug 05 '25 14:08 GoneZ5

But I think we should keep grad_scaler=True, because the loss of experiment sft-qwen3-32b-lr5e-5-32k-gpu64-bsz32-sp2 is correct.

GoneZ5 avatar Aug 05 '25 15:08 GoneZ5

Hi, thanks for sharing your experiment results! We are also running SFT with sp and gradient accumulation, and the trained model shows some strange behaviors. Have you figured out the root cause of this inconsistent loss?

Ber666 avatar Aug 18 '25 20:08 Ber666

Hi, thanks for the follow-up. After the ablation runs, I’m afraid I haven’t been able to pinpoint the root cause for the inconsistent loss.

GoneZ5 avatar Aug 19 '25 02:08 GoneZ5

Hi, thanks for the follow-up.嗨,感谢您的跟进。 After the ablation runs, I’m afraid I haven’t been able to pinpoint the root cause for the inconsistent loss.消融运行后,恐怕我无法查明丢失不一致的根本原因。

Regarding the loss calculation during gradient accumulation: if the loss is computed using a token-level strategy where the number of valid tokens can vary across different micro-batches, wouldn't scaling the loss simply by 1 / num_micro_batches be inaccurate? Image

Schilings avatar Sep 09 '25 12:09 Schilings

Hi, thanks for the follow-up.嗨,感谢您的跟进。 After the ablation runs, I’m afraid I haven’t been able to pinpoint the root cause for the inconsistent loss.消融运行后,恐怕我无法查明丢失不一致的根本原因。Hi, thanks for the follow-up.嗨,感谢您的跟进。 After the ablation runs, I’m afraid I haven’t been able to pinpoint the root cause for the inconsistent loss.消融运行后,恐怕我无法查明丢失不一致的根本原因。

Regarding the loss calculation during gradient accumulation: if the loss is computed using a token-level strategy where the number of valid tokens can vary across different micro-batches, wouldn't scaling the loss simply by 1 / num_micro_batches be inaccurate?关于梯度累积过程中的损失计算:如果损失是采用令牌级策略计算的,且不同微批次的有效令牌数量可能不同,那么仅用1/微批次数量来缩放损失会不会不准确呢? Image

i think so not only sft, but also ppo training methods implemented in verl face this problem see https://huggingface.co/blog/gradient_accumulation

exiarepairii avatar Sep 16 '25 12:09 exiarepairii

not only sft, but also ppo training methods implemented in verl face this problem see https://huggingface.co/blog/gradient_accumulation

I am testing with DAPO, and it seems that the behavior also changes in DAPO.

Image Image Image

Taishi-N324 avatar Sep 23 '25 01:09 Taishi-N324

With balance dp token and divide by n micro batches fix, the implementation in SFT is same as recommended here https://huggingface.co/docs/accelerate/en/usage_guides/gradient_accumulation#gradient-accumulation-on-training-samples-of-variable-size and https://github.com/huggingface/accelerate/blob/main/examples/by_feature/gradient_accumulation_for_autoregressive_models.py

puneeshkhanna avatar Sep 23 '25 03:09 puneeshkhanna

https://github.com/volcengine/verl/commit/e160d3b2e0e0ed7607b6def2c7cac4ddba6b8d6f PR from @puneeshkhanna fixed this problem.

We might close this issue.

longxudou avatar Nov 12 '25 09:11 longxudou