ColossalAI icon indicating copy to clipboard operation
ColossalAI copied to clipboard

How PP and ZeRO stage 2+ work together?

Open hyunwoongko opened this issue 3 years ago • 6 comments

At least as far as I know, ZeRO2 splits gradients and PP accumulates gradients, so there's no real performance boost for these two mechanisms working together.

related issues

  • https://github.com/EleutherAI/gpt-neox/issues/67
  • https://github.com/EleutherAI/gpt-neox/issues/62
  • https://github.com/microsoft/DeepSpeed/issues/1110

But https://github.com/hpcaitech/ColossalAI/pull/477 makes PP and ZeRO stage 2+ work together. how they are working together?

hyunwoongko avatar Apr 06 '22 15:04 hyunwoongko

We currently just want to support more parallel training methods. As ZeRO and PP are both important and useful parallel training methods, we think users may want to use them together in some special cases. For example, when a layer cannot fit in the memory, like GPT-3, we must partition the layer. Absolutely, we can use TP+PP. But if the GPU memory is not enough for TP, ZeRO+PP may be another solution. As for the efficiency, we can optimize it in the future.

ver217 avatar Apr 07 '22 03:04 ver217

Here is an example of training GPT, and this is an example config file of using ZeRO, TP and PP together.

ver217 avatar Apr 07 '22 03:04 ver217

Combining ZeRO2 with PP is not mechanistically efficient. ZeRO2 has to split the gradients, but PP has to accumulate the gradients, so there's no real performance boost, it's actually slower and memory inefficient than normal ZeRO2 (without PP). And If the user runs out of memory, the user can use ZeRO3 instead of ZeRO2 + PP.

I'm saying it's mechanistically inefficient to use these two algorithms together. I've been testing this for a year. And it won't be easy to improve the efficiency of using the two together. To do that, one of the two mechanisms would have to change in a completely new way.

That's why DeepSpeed prevents PP and ZeRO2 from being used together. Not because they can't implement it, but because it's ineffective to do so. https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/runtime/pipe/engine.py#L72 If you want to allow ZeRO + PP, I think it is reasonable to limit this to stage 1.

hyunwoongko avatar Apr 07 '22 04:04 hyunwoongko

I also don't think it makes sense for Colossal AI to use the name of ShardedModel. Because for ZeRO1 and 2 we don't actually split the model. This name only seems reasonable on ZeRO3. What do you think? Of course, all of this is just advice, so feel free to do whatever you guys want :)

hyunwoongko avatar Apr 07 '22 04:04 hyunwoongko

I also don't think it makes sense for Colossal AI to use the name of ShardedModel. Because for ZeRO1 and 2 we don't actually split the model. This name only seems reasonable on ZeRO3. What do you think? Of course, all of this is just advice, so feel free to do whatever you want :)

Thanks for your advice, maybe ShardableModel is more appropriate.

feifeibear avatar Apr 07 '22 04:04 feifeibear

Combining ZeRO2 with PP is not mechanistically efficient. ZeRO2 has to split the gradients, but PP has to accumulate the gradients, so there's no real performance boost, it's actually slower and memory inefficient than normal ZeRO2 (without PP).

@hyunwoongko Hi, I was wondering if it's possible to implement PP with ZeRO-2/3.

Say we have a model of 48L(L for Layers) to train with 2DPx4PP, which means each PP stage gets 2 devices and 12L duplicated on these 2 devs. Now I want to use torch-FSDP(which is equal to ZeRO-3 as I know) to implement the 2DP by splitting these 12L into 3 FSDP Units(each wraps 4L).

With the placement above, when dev0 forward the model, it will allgather parameters for the 1st FSDP Unit(4L), do compute, discard the parameters and then move to the 2nd FSDP Unit(4L)。 So, at any point in time, dev0 only materializes parameters/grads only for 4L instead of 12L, thus reducing the memory peak usage.

Also, For every micro_batch computed(fw&bw) by dev0, it will allgather the weight and allreduce the grad, which leads to many replicated communications(for every micro_batch of a batch in an iteration). but all these communications can be overlapped with computations if we schedule them carefully.

Of course, the scheduling algorithm would be much more complicated for Pipeline with FSDP(ZeRO-3), but as a result of using FSDPxPP, we reduce the mem usage without communication overhead(the communication can be overlapped), compared with TPxPP (the communication cannot be overlapped).

Dounm avatar Jul 27 '22 09:07 Dounm

We have updated a lot. This issue was closed due to inactivity. Thanks.

binmakeswell avatar Apr 13 '23 03:04 binmakeswell