BrainPy icon indicating copy to clipboard operation
BrainPy copied to clipboard

Questions about backpropagation through delay variables

Open CloudyDory opened this issue 1 year ago • 5 comments

In #626 we mentioned that the rotation method in delay variables does not implement an autograd functionality. However I have tested this in training and found that the parameters can be trained normally. Is there a misunderstanding on the issue?

import functools
import numpy as np
import jax
import brainpy as bp
import brainpy.math as bm

snn_latency = 20
dt = 1.0

bm.clear_buffer_memory()
bm.set(float_=bm.float32)
bm.set_platform('cpu')
bm.set_dt(dt)

#%% Network definition
class Network(bp.DynSysGroup):
    def __init__(self):
        super().__init__()
        self.neu = bp.dyn.Lif(size=2, V_rest=0.0, V_reset=0.0, V_th=1.0, spk_fun=bm.surrogate.Arctan())
        self.delay_len = 2
        self.spike_buffer = bm.LengthDelay(self.neu.spike, delay_len=self.delay_len, update_method='rotation')
        self.weight = bm.TrainVar(bm.random.randn(2,2))
        self.bias = bm.TrainVar(bm.random.randn(2))
    
    def reset_state(self, *args):
        self.neu.reset_state(self.neu.mode)
        self.spike_buffer.reset(self.neu.spike, delay_len=self.delay_len)
    
    def update(self, data):
        spike = self.neu(self.weight @ data + self.bias)  # [batch, 2] 
        self.spike_buffer.update(spike)
        spike_delay = self.spike_buffer.retrieve(self.delay_len)  # [batch, 2] 
        return spike_delay

#%% Create network and fake data
print('Creating network... ')
with bm.training_environment():
    model = Network()
    optimizer = bp.optim.Adam(lr=1.0, train_vars=model.train_vars().unique())

print('Creating data... ')
train_data = np.concatenate([np.random.randn(100, 2) + np.array([[-1,-1]]),
                             np.random.randn(100, 2) + np.array([[ 1, 1]])], axis=0)  # [batch, 2]
train_label = bm.concatenate([bm.zeros(100, dtype=bm.int32), 
                              bm.ones(100, dtype=bm.int32)], axis=0)  # [batch]

#%% Training functions
def loss_fun(x_single, y_single):
    '''
    Inputs:
        x_single: [feature]
        y_single: [1]
    '''
    indices = np.arange(snn_latency)  # sequence length
    
    model.reset_state()
    spike = bm.for_loop(functools.partial(model.step_run, data=x_single), indices)  # [length, batch=1, 2], float32
    firerate = bm.sum(spike, axis=0) + 1.0e-6  # [batch=1, 2]
    predict = bm.log(firerate / bm.sum(firerate))  # log-prababilities, [batch=1, n_class]
    
    loss = bp.losses.nll_loss(-predict, y_single)  # scalar
    acc = bm.mean(predict.argmax(-1) == y_single)  # scalar
    return loss, acc

grad_f = bm.grad(loss_fun, grad_vars=model.train_vars().unique(), has_aux=True, return_value=True)

def grad_fun(last_grad, x_y_single):
    '''
    Inputs:
        last_grad: PyTree of gradients of each trainable parameter.
        x_y_single: tuple of ([feature], scalar), a single training sample.
    '''
    x_single, y_single = x_y_single  # [feature], scalar
    grads, loss, acc = grad_f(x_single, y_single[None])  # PyTree of gradients, scalar, scalar
    new_grad = jax.tree_map(bm.add, last_grad, grads)  # accumulate gradients
    return new_grad, (loss, acc)

@bm.jit
def train(x_batch, y_batch):
    '''
    Inputs:
        x_batch: [batch, feature]
        y_batch: [batch]
    '''
    train_vars = model.train_vars().unique()
    
    # Gradient accumulation
    grads = jax.tree_map(bm.zeros_like, train_vars)
    grads, (losses, acces) = bm.scan(grad_fun, grads, (x_batch, y_batch))  # PyTree of gradients, [batch], [batch]
    optimizer.update(grads)
    
    loss = losses.mean()  # scalar
    acc = acces.mean()    # scalar
    return loss, acc

#%% Start training
print('Start training...')
train_epochs = 10
train_loss = bm.zeros(train_epochs, dtype=bm.float_)
train_acc = bm.zeros(train_epochs, dtype=bm.float_)

for e in range(train_epochs):
    # with jax.disable_jit():
    train_loss[e], train_acc[e] = train(train_data, train_label)
    print("Epoch {}, train_loss={:.3f}, train_acc={:.2f}%".format(e, train_loss[e], train_acc[e]*100.0))
    
print('Done!')

Outputs:

Creating network... 
Creating data... 
Start training...
Epoch 0, train_loss=4.572, train_acc=50.00%
Epoch 1, train_loss=1.373, train_acc=77.00%
Epoch 2, train_loss=0.757, train_acc=87.00%
Epoch 3, train_loss=0.767, train_acc=89.50%
Epoch 4, train_loss=0.636, train_acc=91.00%
Epoch 5, train_loss=0.642, train_acc=91.00%
Epoch 6, train_loss=0.568, train_acc=89.50%
Epoch 7, train_loss=0.570, train_acc=89.50%
Epoch 8, train_loss=0.576, train_acc=90.00%
Epoch 9, train_loss=0.582, train_acc=91.00%
Done!

CloudyDory avatar Feb 23 '24 02:02 CloudyDory

