BrainPy
BrainPy copied to clipboard
Support veterization and parallerization for class objects
Please:
- [x] Check for duplicate requests.
- [x] Describe your goal, and if possible provide a code snippet with a motivating example.
If instances of brainpy.math.Variable
in a class is mapped, then there is no need to write a vmap
for class objects, because jax.vmap
can directly apply on it.
However, if user define a input
-based class object based on no-batch axis, Like
For example,
import brainpy as bp
import brainpy.math as bm
class A(bp.Base):
def __init__(self):
super(A, self).__init__()
self.v = bm.Variable(bm.zeros(1))
def add(self, inp):
self.v += inp
Later, if users try to put a batch of inputs, directly applying jax.vmap
may cause errors.
For this kind of model, we need process the vmapped self.v
variable, and make a summation on the mapped v
. However, once the operation is the substraction. The result of this operation is also wrong. Therefore, what do we really need for vectorization and parallerization for class objects.
It seems that mapping transformations, like vmap
and pmap
are not compatible with object-oriented programming.