transformers icon indicating copy to clipboard operation
transformers copied to clipboard

Add DDP token averaging for equivalent non-parallel training similar to #34191

Open sbwww opened this issue 1 year ago • 1 comments

Feature request

Token averaging in gradient accumulation was fixed in #34191 . But token averaging in DDP seems to have the same issue.


Expected behaivor

With all the tokens contributing to loss in each step (in each GPU, gradient accumulation step, and microbatch), the equation becomes:

$$ntokens=\sum\limits_{GPUs} \sum\limits_{gas} \sum\limits_{microb} (label\neq-100)$$

I believe we should average the above tokens at the same time for equivalent non-parallel training.


Current issue

Prior to #34191, the loss/gradients were averaged on $\sum\limits_{GPUs}$, $\sum\limits_{gas}$, and $\sum\limits_{microb}$ separately. And, the introduction of num_items_in_batch in #34191 refers to:

$$ntokens=\sum\limits_{gas} \sum\limits_{microb} (label\neq-100)$$

So, the loss/gradients are now averaged on $\sum\limits_{GPUs}$ and $\left(\sum\limits_{gas}\sum\limits_{microb}\right)$ separately. However, this still does not seem equivalent to non-parallel training.

Can we also incorporate $\sum\limits_{GPUs}$ when determining num_items_in_batch? Something like all_reduce(num_items_in_batch)?

Motivation

DDP seems not fully equivalent to non-parallel training.

related comments: https://github.com/huggingface/transformers/pull/34191#issuecomment-2421777304

Your contribution

Found some fairseq implementation of this feature

https://github.com/facebookresearch/fairseq/blob/018621f3cca02ca9de945dc082c3fb1a7f9f2deb/fairseq/trainer.py#L932-L949

sbwww avatar Oct 18 '24 10:10 sbwww

I observed this as well when I was running some experiments (things were close postfix, but not exact). Would you like to take a stab at a PR? :)

muellerzr avatar Oct 18 '24 16:10 muellerzr

A simple implemention may be:

  1. add all_reduce(num_items_in_batch, op=SUM) after: https://github.com/huggingface/transformers/blob/main/src/transformers/trainer.py#L2416
  2. add loss *= get_world_size() after: https://github.com/huggingface/transformers/blob/main/src/transformers/loss/loss_utils.py#L26

techkang avatar Oct 22 '24 04:10 techkang

Although this issue has little impact on the training results, it significantly affects to reproduce experiments across different hardware configurations. I hope it can be resolved alongside gradient accumulation.

I attempted to use all-reduce during training, but it slowed down the process. Is it possible to calculate the total number of tokens per batch across devices when initializing the Dataloader with accelerate (without compromising compatibility with the existing code) ?

TechxGenus avatar Oct 22 '24 12:10 TechxGenus

That is the issue with it, and why I'm not the biggest fan of that particular solution.

We can't, bc there are situations like IterableDatasets where that just cannot be possible.

The fairseq solution may be the way

muellerzr avatar Oct 22 '24 13:10 muellerzr

W B Chart 10_22_2024, 10_20_21 AM

Can confirm the fairseq solution works great, it'll be part of https://github.com/huggingface/transformers/pull/34283

muellerzr avatar Oct 22 '24 14:10 muellerzr

This however does not make any impact as we scale (current fix or these ones) image

This might be problem specific, however I did find the fix helped a little

muellerzr avatar Oct 22 '24 16:10 muellerzr

I'll leave this open for now. I didn't see significant discrepancies between DDP and non, but if users have stories/can show where it goes wrong, post them here for us to dig into please

muellerzr avatar Oct 23 '24 17:10 muellerzr

I tested multi-gpu performance with or without all_reduce. First, with num_items_in_batch = all_reduce(num_items_in_batch), all loss curves matched exactly. image Then, compared with the version without all_reduce, the loss curves mismatched: image However, the loss value was only slightly larger at the beginning. So introducing DDP all_reduce sync may only improve reproducibility.

Here is my code if anyone in need:

class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        if torch.cuda.device_count() > 1 and num_items_in_batch is not None:
            num_items_in_batch_torch = torch.tensor([num_items_in_batch]).cuda()
            torch.distributed.all_reduce(num_items_in_batch_torch, op=torch.distributed.ReduceOp.SUM)
            num_items_in_batch = int(num_items_in_batch_torch.cpu())

        loss = super().compute_loss(model, inputs, return_outputs=return_outputs, num_items_in_batch=num_items_in_batch)
        if torch.cuda.device_count() > 1 and num_items_in_batch is not None:
            loss *= torch.cuda.device_count()
        return loss

techkang avatar Oct 24 '24 03:10 techkang

What we can do then is add it in under a flag which is disabled by default (average_tokens_across_devices) into the TrainingArguments. @techkang want to take a stab at a PR?

muellerzr avatar Oct 24 '24 10:10 muellerzr

Thanks, I already created the pull request and tested on my machine. image

techkang avatar Oct 24 '24 12:10 techkang

A simple implemention may be:

  1. add all_reduce(num_items_in_batch, op=SUM) after: https://github.com/huggingface/transformers/blob/main/src/transformers/trainer.py#L2416
  2. add loss *= get_world_size() after: https://github.com/huggingface/transformers/blob/main/src/transformers/loss/loss_utils.py#L26

yyds,thanks

hedes1992 avatar Oct 28 '24 12:10 hedes1992

This issue look nice, but why is it(average_tokens_across_devices) disabled by default?

Seems that advantage of this PR is reproducibility and disadvantage is None.

Also we may assume the case where this feature is strictly needed:

  • assume in GPU 0, label was [1,-100,-100,....-100]
  • assume in GPU 1, label was [1,1,1,...,1]

Then since the loss calculation treats two GPU as same weight, unstable(since it's from only single item) loss signal from GPU 0 will dominate(50%) much than we needed.

MilkClouds avatar Jul 14 '25 02:07 MilkClouds