spikingjelly icon indicating copy to clipboard operation
spikingjelly copied to clipboard

STDP learning rule

Open jcmharry opened this issue 3 years ago • 9 comments

Hi, the spikingjelly is really good and easy to use. One advice is do you consider integrating some Brain-like properties such as STDP in this library? Maybe it's a good way to add the advantage of spikingjelly.

jcmharry avatar Jun 09 '21 02:06 jcmharry

We have a very primary implement STDP:

https://github.com/fangwei123456/spikingjelly/blob/4b9f796da343f8da2d2a5870c817b6cea60d0650/spikingjelly/clock_driven/layer.py#L872

We are trying to write a better version. But the progress is slow, because STDP is an unsupervised method and it is hard to evaluate our impelment is "correct". @Grasshlw Do you have interest in it?

fangwei123456 avatar Jun 09 '21 06:06 fangwei123456

Hello even if I don't have much information about STDP, I want to contribute to this project, if you can help me a little or show the way I can start to implement SDTP @fangwei123456

lvntky avatar Jun 09 '21 23:06 lvntky

@lvntky Thanks! Here are some reference:

If the pre neuron fires a spike before the post neuron fires a spike, then the synaptic weight increases, and vice versa, if the post neuron fires a spike before the pre neuron fires a spike, the synaptic weight decreases. The biological experimental data are shown in the figure below, and the horizontal axis is the time difference between a pair of pulses released by pre neurons and post neurons, i.e. the vertical axis represents the percentage change in synaptic strength:t_{post}−t_{pre} image

fangwei123456 avatar Jun 10 '21 02:06 fangwei123456

This synaptic strength is related to the time of spikes and can be fitted using the following formula: image

fangwei123456 avatar Jun 10 '21 02:06 fangwei123456

According to "Morrison A, Diesmann M, Gerstner W. Phenomenological models of synaptic plasticity based on spiketiming[J]. Biological cybernetics, 2008, 98(6): 459-478.", we can use the trace, x_{j} for the pre neuron, and y_{i} for the post neuron, to implement STDP: image

where detla(t)=1 only when t=0, else delta(t) = 0. t_{i}^{f} is the firing time of a spike from pre neuron i, and t_{j}^{f} is the firing time of a spike from post neuron j. The weight is changed when a spike is fired by the pre or the post neuron: image

where F is a function to control weight's change. image

fangwei123456 avatar Jun 10 '21 02:06 fangwei123456

The following codes are implement by the trace method: https://github.com/fangwei123456/spikingjelly/blob/4b9f796da343f8da2d2a5870c817b6cea60d0650/spikingjelly/clock_driven/layer.py#L872

class STDPLearner(nn.Module):
    def __init__(self,
                 tau_pre: float, tau_post: float,
                 f_pre, f_post
                 ) -> None:
        super().__init__()
        self.tau_pre = tau_pre
        self.tau_post = tau_post
        self.trace_pre = 0
        self.trace_post = 0
        self.f_pre = f_pre
        self.f_post = f_post

    def reset(self):
        self.trace_pre = 0
        self.trace_post = 0

    @torch.no_grad()
    def stdp(self, s_pre: torch.Tensor, s_post: torch.Tensor, module: nn.Module, learning_rate: float):
        if isinstance(module, nn.Linear):
            # update trace
            self.trace_pre += - self.trace_pre / self.tau_pre + s_pre
            self.trace_post += - self.trace_post / self.tau_post + s_post

            # update weight
            delta_w_pre = self.f_pre(module.weight) * s_pre
            delta_w_post = self.f_post(module.weight) * s_post.unsqueeze(1)
            module.weight += (delta_w_pre + delta_w_post) * learning_rate
        else:
            raise NotImplementedError

You can find that we only implement STDP for nn.Linear. The implement for nn.Conv2d will be more difficult.

Here is the example:

import torch
import torch.nn as nn
from spikingjelly.clock_driven import layer, neuron, functional
from matplotlib import pyplot as plt
import numpy as np
def f_pre(x):
    return x.abs() + 0.1

def f_post(x):
    return - f_pre(x)

fc = nn.Linear(1, 1, bias=False)

stdp_learner = layer.STDPLearner(100., 100., f_pre, f_post)
trace_pre = []
trace_post = []
w = []
T = 256
s_pre = torch.zeros([T, 1])
s_post = torch.zeros([T, 1])
s_pre[0: T // 2] = (torch.rand_like(s_pre[0: T // 2]) > 0.95).float()
s_post[0: T // 2] = (torch.rand_like(s_post[0: T // 2]) > 0.9).float()

s_pre[T // 2:] = (torch.rand_like(s_pre[T // 2:]) > 0.8).float()
s_post[T // 2:] = (torch.rand_like(s_post[T // 2:]) > 0.95).float()

for t in range(T):
    stdp_learner.stdp(s_pre[t], s_post[t], fc, 1e-2)
    trace_pre.append(stdp_learner.trace_pre.item())
    trace_post.append(stdp_learner.trace_post.item())
    w.append(fc.weight.item())

plt.style.use('science')
fig = plt.figure(figsize=(10, 6))
s_pre = s_pre[:, 0].numpy()
s_post = s_post[:, 0].numpy()
t = np.arange(0, T)
plt.subplot(5, 1, 1)
plt.eventplot((t * s_pre)[s_pre == 1.], lineoffsets=0, colors='r')
plt.yticks([])
plt.ylabel('$S_{pre}$', rotation=0, labelpad=10)
plt.xticks([])
plt.xlim(0, T)
plt.subplot(5, 1, 2)
plt.plot(t, trace_pre)
plt.ylabel('$tr_{pre}$', rotation=0, labelpad=10)
plt.xticks([])
plt.xlim(0, T)

plt.subplot(5, 1, 3)
plt.eventplot((t * s_post)[s_post == 1.], lineoffsets=0, colors='r')
plt.yticks([])
plt.ylabel('$S_{post}$', rotation=0, labelpad=10)
plt.xticks([])
plt.xlim(0, T)
plt.subplot(5, 1, 4)
plt.plot(t, trace_post)
plt.ylabel('$tr_{post}$', rotation=0, labelpad=10)
plt.xticks([])
plt.xlim(0, T)
plt.subplot(5, 1, 5)
plt.plot(t, w)
plt.ylabel('$w$', rotation=0, labelpad=10)
plt.xlim(0, T)

plt.show()

image

fangwei123456 avatar Jun 10 '21 02:06 fangwei123456

#15

fangwei123456 avatar Jun 10 '21 02:06 fangwei123456

okay @fangwei123456 thank you very much for this great reference, I forked the repo and started to improve as I can, I will open a PR if I can make anything valuable. Also is there any discord channel, or email that I can talk with you about the progress of code or should I just write here?

lvntky avatar Jun 10 '21 11:06 lvntky

@lvntky Thanks! I suggest that you can just write here, which is convenient for other developers to join!

fangwei123456 avatar Jun 10 '21 14:06 fangwei123456

@lvntky @jcmharry Hi, STDP is provided now:

https://spikingjelly.readthedocs.io/zh_CN/latest/activation_based_en/stdp.html

fangwei123456 avatar Aug 31 '22 13:08 fangwei123456