Open-Sora-Plan icon indicating copy to clipboard operation
Open-Sora-Plan copied to clipboard

[Bug] fix sp bug when global_bs = 1

Open foreverpiano opened this issue 5 months ago • 0 comments

The original code has a bug when use reshape(-1) for automatic data layout inferring. Move it to fixed data layout.

Failed case: train_batch_size = 1, sp_size = 4, train_sp_batch_size=1.

hidden_states = rearrange(hidden_states, 'b h s d -> s b h d')
hidden_states = hidden_states.reshape(-1, attn.heads // sp_size, head_dim)
# [s * b, h // sp, d] -> [s // sp * b, h, d] -> [s // sp, b, h * d]
hidden_states = all_to_all_SBH(hidden_states, scatter_dim=0, gather_dim=1).reshape(-1, batch_size, h_size)

change to

hidden_states = rearrange(hidden_states, 'b h s d -> s h b d').contiguous()
hidden_states = all_to_all_SBH(hidden_states, scatter_dim=0, gather_dim=1)
# [s * b, h // sp, d] -> [s // sp * b, h, d] -> [s // sp, b, h * d]
hidden_states = rearrange(hidden_states, 's h b d -> s b (h d)').contiguous()

@LinB203 @apprivoiser

foreverpiano avatar Sep 08 '24 12:09 foreverpiano