Open-Sora-Plan icon indicating copy to clipboard operation
Open-Sora-Plan copied to clipboard

Sequence Parallelism Aware Training Loss

Open ginward opened this issue 6 months ago • 1 comments

Hi there! When you are training with sequence parallel attention, I was wondering if you scale the loss function properly, as each GPU card will only contain a subset of the total sequence when calculating the loss function naively?

See the code in Megatron-Deepspeed here:

https://github.com/microsoft/Megatron-DeepSpeed/blob/main/megatron/core/sequence_parallel/cross_entropy.py#L26-L30

        loss_all = torch.empty(ctx.seqlen, batch_size, dtype=vocab_seq_parallel_logits.dtype, device=vocab_seq_parallel_logits.device)
        if version.parse(torch.__version__) >= version.parse('1.13'):
            torch.distributed.all_gather_into_tensor(loss_all, loss, group=get_sequence_parallel_group())
        else:
            torch.distributed._all_gather_base(loss_all, loss, group=get_sequence_parallel_group())

It seems that the authors would gather the loss from all sequence parallel group before calculating the loss function.

I was wondering if you are doing the same?

ginward avatar Aug 01 '24 19:08 ginward