spikingjelly
spikingjelly copied to clipboard
How to obtain the surrogate gradient of loss function w.r.t. inputs in SNNs?
Issue type
- [ ] Bug Report
- [ ] Feature Request
- [x] Help wanted
- [ ] Other
Description
I would like to calculate the gradient of loss function with respect to inputs, with the help of surrogate gradient method. However, I encountered an error message: RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
.
My SNN model is a reproduction of the [Convolutional SNN to classify FMNIST], which uses direct coding.
I have tested that the same code can calculate the gradient, in a structurally equivalent ANN.
I would greatly appreciate if you could give me some advices or reference codes about input gradient calculation in SNNs.
Thank you in advance!
Minimal code to reproduce the error/bug
When I call loss.backward()
in the following code, I get the error RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
class CSNN(nn.Module):
def __init__(self, T: int, channels: int, use_cupy=False):
super().__init__()
self.T = T
self.conv_fc = nn.Sequential(
layer.Conv2d(1, channels, kernel_size=3, padding=1, bias=False),
layer.BatchNorm2d(channels),
neuron.IFNode(surrogate_function=surrogate.ATan()),
layer.MaxPool2d(2, 2), # 14 * 14
layer.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False),
layer.BatchNorm2d(channels),
neuron.IFNode(surrogate_function=surrogate.ATan()),
layer.MaxPool2d(2, 2), # 7 * 7
layer.Flatten(),
layer.Linear(channels * 7 * 7, channels * 4 * 4, bias=False),
neuron.IFNode(surrogate_function=surrogate.ATan()),
layer.Linear(channels * 4 * 4, 10, bias=False),
neuron.IFNode(surrogate_function=surrogate.ATan()),
)
functional.set_step_mode(self, step_mode='m')
if use_cupy:
functional.set_backend(self, backend='cupy')
def forward(self, x: torch.Tensor):
# x.shape = [N, C, H, W]
x_seq = x.unsqueeze(0).repeat(self.T, 1, 1, 1, 1) # [N, C, H, W] -> [T, N, C, H, W]
x_seq = self.conv_fc(x_seq)
fr = x_seq.mean(0)
return fr
print("Start loading network.")
device = 'cuda:0'
model = CSNN(T=4, channels=128, use_cupy=True)
model.to(device)
checkpoint_path = "/root/InputGrad/logs/T4_b128_sgd_lr0.1_c128_amp_cupy/checkpoint_max.pth"
checkpoint = torch.load(checkpoint_path, map_location='cpu')
model.load_state_dict(checkpoint['net'])
model.eval()
print("Finish loading network.")
test_set = torchvision.datasets.FashionMNIST(
root="/root/InputGrad/",
train=False,
transform=torchvision.transforms.ToTensor(),
download=True)
test_data_loader = torch.utils.data.DataLoader(
dataset=test_set,
batch_size=1,
shuffle=False,
drop_last=False,
num_workers=8,
pin_memory=True
)
for img, label in test_data_loader:
img = img.to(device)
label = label.to(device)
img.requires_grad = True
opt = torch.optim.SGD([img], lr=1e-3)
opt.zero_grad()
loss = nn.CrossEntropyLoss()(model(img), label)
print("loss", loss) # loss tensor(1.4612, device='cuda:0')
loss.backward() # RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
functional.reset_net(model)
With the following code, the gradient w.r.t input can be successfully calculated in a structurally equivalent ANN.
class CNN(nn.Module):
def __init__(self, channels: int):
super().__init__()
self.conv_fc = nn.Sequential(
nn.Conv2d(1, channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(channels),
nn.ReLU(),
nn.MaxPool2d(2, 2), # 14 * 14
nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(channels),
nn.ReLU(),
nn.MaxPool2d(2, 2), # 7 * 7
nn.Flatten(),
nn.Linear(channels * 7 * 7, channels * 4 * 4, bias=False),
nn.ReLU(),
nn.Linear(channels * 4 * 4, 10, bias=False),
)
def forward(self, x: torch.Tensor):
x = self.conv_fc(x)
return x
print("Start loading network.")
device = 'cuda:0'
model = CNN(channels=128)
model.to(device)
checkpoint_path = "/root/InputGrad/logs/b128_sgd_lr0.1_c128/checkpoint_max.pth"
checkpoint = torch.load(checkpoint_path, map_location='cpu')
model.load_state_dict(checkpoint['net'])
model.eval()
print("Finish loading network.")
for img, label in test_data_loader:
img = img.to(device)
label = label.to(device)
img.requires_grad = True
opt = torch.optim.SGD([img], lr=1e-3)
opt.zero_grad()
loss = nn.CrossEntropyLoss()(model(img), label)
print("loss", loss) # tensor(1.5077, device='cuda:0', grad_fn=<NllLossBackward0>)
loss.backward()
print(img.grad.data.shape) # torch.Size([1, 1, 28, 28])
break
Try to modify model.eval()
to model.train()
. Some spiking neurons will have different forward functions (sometimes they are runed in no grad mode) in eval mode.
Thank you for your quick reply!
The problem is solved with model.train()
.