torchdiffeq icon indicating copy to clipboard operation
torchdiffeq copied to clipboard

Loss.backward() more stiff than forward simulation?

Open dom-linkevicius opened this issue 2 years ago • 7 comments

Hello,

First of all, many thanks for this library - it is great! I am trying to implement one idea that I had and I am constantly finding that I have no issues on forward simulation using odeint_adjoint, but backwards calls constantly underflow in time (some derivatives explode). I think one could reasonably say that this is due to numerical stability issues. I am using the dopri5 solver (though tried other ones too, e.g. LSODA via method="scipy_solver"), tried lowering the tolerances, but it is still happening (only after a much longer time). At this point I am just trying to understand: does the adjoint system require solvers with better numerical stability / is more stiff? If so - do you know why (and how to avoid that, other than what I already described)? Any help is appreciated, many thanks!

dom-linkevicius avatar Sep 27 '21 15:09 dom-linkevicius

I encountered the derivative explosion, too. Do you have any idea to avoid it? Appreciate any pointers, thanks!

YuanYuan98 avatar Nov 24 '21 04:11 YuanYuan98

It is a bit difficult for me to comment in general, but it may be due to the stiffness of the system you are trying to learn and the instability of the adjoint method (if you are using it) with stiff systems during the gradient calculations (see this arXiv paper for a bit more info).

What I did personally was I moved from Python/PyTorch to Julia programming language, which has many more ODE solvers and seems in general quite a bit faster than what I could get with Python. Usage of a different ODE solver in Julia helped my situation.

dom-linkevicius avatar Nov 24 '21 10:11 dom-linkevicius

Thanks for your reply! In addition, I discover that the adaptive Runge-Kutta step (to integrate the ODE) can also lead to dt = 0, and is it the underflow in time you mentioned?

Do you think modifying the adaptive step function (e.g., if dt is continuously refused, then decrease the acceptance condition) is a probable solution to this problem?

The related function is def _adaptive_step(self, rk_state) in https://github.com/rtqichen/torchdiffeq/blob/master/torchdiffeq/_impl/rk_common.py

YuanYuan98 avatar Nov 24 '21 10:11 YuanYuan98

Yes, that is what I meant when I said underflow in time, the step-size dt goes to 0 and solver gets stuck. In theory modifying the error tolerance in adaptive step solvers should help, but in my own applications it hasn't - either a solver becomes too computationally expensive (takes too long) or the error tolerance is too high and derivatives explode, but YMMV. This is in the context of the solvers easily available in Python/torchdiffeq.

dom-linkevicius avatar Nov 24 '21 10:11 dom-linkevicius

Thanks a lot! I will try some methods to solve the underflow problem.

YuanYuan98 avatar Nov 24 '21 11:11 YuanYuan98

Hi, the underflow problem mentioned in this issue does it include problem of "non-finite values in state y. Which is also in the https://github.com/rtqichen/torchdiffeq/blob/master/torchdiffeq/_impl/rk_common.py.

May I know how do you solve it?

yx-chan131 avatar Jan 21 '22 07:01 yx-chan131

Hello, this was quite a while ago and my solution was to move from Python and PyTorch/torchdiffeq to Julia and DifferentialEquations.jl/Flux.jl.

Looking at the code you mentioned, I think it wasn't the non-finite value in state y, but rather it failed on assert t0 + dt > t0, 'underflow in dt {}'.format(dt.item())

dom-linkevicius avatar Jan 21 '22 10:01 dom-linkevicius