spikingjelly icon indicating copy to clipboard operation
spikingjelly copied to clipboard

How to obtain the surrogate gradient of loss function w.r.t. inputs in SNNs?

Open tianyunzhe opened this issue 1 year ago • 2 comments

Issue type

  • [ ] Bug Report
  • [ ] Feature Request
  • [x] Help wanted
  • [ ] Other

Description

I would like to calculate the gradient of loss function with respect to inputs, with the help of surrogate gradient method. However, I encountered an error message: RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn.

My SNN model is a reproduction of the [Convolutional SNN to classify FMNIST], which uses direct coding.

I have tested that the same code can calculate the gradient, in a structurally equivalent ANN.

I would greatly appreciate if you could give me some advices or reference codes about input gradient calculation in SNNs.

Thank you in advance!

Minimal code to reproduce the error/bug

When I call loss.backward() in the following code, I get the error RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

class CSNN(nn.Module):
    def __init__(self, T: int, channels: int, use_cupy=False):
        super().__init__()
        self.T = T

        self.conv_fc = nn.Sequential(
        layer.Conv2d(1, channels, kernel_size=3, padding=1, bias=False),
        layer.BatchNorm2d(channels),
        neuron.IFNode(surrogate_function=surrogate.ATan()),
        layer.MaxPool2d(2, 2),  # 14 * 14

        layer.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False),
        layer.BatchNorm2d(channels),
        neuron.IFNode(surrogate_function=surrogate.ATan()),
        layer.MaxPool2d(2, 2),  # 7 * 7

        layer.Flatten(),
        layer.Linear(channels * 7 * 7, channels * 4 * 4, bias=False),
        neuron.IFNode(surrogate_function=surrogate.ATan()),

        layer.Linear(channels * 4 * 4, 10, bias=False),
        neuron.IFNode(surrogate_function=surrogate.ATan()),
        )

        functional.set_step_mode(self, step_mode='m')

        if use_cupy:
            functional.set_backend(self, backend='cupy')

    def forward(self, x: torch.Tensor):
        # x.shape = [N, C, H, W]
        x_seq = x.unsqueeze(0).repeat(self.T, 1, 1, 1, 1)  # [N, C, H, W] -> [T, N, C, H, W]
        x_seq = self.conv_fc(x_seq)
        fr = x_seq.mean(0)
        return fr

print("Start loading network.")
device = 'cuda:0'

model = CSNN(T=4, channels=128, use_cupy=True)
model.to(device)

checkpoint_path = "/root/InputGrad/logs/T4_b128_sgd_lr0.1_c128_amp_cupy/checkpoint_max.pth"
checkpoint = torch.load(checkpoint_path, map_location='cpu')
model.load_state_dict(checkpoint['net'])

model.eval()
print("Finish loading network.")

test_set = torchvision.datasets.FashionMNIST(
        root="/root/InputGrad/",
        train=False,
        transform=torchvision.transforms.ToTensor(),
        download=True)

test_data_loader = torch.utils.data.DataLoader(
    dataset=test_set,
    batch_size=1,
    shuffle=False,
    drop_last=False,
    num_workers=8,
    pin_memory=True
)

for img, label in test_data_loader:
    img = img.to(device)
    label = label.to(device)
    
    img.requires_grad = True

    opt = torch.optim.SGD([img], lr=1e-3)
    opt.zero_grad()

    loss = nn.CrossEntropyLoss()(model(img), label)
    print("loss", loss)  # loss tensor(1.4612, device='cuda:0')
    loss.backward()  # RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

    functional.reset_net(model)

With the following code, the gradient w.r.t input can be successfully calculated in a structurally equivalent ANN.

class CNN(nn.Module):
    def __init__(self, channels: int):
        super().__init__()
        
        self.conv_fc = nn.Sequential(
        nn.Conv2d(1, channels, kernel_size=3, padding=1, bias=False),
        nn.BatchNorm2d(channels),
        nn.ReLU(),
        nn.MaxPool2d(2, 2),  # 14 * 14

        nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False),
        nn.BatchNorm2d(channels),
        nn.ReLU(),
        nn.MaxPool2d(2, 2),  # 7 * 7

        nn.Flatten(),
        nn.Linear(channels * 7 * 7, channels * 4 * 4, bias=False),
        nn.ReLU(),
        nn.Linear(channels * 4 * 4, 10, bias=False),
        )

    def forward(self, x: torch.Tensor):
        x = self.conv_fc(x)
        return x
    
print("Start loading network.")
device = 'cuda:0'

model = CNN(channels=128)
model.to(device)

checkpoint_path = "/root/InputGrad/logs/b128_sgd_lr0.1_c128/checkpoint_max.pth"
checkpoint = torch.load(checkpoint_path, map_location='cpu')
model.load_state_dict(checkpoint['net'])

model.eval()
print("Finish loading network.")

for img, label in test_data_loader:
    img = img.to(device)
    label = label.to(device)
    
    img.requires_grad = True

    opt = torch.optim.SGD([img], lr=1e-3)
    opt.zero_grad()

    loss = nn.CrossEntropyLoss()(model(img), label)
    print("loss", loss)  # tensor(1.5077, device='cuda:0', grad_fn=<NllLossBackward0>)
    loss.backward()
    print(img.grad.data.shape)  # torch.Size([1, 1, 28, 28])
    break

tianyunzhe avatar Jun 06 '23 07:06 tianyunzhe

Try to modify model.eval() to model.train(). Some spiking neurons will have different forward functions (sometimes they are runed in no grad mode) in eval mode.

fangwei123456 avatar Jun 06 '23 07:06 fangwei123456

Thank you for your quick reply!

The problem is solved with model.train().

tianyunzhe avatar Jun 06 '23 07:06 tianyunzhe