wetectron icon indicating copy to clipboard operation
wetectron copied to clipboard

Sequential backprop impl sketch

Open vadimkantorov opened this issue 4 years ago • 1 comments

Should something like below work for wrapping ResNet's last layer (Neck)? (https://gist.github.com/vadimkantorov/67fe785ed0bf31727af29a3584b87be1)

import torch
import torch.nn as nn

class SequentialBackprop(nn.Module):
    def __init__(self, module, batch_size = 1):
        super().__init__()
        self.module = module
        self.batch_size = batch_size

    def forward(self, x):
        y = self.module(x.detach())
        return self.Function.apply(x, y, self.batch_size, self.module)

    class Function(torch.autograd.Function):
        @staticmethod
        def forward(ctx, x, y, batch_size, module):
            ctx.save_for_backward(x)
            ctx.batch_size = batch_size
            ctx.module = module
            return y

        @staticmethod
        def backward(ctx, grad_output):
            (x,) = ctx.saved_tensors
            grads = []
            for x_mini, g_mini in zip(x.split(ctx.batch_size), grad_output.split(ctx.batch_size)):
                with torch.enable_grad():
                    x_mini = x_mini.detach().requires_grad_()
                    x_mini.retain_grad()
                    y_mini = ctx.module(x_mini)
                torch.autograd.backward(y_mini, g_mini)
                grads.append(x_mini.grad)
            return torch.cat(grads), None, None, None

if __name__ == '__main__':
    backbone = nn.Linear(3, 6)
    neck = nn.Linear(6, 12)
    head = nn.Linear(12, 1)

    model = nn.Sequential(backbone, SequentialBackprop(neck, batch_size = 16), head)

    print('before', neck.weight.grad)

    x = torch.rand(512, 3)
    model(x).sum().backward()
    print('after', neck.weight.grad)

vadimkantorov avatar Nov 10 '21 10:11 vadimkantorov

Hello vadimkantorov! I'm trying to implement this module recently, wondering whether your SBP code can work or not? Or there may need more modification? Would be grateful if any help is provided!

ck6698000 avatar Apr 07 '22 02:04 ck6698000