wetectron
wetectron copied to clipboard
Sequential backprop impl sketch
Should something like below work for wrapping ResNet's last layer (Neck)? (https://gist.github.com/vadimkantorov/67fe785ed0bf31727af29a3584b87be1)
import torch
import torch.nn as nn
class SequentialBackprop(nn.Module):
def __init__(self, module, batch_size = 1):
super().__init__()
self.module = module
self.batch_size = batch_size
def forward(self, x):
y = self.module(x.detach())
return self.Function.apply(x, y, self.batch_size, self.module)
class Function(torch.autograd.Function):
@staticmethod
def forward(ctx, x, y, batch_size, module):
ctx.save_for_backward(x)
ctx.batch_size = batch_size
ctx.module = module
return y
@staticmethod
def backward(ctx, grad_output):
(x,) = ctx.saved_tensors
grads = []
for x_mini, g_mini in zip(x.split(ctx.batch_size), grad_output.split(ctx.batch_size)):
with torch.enable_grad():
x_mini = x_mini.detach().requires_grad_()
x_mini.retain_grad()
y_mini = ctx.module(x_mini)
torch.autograd.backward(y_mini, g_mini)
grads.append(x_mini.grad)
return torch.cat(grads), None, None, None
if __name__ == '__main__':
backbone = nn.Linear(3, 6)
neck = nn.Linear(6, 12)
head = nn.Linear(12, 1)
model = nn.Sequential(backbone, SequentialBackprop(neck, batch_size = 16), head)
print('before', neck.weight.grad)
x = torch.rand(512, 3)
model(x).sum().backward()
print('after', neck.weight.grad)
Hello vadimkantorov! I'm trying to implement this module recently, wondering whether your SBP code can work or not? Or there may need more modification? Would be grateful if any help is provided!