spikingjelly
spikingjelly copied to clipboard
STDP learning rule
Hi, the spikingjelly is really good and easy to use. One advice is do you consider integrating some Brain-like properties such as STDP in this library? Maybe it's a good way to add the advantage of spikingjelly.
We have a very primary implement STDP:
https://github.com/fangwei123456/spikingjelly/blob/4b9f796da343f8da2d2a5870c817b6cea60d0650/spikingjelly/clock_driven/layer.py#L872
We are trying to write a better version. But the progress is slow, because STDP is an unsupervised method and it is hard to evaluate our impelment is "correct". @Grasshlw Do you have interest in it?
Hello even if I don't have much information about STDP, I want to contribute to this project, if you can help me a little or show the way I can start to implement SDTP @fangwei123456
@lvntky Thanks! Here are some reference:
If the pre neuron fires a spike before the post neuron fires a spike, then the synaptic weight increases, and vice versa, if the post neuron fires a spike before the pre neuron fires a spike, the synaptic weight decreases. The biological experimental data are shown in the figure below, and the horizontal axis is the time difference between a pair of pulses released by pre neurons and post neurons, i.e. the vertical axis represents the percentage change in synaptic strength:t_{post}−t_{pre}
This synaptic strength is related to the time of spikes and can be fitted using the following formula:
According to "Morrison A, Diesmann M, Gerstner W. Phenomenological models of synaptic plasticity based on spiketiming[J]. Biological cybernetics, 2008, 98(6): 459-478.", we can use the trace, x_{j} for the pre neuron, and y_{i} for the post neuron, to implement STDP:
where detla(t)=1 only when t=0, else delta(t) = 0. t_{i}^{f} is the firing time of a spike from pre neuron i, and t_{j}^{f} is the firing time of a spike from post neuron j. The weight is changed when a spike is fired by the pre or the post neuron:
where F is a function to control weight's change.
The following codes are implement by the trace method: https://github.com/fangwei123456/spikingjelly/blob/4b9f796da343f8da2d2a5870c817b6cea60d0650/spikingjelly/clock_driven/layer.py#L872
class STDPLearner(nn.Module):
def __init__(self,
tau_pre: float, tau_post: float,
f_pre, f_post
) -> None:
super().__init__()
self.tau_pre = tau_pre
self.tau_post = tau_post
self.trace_pre = 0
self.trace_post = 0
self.f_pre = f_pre
self.f_post = f_post
def reset(self):
self.trace_pre = 0
self.trace_post = 0
@torch.no_grad()
def stdp(self, s_pre: torch.Tensor, s_post: torch.Tensor, module: nn.Module, learning_rate: float):
if isinstance(module, nn.Linear):
# update trace
self.trace_pre += - self.trace_pre / self.tau_pre + s_pre
self.trace_post += - self.trace_post / self.tau_post + s_post
# update weight
delta_w_pre = self.f_pre(module.weight) * s_pre
delta_w_post = self.f_post(module.weight) * s_post.unsqueeze(1)
module.weight += (delta_w_pre + delta_w_post) * learning_rate
else:
raise NotImplementedError
You can find that we only implement STDP for nn.Linear
. The implement for nn.Conv2d
will be more difficult.
Here is the example:
import torch
import torch.nn as nn
from spikingjelly.clock_driven import layer, neuron, functional
from matplotlib import pyplot as plt
import numpy as np
def f_pre(x):
return x.abs() + 0.1
def f_post(x):
return - f_pre(x)
fc = nn.Linear(1, 1, bias=False)
stdp_learner = layer.STDPLearner(100., 100., f_pre, f_post)
trace_pre = []
trace_post = []
w = []
T = 256
s_pre = torch.zeros([T, 1])
s_post = torch.zeros([T, 1])
s_pre[0: T // 2] = (torch.rand_like(s_pre[0: T // 2]) > 0.95).float()
s_post[0: T // 2] = (torch.rand_like(s_post[0: T // 2]) > 0.9).float()
s_pre[T // 2:] = (torch.rand_like(s_pre[T // 2:]) > 0.8).float()
s_post[T // 2:] = (torch.rand_like(s_post[T // 2:]) > 0.95).float()
for t in range(T):
stdp_learner.stdp(s_pre[t], s_post[t], fc, 1e-2)
trace_pre.append(stdp_learner.trace_pre.item())
trace_post.append(stdp_learner.trace_post.item())
w.append(fc.weight.item())
plt.style.use('science')
fig = plt.figure(figsize=(10, 6))
s_pre = s_pre[:, 0].numpy()
s_post = s_post[:, 0].numpy()
t = np.arange(0, T)
plt.subplot(5, 1, 1)
plt.eventplot((t * s_pre)[s_pre == 1.], lineoffsets=0, colors='r')
plt.yticks([])
plt.ylabel('$S_{pre}$', rotation=0, labelpad=10)
plt.xticks([])
plt.xlim(0, T)
plt.subplot(5, 1, 2)
plt.plot(t, trace_pre)
plt.ylabel('$tr_{pre}$', rotation=0, labelpad=10)
plt.xticks([])
plt.xlim(0, T)
plt.subplot(5, 1, 3)
plt.eventplot((t * s_post)[s_post == 1.], lineoffsets=0, colors='r')
plt.yticks([])
plt.ylabel('$S_{post}$', rotation=0, labelpad=10)
plt.xticks([])
plt.xlim(0, T)
plt.subplot(5, 1, 4)
plt.plot(t, trace_post)
plt.ylabel('$tr_{post}$', rotation=0, labelpad=10)
plt.xticks([])
plt.xlim(0, T)
plt.subplot(5, 1, 5)
plt.plot(t, w)
plt.ylabel('$w$', rotation=0, labelpad=10)
plt.xlim(0, T)
plt.show()
#15
okay @fangwei123456 thank you very much for this great reference, I forked the repo and started to improve as I can, I will open a PR if I can make anything valuable. Also is there any discord channel, or email that I can talk with you about the progress of code or should I just write here?
@lvntky Thanks! I suggest that you can just write here, which is convenient for other developers to join!
@lvntky @jcmharry Hi, STDP is provided now:
https://spikingjelly.readthedocs.io/zh_CN/latest/activation_based_en/stdp.html