st-moe-pytorch icon indicating copy to clipboard operation
st-moe-pytorch copied to clipboard

Question on increasing batch size and sequence length

Open jambo6 opened this issue 1 year ago • 1 comments

Do you know how this giant all reduce works for giant architectures across hundreds of workers?

Specifically interested in this bit of code

if is_distributed:
        ...

        # gather and concat across batches, accounting for variable batch sizes
        x, batch_sizes = self.all_gather(x)

using standard notation from the code, the x returned from this is of shape [b * w, e, c, d] where standardly

  • b - batch size
  • w - world size
  • e - num experts
  • c - expert capacity (which we can say is something like n / e where n is seq len)
  • d - hidden dim

This means our overall tensor is of shape approx [b * w, n, d] which is the same as holding all the worker batches in memory on each individual device. I.e. on a per-device level we've moved from [b, n, d] -> [b * w, n, d]. Dont see how this can reasonably scale with w.

I'm currently at a loss to understanding how this doesn't prevent training over reasonable sizes.

E.g. in mixtral they have n = 32k, and if I have a large number of workers (even with batch size 1) this is not going to fit in memory.

Just wondering if I'm missing something or this is just a bottleneck inherent in MOE models.

Thanks very much for this code base, I found going through it highly informative!

jambo6 avatar Mar 12 '24 16:03 jambo6

Simlarly with seq_len, we have e.g. this code

combine_tensor = reduce(gates * mask_flat * one_hot_gate_indices * safe_one_hot_gates, "k ... -> ...", "sum")

however this intermediary

gates * mask_flat * one_hot_gate_indices * safe_one_hot_gates

has shape [k, b, n, e, c]. Again since c is like n / e then this is like b * n ** 2 which can hit CUDA OOMS easily for small batch sizes, 8 experts, and 32k seq len (again as stated in mixtral).

Again just wondering if this is some fundamental bottleneck with MOE models or there is some finer detail I am missing.

I have not seen it mentioned in papers (though admittedly I have not read that many related papers)

jambo6 avatar Mar 12 '24 16:03 jambo6