RegNeuralDE.jl icon indicating copy to clipboard operation
RegNeuralDE.jl copied to clipboard

Significantly Faster Adjoint Solves

Open jessebett opened this issue 4 years ago • 1 comments

@ChrisRackauckas have you seen this work by Ricky already?

https://arxiv.org/abs/2009.09457

"Hey, that's not an ODE": Faster ODE Adjoints with 12 Lines of Code Patrick Kidger, Ricky T. Q. Chen, Terry Lyons

Neural differential equations may be trained by backpropagating gradients via the adjoint method, which is another differential equation typically solved using an adaptive-step-size numerical differential equation solver. A proposed step is accepted if its error, \emph{relative to some norm}, is sufficiently small; else it is rejected, the step is shrunk, and the process is repeated. Here, we demonstrate that the particular structure of the adjoint equations makes the usual choices of norm (such as L2) unnecessarily stringent. By replacing it with a more appropriate (semi)norm, fewer steps are unnecessarily rejected and the backpropagation is made faster. This requires only minor code modifications. Experiments on a wide range of tasks---including time series, generative modeling, and physical control---demonstrate a median improvement of 40% fewer function evaluations. On some problems we see as much as 62% fewer function evaluations, so that the overall training time is roughly halved.

jessebett avatar Oct 22 '20 19:10 jessebett

Yes, I saw it. That'll only work on neural ODEs though. There's more than a few well-known results that you need very accurate gradients for optimization of physical parameters in differential equations. Generally adjoint methods have a difficulty because of this: this just accentuates the difficulty. Pumas actually has a ton of examples showing how it can fail... they need a limitations section mentioning the non-generalizability of the results and should mention some of those.

FWIW, it's one line in DiffEqFlux. You just do internalnorm = (u,p,t)->... and just write in the semi-norm that you want. We specifically change the norm to properly include the derivative terms though because if you don't you will fail to handle saddle points in things like pharmacokinetic models.

I can't share more details on the model here since IIRC it was on an FDA submission, but this plot really showcases how it fails.

63291156-47c4c580-c2c3-11e9-868a-922571fa7f3c

If you have too high of tolerances you can get non-smooth changes in the gradient and thus it's not able to hone in on the saddle point. So what we had to do for the FDA submissions to work was ensure that all gradient calculations were to at least 1e-8 tolerance, since otherwise you could get divergence due to the stiffness. So saying that this gives "faster ODE adjoints" is misleading: the adjoint just doesn't take the error of the integral into account, but this can have some very adverse effects.

In fact, in the latest version of the UDE paper there's a paragraph of I think 10-15 sources that demonstrates ways that adjoints methods can fail because of this accuracy issue.

So 🤷 it's a one line thing with enough well-known counter examples so I don't think any reviewer would accept it, so I've effectively ignored it. In fact, Lars Ruthotto has a nice paper that demonstrates that these errors do have a major effect on the training performance of even neural ODEs which you can only see if you train with backprop AD.

https://arxiv.org/pdf/2005.13420.pdf

But the traditional discrete sensitivity analysis literature in particular has a ton of damning examples that show you should probably not do this in any general purpose code.

ChrisRackauckas avatar Oct 22 '20 21:10 ChrisRackauckas