equinox
equinox copied to clipboard
using batchnorm/dropout layers from flax.linen along with diffrax package
Hi Patrick,
I am using the diffrax ode solver in my code and will need to use batchnorm/dropout layers in the function that will be passed to the solver. However this is the error I am getting:
solution = diffrax.diffeqsolve( File "/data/ne12/.conda/envs/Negar2/lib/python3.8/site-packages/equinox/_jit.py", line 99, in call return self._call(False, args, kwargs) File "/data/ne12/.conda/envs/Negar2/lib/python3.8/site-packages/equinox/_jit.py", line 95, in _call out = self._cached(dynamic, static) File "/data/ne12/.conda/envs/Negar2/lib/python3.8/site-packages/equinox/_jit.py", line 37, in fun_wrapped out = fun(*args, **kwargs) File "/data/ne12/.conda/envs/Negar2/lib/python3.8/site-packages/diffrax/integrate.py", line 676, in diffeqsolve solver_state = solver.init(terms, t0, tnext, y0, args) File "/data/ne12/.conda/envs/Negar2/lib/python3.8/site-packages/diffrax/solver/runge_kutta.py", line 269, in init return terms.vf(t0, y0, args) File "/data/ne12/.conda/envs/Negar2/lib/python3.8/site-packages/diffrax/term.py", line 364, in vf return self.term.vf(t, y, args) File "/data/ne12/.conda/envs/Negar2/lib/python3.8/site-packages/diffrax/term.py", line 173, in vf return self.vector_field(t, y, args) File "/data/ne12/Kuramoto/model/neuralODE.py", line 56, in fn y0 = self.batchnorm(y0, use_running_average=not training) File "/data/ne12/.conda/envs/Negar2/lib/python3.8/site-packages/flax/linen/normalization.py", line 256, in call ra_mean = self.variable('batch_stats', 'mean', File "/data/ne12/.conda/envs/Negar2/lib/python3.8/site-packages/flax/core/tracers.py", line 36, in check_trace_level raise errors.JaxTransformError() flax.errors.JaxTransformError: Jax transforms and Flax models cannot be mixed. (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.JaxTransformError)
I am not using equinox but instead using flax.linen.Module to build up my Neural Network.
`class NeuralODE(nn.Module):
act_fn: Callable
actfn: Callable
node_size: int
hdims : str
batchnorm : Callable = nn.BatchNorm()
dropout : Callable = nn.Dropout(0.1)
@nn.compact
def call(self, ts, ys, training):
y, args = ys
if self.hdims!=None:
hdims_list = [int(i) for i in self.hdims.split('-')]
else:
hdims_list = []
kernels = []
first_dim = self.node_size
for i, dim in enumerate(hdims_list):
kernels.append(self.param(f'kernel{i}', nn.initializers.normal(), [dim, self.node_size, first_dim]))
first_dim = dim
kernels.append(self.param('kernel', nn.initializers.normal(), [self.node_size, self.node_size, first_dim]))
def fn(t, y, args):
y0 = y
bias, data_adj = args
if len(y0.shape) == 2:
y0 = jnp.expand_dims(y0, -1)
elif len(y0.shape) == 1:
y0 = jnp.expand_dims(jnp.expand_dims(y0, -1), 0)
for kernel in kernels:
y0 = jnp.einsum('ijk,lmj->ilm', y0, kernel)
y0 = self.act_fn(y0)
y0 = self.batchnorm(y0, use_running_average=not training)
y0 = self.dropout(y0, deterministic=not training)
if y0.shape[0] == 1:
y0 = jnp.squeeze(y0, 0)
if len(y0.shape) == 2:
out = jnp.einsum('ij,ij->ij', data_adj, y0).sum(-1)
elif len(y0.shape) == 3:
out = jnp.einsum('aij,aij->aij', data_adj, y0).sum(-1)
out = jnp.squeeze(bias, -1) - out # B*N
return out#, bias, data_adj
solution = diffrax.diffeqsolve(
diffrax.ODETerm(fn),
diffrax.Dopri5(),
t0=ts[0],
t1=ts[-1],
dt0=0.01, # ts[1] - ts[0],
y0=y,
args=args,
stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6, dtmax=0.1),
saveat=diffrax.SaveAt(ts=ts),
made_jump=True,
)
return solution.ys`
The error raises when I init the model using model.init(model_key, ts, inputs, training= False).
Any idea how to solve this? Thanks so much.