DeepSpeed icon indicating copy to clipboard operation
DeepSpeed copied to clipboard

[BUG] Zero3: Post backward hook is not triggered for submodules whose inputs have .required_grad=False

Open deepcharm opened this issue 1 year ago • 4 comments
trafficstars

Describe the bug The mechanism of pre-backward and post-backward hooks employs adding a custom autograd function class on tensors, which are either inputs to the module (for post-backward) or outputs of the module (for pre-backward).

When the forward method of the post-backward function is invoked, it saves the module and counts the number of input tensors.

Consequently, when its backward method is invoked, the counter decreases for each tensor, and once it reaches zero, the actual post backward processing routine is invoked. The main purpose of that routine being the release of the previously materialized module parameters.

The above mechanism works for all the modules in a model, except for those whose inputs have .requires_grad being False. Typically, these are the very first modules in the model.

Since, no gradient calculation is required for such inputs, the backward method of the above custom autograd function is NOT called.

image

As a result, the release_submodule is not called for those modules, causing memory being not released (and potentially not cleaning the params state correctly).

For example, the BERT model has 3 Embedding modules of significant size (> GB of memory) who directly receive their inputs from a dataloader. The release_submodule will not be called for these modules in the current design, causing a memory peak.

The same would happen for ANY module whose inputs have .requires_grad False and not necessarily the very first modules.

To Reproduce This can be easily reproduced on any model, such as below. The submodules linear0_0 and linear0_1 of the model MyModel are receiving inputs directly. The last submodule linear1 is receiving inputs from the first 2 layers.

class MyModel(torch.nn.Module):
  def __init__(self, D_in, H, D_out):
    super().__init__()
    self.linear0_0 = torch.nn.Linear(D_in, H)
    self.linear0_1 = torch.nn.Linear(D_in, H)  
    self.linear1 = torch.nn.Linear(H, D_out)

  def forward(self, x):
    y = torch.add(self.linear0_0(x), self.linear0_1(x)).clamp(min=0)
    y = self.linear1(y)
    return y

One can observe (by adding appropriate debug prints), that in the backward pass release_submodule is not invoked for the submodules linear0_0 and linear0_1, while it is invoked as expected for the submodule linear1.

deepcharm avatar May 12 '24 12:05 deepcharm

A brutal force solution is to enforce the .requires_grad to be True for the model input tensors:

        class PostBackwardFunctionModule(torch.autograd.Function):

            @staticmethod
            def forward(ctx, output):
                ctx.module = module

                if not output.requires_grad:
                    output.requires_grad_(requires_grad=True)
                    output.mark_as_no_grad = True

The .requires_grad value can be then restored to its original in the PostBackwardFunctionModule::backward. This method works, but seems to be hacky and may introduce some unexpected changes in the torch autograd mechanism.

deepcharm avatar May 12 '24 12:05 deepcharm

@tjruwase are you familiar with this issue?

nelyahu avatar Jul 25 '24 12:07 nelyahu

@nelyahu, I was unaware, so thanks for bring this to my attention.

tjruwase avatar Jul 25 '24 13:07 tjruwase

I have the same issue, but @deepcharm doesn't work for me if inputs are IntTensor which is required for Embedding modules. Do you have any better workaround?

mksit avatar Oct 05 '24 16:10 mksit