diffrax
diffrax copied to clipboard
Paper - Correcting auto-differentiation in neural-ODE training
This may be of interest.
Does the use of auto-differentiation yield reasonable updates to deep neural networks that represent neural ODEs? Through mathematical analysis and numerical evidence, we find that when the neural network employs high-order forms to approximate the underlying ODE flows (such as the Linear Multistep Method (LMM)), brute-force computation using auto-differentiation often produces non-converging artificial oscillations. In the case of Leapfrog, we propose a straightforward post-processing technique that effectively eliminates these oscillations, rectifies the gradient computation and thus respects the updates of the underlying flow.
https://arxiv.org/abs/2306.02192
This is a really nice paper -- thank you for sharing it!
At least in the weight-tied regime (=common for "true" neural ODEs, uncommon for ResNets) this seems to suggest that we can halve the cost of backpropagation when using Leapfrog -- we only really need to track the branch with non-negligible gradients.
I wonder how this problem would generalise to other linear multistep methods -- most practical usage of autodiff+diffeqsolves has involved Runge--Kutta methods.
They state on p3,
Fundamentally, the oscillation arises from the fact that Leapfrog calls for two history steps, while in contrast, mismatch used in the back-propagation in auto-differentiation can provide only one value at the final step t = 1. As a consequence, the chain-rule-type gradient computation provides a non-physical assignment to the final two data points as the initialization for the backpropagation. We provide calculations only for Leapfrog, but the same phenomenon is expected to hold for general LMM type neural ODEs.
But as far as I can see, no more information is given.