Thanks for the report. The rotation mode may be fixed by sometimes before. But i will check whether the gradients is correct.

chaoming0625 avatar Feb 29 '24 06:02 chaoming0625

Thanks for the report. The rotation mode may be fixed by sometimes before. But i will check whether the gradients is correct.

This is also what I hope to know. How to check gradients in BrainPy?

CloudyDory avatar Feb 29 '24 07:02 CloudyDory

I write a simple code to check whether the gradients are the same. The answer is yes.

import functools

import jax
import numpy as np

import brainpy as bp
import brainpy.math as bm

snn_latency = 20
dt = 1.0

bm.clear_buffer_memory()
bm.set(float_=bm.float32, mode=bm.training_mode)
bm.set_platform('cpu')
bm.set_dt(dt)


# %% Network definition
class Network(bp.DynSysGroup):
  def __init__(self, method):
    super().__init__()
    self.neu = bp.dyn.Lif(size=2, V_rest=0.0, V_reset=0.0, V_th=1.0, spk_fun=bm.surrogate.Arctan())
    self.delay_len = 2
    self.spike_buffer = bm.LengthDelay(self.neu.spike, delay_len=self.delay_len, update_method=method)
    self.weight = bm.TrainVar(bm.random.randn(2, 2))
    self.bias = bm.TrainVar(bm.random.randn(2))

  def reset_state(self, *args):
    self.neu.reset_state(self.neu.mode)
    self.spike_buffer.reset(self.neu.spike, delay_len=self.delay_len)

  def update(self, data):
    spike = self.neu(self.weight @ data + self.bias)  # [batch, 2]
    self.spike_buffer.update(spike)
    spike_delay = self.spike_buffer.retrieve(self.delay_len)  # [batch, 2]
    return spike_delay


def train1(method='rotation'):
  # %% Create network and fake data
  model = Network(method)
  optimizer = bp.optim.Adam(lr=1.0, train_vars=model.train_vars().unique())

  # %% Training functions
  def loss_fun(x_single, y_single):
    '''
    Inputs:
        x_single: [feature]
        y_single: [1]
    '''
    indices = np.arange(snn_latency)  # sequence length

    model.reset_state()
    spike = bm.for_loop(functools.partial(model.step_run, data=x_single), indices)  # [length, batch=1, 2], float32
    firerate = bm.sum(spike, axis=0) + 1.0e-6  # [batch=1, 2]
    predict = bm.log(firerate / bm.sum(firerate))  # log-prababilities, [batch=1, n_class]

    loss = bp.losses.nll_loss(-predict, y_single)  # scalar
    acc = bm.mean(predict.argmax(-1) == y_single)  # scalar
    return loss, acc

  def grad_fun(last_grad, x_y_single):
    '''
    Inputs:
        last_grad: PyTree of gradients of each trainable parameter.
        x_y_single: tuple of ([feature], scalar), a single training sample.
    '''
    x_single, y_single = x_y_single  # [feature], scalar
    grad_f = bm.grad(loss_fun, grad_vars=model.train_vars().unique(), has_aux=True, return_value=True)
    grads, loss, acc = grad_f(x_single, y_single[None])  # PyTree of gradients, scalar, scalar
    new_grad = jax.tree_map(bm.add, last_grad, grads)  # accumulate gradients
    return new_grad, (loss, acc)

  @bm.jit
  def train(x_batch, y_batch):
    '''
    Inputs:
        x_batch: [batch, feature]
        y_batch: [batch]
    '''
    train_vars = model.train_vars().unique()
    # Gradient accumulation
    grads = jax.tree_map(bm.zeros_like, train_vars)
    grads, (losses, acces) = bm.scan(grad_fun, grads, (x_batch, y_batch))  # PyTree of gradients, [batch], [batch]
    optimizer.update(grads)
    return grads

  return train


train_data = np.concatenate([np.random.randn(100, 2) + np.array([[-1, -1]]),
                             np.random.randn(100, 2) + np.array([[1, 1]])], axis=0)  # [batch, 2]
train_label = bm.concatenate([bm.zeros(100, dtype=bm.int32),
                              bm.ones(100, dtype=bm.int32)], axis=0)  # [batch]

bm.random.seed(0)
bm.clear_name_cache()
f1 = train1('rotation')

bm.random.seed(0)
bm.clear_name_cache()
f2 = train1('concat')

for e in range(10):
  # with jax.disable_jit():
  grad1 = f1(train_data, train_label)
  grad2 = f2(train_data, train_label)
  print(jax.tree_map(bm.allclose, grad1, grad2))



chaoming0625 avatar Feb 29 '24 07:02 chaoming0625

Hi, I actually hope to know where are the gradient stored in BrainPy. For example, in PyTorch there is a grad field in the trained parameters which stored the gradient values. Is there a similar field in BrainPy variables?

CloudyDory avatar Feb 29 '24 08:02 CloudyDory

The gradients do not have a fixed place to store. It is only returned after the function is computed. For the following example, the gradient has stored as grads:

# "grad_vars" specify the target to compute gradients
grad_f = bm.grad(loss_fun, grad_vars=model.train_vars().unique(), has_aux=True, return_value=True)
# "grads" return as the function output
grads, loss, acc = grad_f(x_single, y_single[None])

chaoming0625 avatar Feb 29 '24 08:02 chaoming0625

Thank you very much for the information!

CloudyDory avatar Mar 04 '24 01:03 CloudyDory