spikingjelly
spikingjelly copied to clipboard
脉冲梯度的处理
二值脉冲本应用bool类型表示,能够极大的节省显存。但由于PyTorch的bool类型不支持携带梯度,需要设法重写梯度的机制。
需要先测试一下是不是真的节省了显存。。
在这片文章“LISNN: Improving Spiking Neural Networks with Lateral Interactions for Robust Object Recognition”的源码里面,他们的LIFnode仅仅写了梯度前向和后向机制,速度的确快了很多,我觉得这个应该是可以加快速度的。
另外求问,单层全连接层+LIF效果的确不错,但是两层以上(比如conv_fashion_mnist.py)的话,貌似由于各层脉冲频率发射的问题,梯度更新没有效果,训练准确率不到20%,请问这里的超参是怎么调的呢?尝试使用了默认值发现没有效果。
@ZulunZhu 他们的实现方式和我们的类似,实测速度更快吗 https://github.com/Delver-of-Squeakrets/LISNN/blob/a099a9c0d7f20073cde4b368cd3de4a4d2df30ad/LISNN.py#L7
除开他的contribution之一,加了个lateral interaction(邻居信息),他们的SNN跟spikingjelly是一样的(他们少了个reset过程),fire的过程以及不可导的问题分别用forward和backward实现,可能是少了LIF layer层,少了很多变量的缘故,确实要快一些
框架里的Python版本神经元用的是nn.Module而不是LISNN代码里的torch.autograd.Function来写神经元的状态更新,这两个似乎性能区别不大?我记得早期SpikingFlow版本是类似forward和backward分开写的。 @fangwei123456
from spikingjelly.clock_driven import surrogate, neuron
import torch
from spikingjelly.cext import cal_fun_t
device = 'cuda:0'
T = 64
N = 128
x = torch.rand([T, N], requires_grad=True, device=device)
thresh, lens, decay = (1.0, 0.5, 0.5)
class ActFun(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
return input.gt(thresh).float()
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
grad_input = grad_output.clone()
temp = abs(input - thresh) < lens
return grad_input * temp.float()
act_fun = ActFun.apply
def mem_update(x, mem, spike):
mem = mem * decay * (1. - spike) + x
spike = act_fun(mem)
return mem, spike
x = x.detach()
x.requires_grad_(True)
def lisnn_lif_fpbp(x):
mem = 0.
spike = 0.
y = 0.
for t in range(T):
mem, spike = mem_update(x[t], mem, spike)
y += spike
y.sum().backward()
print('LISNN', cal_fun_t(32, device, lisnn_lif_fpbp, x))
class FuseLIFNode(neuron.LIFNode):
def reset(self):
super().reset()
self.spike = 0.
def forward(self, x: torch.Tensor):
self.v = self.v / self.tau * self.spike + x
self.spike = self.surrogate_function(self.v)
return self.spike
x = x.detach()
x.requires_grad_(True)
lif = FuseLIFNode(tau=2.0, surrogate_function=act_fun)
lif.to(device)
def sj_lif_fpbp(lif, x):
y = 0.
for t in range(x.shape[0]):
y += lif(x[t])
y.sum().backward()
lif.reset()
print('SJ FuseLIF', cal_fun_t(32, device, sj_lif_fpbp, lif, x))
x = x.detach()
x.requires_grad_(True)
class SJLIFNode(neuron.LIFNode):
def neuronal_fire(self):
self.spike = self.surrogate_function(self.v)
# 阈值被包含在act_fun中了
lif = SJLIFNode(tau=2.0, surrogate_function=act_fun)
lif.to(device)
print('SJ LIF', cal_fun_t(32, device, sj_lif_fpbp, lif, x))
输出为
LISNN 0.02243775312500003
SJ FuseLIF 0.020295768750000026
SJ LIF 0.032104990624999996
@ZulunZhu @Yanqi-Chen SJ的LIF比LISNN的LIF慢一些,原因是SJ用充电放电重置三个过程描述神经元的行为。LISNN的代码,和上门代码中的FuseLIF把充电和重置合并了,操作更少,因而速度更快。
FuseLIF和LISNN的充电重置方程包含3个操作,SJ中的LIF充电+重置包含7个操作
python代码的每一个操作都会重新调用一次CUDA内核,所以操作越多速度越慢。因而后来用CUDA重写了一些神经元,所有操作都放到了一个CUDA内核。但实测单步和python的提升不大,多步的提升巨大,因而下一次更新的时候就只保留多步cuda神经元,单步的cuda神经元删除,用python足够了。
另外求问,单层全连接层+LIF效果的确不错,但是两层以上(比如conv_fashion_mnist.py)的话,貌似由于各层脉冲频率发射的问题,梯度更新没有效果,训练准确率不到20%,请问这里的超参是怎么调的呢?尝试使用了默认值发现没有效果。
网络结构是什么样的
就是运行conv_fashion_mnist.py,原来是我没有pull,更新之后正常了,而且注意到了前面的层只能用IF,用LIF的话会梯度消失
一直有个疑惑想请问作者,泊松编码那部分,明明是根据数值大小来产生0和1,是个二值伯努利分布,为什么都叫做泊松编码呢?
一直有个疑惑想请问作者,泊松编码那部分,明明是根据数值大小来产生0和1,是个二值伯努利分布,为什么都叫做泊松编码呢?
看这个issue https://github.com/fangwei123456/spikingjelly/issues/81
您好!我现在遇到了一个问题,当我自定义的神经元不继承 MemoryModule类,直接继承nn.Module时,在运行下述代码时
总是出现 错误:RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward. 改到崩溃也无法解决,这个问题该怎么解决呢?
@AiHaiHai 在第n次训练时,需要对神经元进行reset,清除n-1次的状态。给你的神经元加一个名称为reset的函数,每次优化器setp后对整个网络reset:
https://github.com/fangwei123456/spikingjelly/blob/01e285fe907c22ee310449663c5e7ae24f41fce8/spikingjelly/clock_driven/functional.py#L7
https://github.com/fangwei123456/spikingjelly/blob/01e285fe907c22ee310449663c5e7ae24f41fce8/spikingjelly/clock_driven/examples/lif_fc_mnist.py#L138
@AiHaiHai在第n次训练时,需要对神经元进行重置,清除给n-1次的状态。
https://github.com/fangwei123456/spikingjelly/blob/01e285fe907c22ee310449663c5e7ae24f41fce8/spikingjelly/clock_driven/functional.py#L7
https://github.com/fangwei123456/spikingjelly/blob/01e285fe907c22ee310449663c5e7ae24f41fce8/spikingjelly/clock_driven/examples/lif_fc_mnist.py#L138
我试了一下,的确实没问题了,非常感谢。这里主要调用了该base类, https://github.com/fangwei123456/spikingjelly/blob/01e285fe907c22ee310449663c5e7ae24f41fce8/spikingjelly/clock_driven/base.py#L54 amazing,这里您说的,清除状态是清除什么状态呢,优化器step后对整个网络reset,具体是执行了什么操作,为什么调用该类就可以正常运行,不执行就会报错呢?
对整个网络reset,只需要放在下一次迭代之前就行。因为训练时,一般优化器step后就是下一次迭代,所以通常放到优化器step后。测试的时候不需要优化器,就直接放到统计完指标后:
https://github.com/fangwei123456/spikingjelly/blob/01e285fe907c22ee310449663c5e7ae24f41fce8/spikingjelly/clock_driven/examples/lif_fc_mnist.py#L167
具体是执行了什么操作,为什么调用该类就可以正常运行,不执行就会报错呢?
隐藏状态,例如神经元的v,保存在神经元内部。reset会将隐藏状态重置为默认值(例如0)。如果不reset,第n次训练的时候就会用n-1次训练完后的神经元的v作为初始电压。这就会让第n次训练和第n-1次训练的计算图连在了一起,而第n-1次计算图在backward后就被销毁了;此外这2个计算图实际上没有关系,也不应该连在一起,因为第n次训练,神经元的初始状态不应该是n-1次训练神经元的最终状态。