diffrax icon indicating copy to clipboard operation
diffrax copied to clipboard

Gradient systems

Open francesco-innocenti opened this issue 3 years ago • 1 comments

Hi,

I am trying to implement a so-called "gradient system" (i.e., an ODE with a gradient vector field) of the form $$\dot{\mathbf{x}} = - \frac{\partial L(\mathbf{x}; \theta)}{\partial \mathbf{x}}$$ where $\mathbf{x}$ is a list of vectors of different dimensions (corresponding to different model layers) and $L$ some loss function depending on model parameters $\theta$. I have just started to learn JAX and Diffrax and was wondering whether Diffrax supports this kind of systems and, if so, what would be an efficient way of implementing them. Thanks!

francesco-innocenti avatar Oct 25 '22 13:10 francesco-innocenti

Yes, this is easily doable. Something like the following:

def L(y, args):
    ...  # your implementation here

def vector_field(t, y, args):
    return -jax.grad(L)(y, args)

diffrax.diffeqsolve(
    diffrax.ODETerm(vector_field),
    y0=x,
    args=theta,
    ...
)

patrick-kidger avatar Oct 25 '22 16:10 patrick-kidger

Hi 👋

Sorry to "reopen" this. I just had a quick question related to the above. Is it possible for diffeqsolve to return the loss over which the gradient system is defined? The vector_field can be simply made return the loss using value_and_grad but then the ODETerm needs to somehow "know this". Is there some workaround?

Thanks a lot in advance! 🙏

francesco-innocenti avatar Jun 06 '24 14:06 francesco-innocenti

Versions of this request come up every now and again (for outputing additional auxiliary information from the vector field).

This often isn't really a concept that's defined, as in general the times at which we evaluate the vector field numerically and the times at which we return from SaveAt may be completely different.

As such I'm afraid the appropriate thing to do is to solve the differential equation, and then evaluate your L on the output.

patrick-kidger avatar Jun 07 '24 18:06 patrick-kidger

That makes sense, thanks!

Do you have any recommendations in terms of "jitting" the above gradient system for performance? This is probably due to my jax ignorance, but I wonder whether it is more efficient to jit every function, including the loss, the vector field and a wrapper around the ode solver, or whether it is best to jit just highest-level function (and jit will figure out how to best optimise every subroutine)? I also wonder about this since some arguments of the sub-functions are static and "non-jittable".

francesco-innocenti avatar Jun 08 '24 13:06 francesco-innocenti

Jit everything in one go. See point 1 here: https://kidger.site/thoughts/torch2jax/

You can handle static arguments either by using jax.jit(..., static_argnums=...) or (probably easier) equinox.filter_jit.

patrick-kidger avatar Jun 08 '24 15:06 patrick-kidger

Thanks a lot for the support! Feel free to resolve this :)

francesco-innocenti avatar Jun 12 '24 08:06 francesco-innocenti

Will do, thanks!

patrick-kidger avatar Jun 12 '24 12:06 patrick-kidger