BrainPy icon indicating copy to clipboard operation
BrainPy copied to clipboard

Support veterization and parallerization for class objects

Open chaoming0625 opened this issue 2 years ago • 1 comments

Please:

  • [x] Check for duplicate requests.
  • [x] Describe your goal, and if possible provide a code snippet with a motivating example.

If instances of brainpy.math.Variable in a class is mapped, then there is no need to write a vmap for class objects, because jax.vmap can directly apply on it.

However, if user define a input-based class object based on no-batch axis, Like

For example,

import brainpy as bp
import brainpy.math as bm

class A(bp.Base):
  def __init__(self):
       super(A, self).__init__()
       self.v = bm.Variable(bm.zeros(1))
  
  def add(self, inp):
       self.v += inp

Later, if users try to put a batch of inputs, directly applying jax.vmap may cause errors.

For this kind of model, we need process the vmapped self.v variable, and make a summation on the mapped v. However, once the operation is the substraction. The result of this operation is also wrong. Therefore, what do we really need for vectorization and parallerization for class objects.

chaoming0625 avatar May 14 '22 15:05 chaoming0625

It seems that mapping transformations, like vmap and pmap are not compatible with object-oriented programming.

chaoming0625 avatar May 14 '22 15:05 chaoming0625