spikingjelly
spikingjelly copied to clipboard
STDP学习器step函数存在数值类型异常
STDP学习器step函数存在数值类型异常
- 在使用STDP学习器进行训练时,遇到了张量和None进行减法运算的异常,源码报错来自:
self.synapse.weight.grad = self.synapse.weight.grad - delta_w
,在以下完整代码中,length = self.in_spike_monitor.records.__len__()
值为0,代表本轮并未接收到脉冲信号,即不会进入for _ in range(length):
循环,此时delta_w
恒为None,但on_grad
且self.synapse.weight.grad
不为None是,就会出现数据类型异常,张量在和None做减法。
def step(self, on_grad: bool = True, scale: float = 1.):
length = self.in_spike_monitor.records.__len__()
delta_w = None
if self.step_mode == 's':
if isinstance(self.synapse, nn.Linear):
stdp_f = stdp_linear_single_step
elif isinstance(self.synapse, nn.Conv2d):
stdp_f = stdp_conv2d_single_step
elif isinstance(self.synapse, nn.Conv1d):
stdp_f = stdp_conv1d_single_step
else:
raise NotImplementedError(self.synapse)
elif self.step_mode == 'm':
if isinstance(self.synapse, (nn.Linear, nn.Conv1d, nn.Conv2d)):
stdp_f = stdp_multi_step
else:
raise NotImplementedError(self.synapse)
else:
raise ValueError(self.step_mode)
for _ in range(length):
in_spike = self.in_spike_monitor.records.pop(0) # [batch_size, N_in]
out_spike = self.out_spike_monitor.records.pop(0) # [batch_size, N_out]
self.trace_pre, self.trace_post, dw = stdp_f(
self.synapse, in_spike, out_spike,
self.trace_pre, self.trace_post,
self.tau_pre, self.tau_post,
self.f_pre, self.f_post
)
if scale != 1.:
dw *= scale
delta_w = dw if (delta_w is None) else (delta_w + dw)
if on_grad:
if self.synapse.weight.grad is None:
self.synapse.weight.grad = -delta_w
else:
self.synapse.weight.grad = self.synapse.weight.grad - delta_w
# if delta_w is not None:
# self.synapse.weight.grad = self.synapse.weight.grad - delta_w
else:
return delta_w
我的处理方式是加入:
if delta_w is not None:
self.synapse.weight.grad = self.synapse.weight.grad - delta_w
此时可以解决数据类型异常,但训练loss不变,请问如何解决?
这个问题比较奇怪,因为监视器实际上不会管数据是否为脉冲,都完整的记录下来。监视器没有记录到数据,可能是被监视的那个层实际上没有参与网络的计算?
另外训练loss不变是很正常的,stdp是非常弱的无监督学习器,如果很容易就能调出好的性能,那就是重大科学突破了