equinox icon indicating copy to clipboard operation
equinox copied to clipboard

using batchnorm/dropout layers from flax.linen along with diffrax package

Open Negar-Erfanian opened this issue 1 year ago • 14 comments

Hi Patrick,

I am using the diffrax ode solver in my code and will need to use batchnorm/dropout layers in the function that will be passed to the solver. However this is the error I am getting:

solution = diffrax.diffeqsolve( File "/data/ne12/.conda/envs/Negar2/lib/python3.8/site-packages/equinox/_jit.py", line 99, in call return self._call(False, args, kwargs) File "/data/ne12/.conda/envs/Negar2/lib/python3.8/site-packages/equinox/_jit.py", line 95, in _call out = self._cached(dynamic, static) File "/data/ne12/.conda/envs/Negar2/lib/python3.8/site-packages/equinox/_jit.py", line 37, in fun_wrapped out = fun(*args, **kwargs) File "/data/ne12/.conda/envs/Negar2/lib/python3.8/site-packages/diffrax/integrate.py", line 676, in diffeqsolve solver_state = solver.init(terms, t0, tnext, y0, args) File "/data/ne12/.conda/envs/Negar2/lib/python3.8/site-packages/diffrax/solver/runge_kutta.py", line 269, in init return terms.vf(t0, y0, args) File "/data/ne12/.conda/envs/Negar2/lib/python3.8/site-packages/diffrax/term.py", line 364, in vf return self.term.vf(t, y, args) File "/data/ne12/.conda/envs/Negar2/lib/python3.8/site-packages/diffrax/term.py", line 173, in vf return self.vector_field(t, y, args) File "/data/ne12/Kuramoto/model/neuralODE.py", line 56, in fn y0 = self.batchnorm(y0, use_running_average=not training) File "/data/ne12/.conda/envs/Negar2/lib/python3.8/site-packages/flax/linen/normalization.py", line 256, in call ra_mean = self.variable('batch_stats', 'mean', File "/data/ne12/.conda/envs/Negar2/lib/python3.8/site-packages/flax/core/tracers.py", line 36, in check_trace_level raise errors.JaxTransformError() flax.errors.JaxTransformError: Jax transforms and Flax models cannot be mixed. (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.JaxTransformError)

I am not using equinox but instead using flax.linen.Module to build up my Neural Network.

`class NeuralODE(nn.Module):
act_fn: Callable actfn: Callable node_size: int hdims : str batchnorm : Callable = nn.BatchNorm() dropout : Callable = nn.Dropout(0.1) @nn.compact def call(self, ts, ys, training): y, args = ys if self.hdims!=None: hdims_list = [int(i) for i in self.hdims.split('-')] else: hdims_list = [] kernels = [] first_dim = self.node_size for i, dim in enumerate(hdims_list): kernels.append(self.param(f'kernel{i}', nn.initializers.normal(), [dim, self.node_size, first_dim])) first_dim = dim kernels.append(self.param('kernel', nn.initializers.normal(), [self.node_size, self.node_size, first_dim])) def fn(t, y, args): y0 = y bias, data_adj = args if len(y0.shape) == 2: y0 = jnp.expand_dims(y0, -1) elif len(y0.shape) == 1: y0 = jnp.expand_dims(jnp.expand_dims(y0, -1), 0) for kernel in kernels: y0 = jnp.einsum('ijk,lmj->ilm', y0, kernel) y0 = self.act_fn(y0) y0 = self.batchnorm(y0, use_running_average=not training) y0 = self.dropout(y0, deterministic=not training)

        if y0.shape[0] == 1:
            y0 = jnp.squeeze(y0, 0)
        if len(y0.shape) == 2:
            out = jnp.einsum('ij,ij->ij', data_adj, y0).sum(-1)
        elif len(y0.shape) == 3:
            out = jnp.einsum('aij,aij->aij', data_adj, y0).sum(-1)
        out = jnp.squeeze(bias, -1) - out  # B*N
        
        return out#, bias, data_adj

    solution = diffrax.diffeqsolve(
        diffrax.ODETerm(fn),
        diffrax.Dopri5(),
        t0=ts[0],
        t1=ts[-1],
        dt0=0.01,  # ts[1] - ts[0],
        y0=y,
        args=args,
        stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6, dtmax=0.1),
        saveat=diffrax.SaveAt(ts=ts),
        made_jump=True,

    )
    return solution.ys`

The error raises when I init the model using model.init(model_key, ts, inputs, training= False).

Any idea how to solve this? Thanks so much.

Negar-Erfanian avatar Jun 11 '23 16:06 Negar-Erfanian