diffrax icon indicating copy to clipboard operation
diffrax copied to clipboard

Compatibility with Flax

Open tims457 opened this issue 3 years ago • 11 comments

Will diffrax.diffeqsolve work inside a Flax linen Module? How would you set up the initialization to use Flax inside of ODETerm instead of Equinox?

tims457 avatar Jun 22 '22 02:06 tims457

I'm not a Flax user, so take this with a pinch of salt. But probably something like the following.

variables = model.init(...)

def vector_field(t, y, args):
    return model.apply(args, y)

diffeqsolve(ODETerm(vector_field), ..., args=variables)

patrick-kidger avatar Jun 22 '22 09:06 patrick-kidger

Thanks. I'll give it a go.

tims457 avatar Jun 22 '22 12:06 tims457

Hmm, I want to do something like the following:

import diffrax
import jax
import jax.numpy as jnp
from flax import linen as nn

class NeuralODE(nn.Module):
    derivative_net: nn.Module

    def __call__(self, coords):
        def f(t, y, args):
            return self.derivative_net(y)

        term = diffrax.ODETerm(f)
        solver = diffrax.Dopri5()
        solution = diffrax.diffeqsolve(term, solver, t0=0, t1=1, dt0=0.1, y0=coords)
        return solution.ys

coords = jnp.ones((1, 4))
model = NeuralODE(derivative_net=nn.Dense(4))
rng = jax.random.PRNGKey(0)
params = jax.jit(model.init)(rng, coords)

Yet, this gives me:

UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with shape (4,) and dtype float32 to escape.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was diffeqsolve at ./.venv/lib/python3.10/site-packages/equinox/jit.py:25 traced for xla_call.
------------------------------
The leaked intermediate value was created on line ./.venv/lib/python3.10/site-packages/flax/core/scope.py:767 (param). 
------------------------------
When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
------------------------------
...
./.venv/lib/python3.10/site-packages/flax/core/scope.py:767 (param)
------------------------------

To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.Detail: Can't lift sublevels 2 to 1
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError

Any ideas @patrick-kidger?

ameya98 avatar Sep 16 '22 01:09 ameya98

I'm afraid not. This looks like something to do with Flax, which I'm not familiar enough with that I can help debug this.

That said I'm sure it's possible. Both Equinox and Diffrax operate at "the samel level as" normal JAX. (In contrast Flax is a wrapper around JAX.)

patrick-kidger avatar Sep 16 '22 02:09 patrick-kidger

Thanks for the help. It’s at moments like these when I wish I was still at Google :) I wonder if diffrax is calling jit inside diffeqsolve. That would explain the error. Is there a way to disable that?

ameya98 avatar Sep 16 '22 02:09 ameya98

Otherwise, is there an way to use Equinox and Flax together? Do you have any examples?

ameya98 avatar Sep 16 '22 02:09 ameya98

Ah, I figured out a (slightly hacky) way to do this:

class NeuralODE(flax.struct.PyTreeNode):
    """A simple neural ODE."""

    encoder: nn.Module
    derivative_net: nn.Module
    decoder: nn.Module

    def init(self, rng, coords):
        rng, encoder_rng, derivative_net_rng, decoder_rng = jax.random.split(rng, 4)
        coords, encoder_params = self.encoder.init_with_output(encoder_rng, coords)
        coords, derivative_net_params = self.derivative_net.init_with_output(derivative_net_rng, coords)
        coords, decoder_params = self.decoder.init_with_output(decoder_rng, coords)

        return {
            "encoder": encoder_params,
            "derivative_net": derivative_net_params,
            "decoder": decoder_params
        }

    def apply(self, params, coords):
        coords = self.encoder.apply(params["encoder"], coords)

        def f(t, y, args):
            return self.derivative_net.apply(params["derivative_net"], y)

        term = diffrax.ODETerm(f)
        solver = diffrax.Euler()
        solution = diffrax.diffeqsolve(term, solver, t0=0, t1=1, dt0=0.1, y0=coords)
        coords = solution.ys
        coords = self.decoder.apply(params["decoder"], coords)
        return coords

rng = jax.random.PRNGKey(0)
coords = jnp.ones((1, 4))

model = NeuralODE(
    encoder=nn.Dense(10),
    derivative_net=nn.Dense(10),
    decoder=nn.Dense(4))
params = jax.jit(model.init)(rng, coords)

Then you can simply use this like any other nn.Module:

@jax.jit
def compute_loss(params, coords, true_coords):
    preds = model.apply(params, coords)
    return jnp.abs(preds - true_coords).sum()

grads = jax.grad(compute_loss)(params, coords, jnp.zeros_like(coords))

This just uses flax.struct.PyTreeNode instead of eqx.Module. I didn't want to mix both of them in my codebase. Thanks a lot for the help!

ameya98 avatar Sep 16 '22 13:09 ameya98

Hurrah! I'm glad you figured this out.

patrick-kidger avatar Sep 16 '22 14:09 patrick-kidger

@patrick-kidger I'm using haiku and also faced with the leaked tracer issue.

I was almost clueless for about 1 hour, until I find that comparing to my other programs, this "buggy" program uses the haiku model (a simple MLP) only in the ODETerm , thus the haiku model initialization happens inside diffrax frames.

I workarounded by calling the model with fake compatible data once before calling diffrax (thus the model is already initialized when calling diffrax).

I understand that the haiku or flax way of bridging Jax-style pure function with Pytorch-style module has certain degree of "dark" magic inside, but I don't think there are simple ways to trigger frightening leaked tracer exception.

Could you suggest the root cause of this exception? If it's infeasible to prevent such leak, is it possible to provide a more reasonable error or warning at least?

jjyyxx avatar Apr 17 '23 15:04 jjyyxx

So Haiku (like Flax) was implemented as a wrapper around JAX, and by-and-large isn't compatible with other libraries in the JAX ecosystem. That's really the root cause -- Haiku assumes that it's only being used in particular ways.

My top recommendation is just to use Equinox instead. This provides a PyTorch-style module without the "dark magic".

If you really want to use Haiku, then probably the best thing to do is to pass your MLP through hk.transform before using another library. This should transform the Haiku DSL into "normal JAX".

patrick-kidger avatar Apr 17 '23 17:04 patrick-kidger