Open-Sora-Plan
Open-Sora-Plan copied to clipboard
[Bug] fix sp bug when global_bs = 1
The original code has a bug when use reshape(-1) for automatic data layout inferring. Move it to fixed data layout.
Failed case: train_batch_size = 1, sp_size = 4, train_sp_batch_size=1.
hidden_states = rearrange(hidden_states, 'b h s d -> s b h d')
hidden_states = hidden_states.reshape(-1, attn.heads // sp_size, head_dim)
# [s * b, h // sp, d] -> [s // sp * b, h, d] -> [s // sp, b, h * d]
hidden_states = all_to_all_SBH(hidden_states, scatter_dim=0, gather_dim=1).reshape(-1, batch_size, h_size)
change to
hidden_states = rearrange(hidden_states, 'b h s d -> s h b d').contiguous()
hidden_states = all_to_all_SBH(hidden_states, scatter_dim=0, gather_dim=1)
# [s * b, h // sp, d] -> [s // sp * b, h, d] -> [s // sp, b, h * d]
hidden_states = rearrange(hidden_states, 's h b d -> s b (h d)').contiguous()
@LinB203 @apprivoiser