diffrax
diffrax copied to clipboard
Gradient systems
Hi,
I am trying to implement a so-called "gradient system" (i.e., an ODE with a gradient vector field) of the form $$\dot{\mathbf{x}} = - \frac{\partial L(\mathbf{x}; \theta)}{\partial \mathbf{x}}$$ where $\mathbf{x}$ is a list of vectors of different dimensions (corresponding to different model layers) and $L$ some loss function depending on model parameters $\theta$. I have just started to learn JAX and Diffrax and was wondering whether Diffrax supports this kind of systems and, if so, what would be an efficient way of implementing them. Thanks!
Yes, this is easily doable. Something like the following:
def L(y, args):
... # your implementation here
def vector_field(t, y, args):
return -jax.grad(L)(y, args)
diffrax.diffeqsolve(
diffrax.ODETerm(vector_field),
y0=x,
args=theta,
...
)
Hi 👋
Sorry to "reopen" this. I just had a quick question related to the above. Is it possible for diffeqsolve to return the loss over which the gradient system is defined? The vector_field can be simply made return the loss using value_and_grad but then the ODETerm needs to somehow "know this". Is there some workaround?
Thanks a lot in advance! 🙏
Versions of this request come up every now and again (for outputing additional auxiliary information from the vector field).
This often isn't really a concept that's defined, as in general the times at which we evaluate the vector field numerically and the times at which we return from SaveAt may be completely different.
As such I'm afraid the appropriate thing to do is to solve the differential equation, and then evaluate your L on the output.
That makes sense, thanks!
Do you have any recommendations in terms of "jitting" the above gradient system for performance? This is probably due to my jax ignorance, but I wonder whether it is more efficient to jit every function, including the loss, the vector field and a wrapper around the ode solver, or whether it is best to jit just highest-level function (and jit will figure out how to best optimise every subroutine)? I also wonder about this since some arguments of the sub-functions are static and "non-jittable".
Jit everything in one go. See point 1 here: https://kidger.site/thoughts/torch2jax/
You can handle static arguments either by using jax.jit(..., static_argnums=...) or (probably easier) equinox.filter_jit.
Thanks a lot for the support! Feel free to resolve this :)
Will do, thanks!