Open-Sora-Plan
Open-Sora-Plan copied to clipboard
Sequence Parallelism Aware Training Loss
Hi there! When you are training with sequence parallel attention, I was wondering if you scale the loss function properly, as each GPU card will only contain a subset of the total sequence when calculating the loss function naively?
See the code in Megatron-Deepspeed here:
https://github.com/microsoft/Megatron-DeepSpeed/blob/main/megatron/core/sequence_parallel/cross_entropy.py#L26-L30
loss_all = torch.empty(ctx.seqlen, batch_size, dtype=vocab_seq_parallel_logits.dtype, device=vocab_seq_parallel_logits.device)
if version.parse(torch.__version__) >= version.parse('1.13'):
torch.distributed.all_gather_into_tensor(loss_all, loss, group=get_sequence_parallel_group())
else:
torch.distributed._all_gather_base(loss_all, loss, group=get_sequence_parallel_group())
It seems that the authors would gather the loss from all sequence parallel group before calculating the loss function.
I was wondering if you are doing the same?