diffrax icon indicating copy to clipboard operation
diffrax copied to clipboard

Langevin PR

Open andyElking opened this issue 1 year ago β€’ 21 comments

Hi Patrick,

Thanks for the heads up, here's the reuploaded PR (with another quick fix).

This PR contains all the new Langevin solvers. All of these inherit from AbstractLangevinSRK in langevin_srk.py. Another important addition is LangevinTerm in _term.py. I explained why it is needed in a comment bellow.

I haven't added the new solvers to the docs and autocite yet, because 1) the relevant paper is not on arxiv yet, but might be in a month or two and 2) I expect you might suggest several changes, so might as well write the docs once the rest is stationary. Still, I think the docstrings and comments are quite comprehensive.

I'm making this PR now so you have ample time to have a look at it, but I will be away for the next few weeks, so there is aboslutely no hurry.

Best, AndraΕΎ

andyElking avatar Jul 01 '24 10:07 andyElking

Hi @patrick-kidger, just bumping this in case you didn't notice that the tests passed. No hurry though :)

andyElking avatar Jul 20 '24 13:07 andyElking

Hi Patrick,

Thanks so much for your review! I addressed almost all of your comments. Two of them I will address in a later commit (it's getting late today haha).

I think you didn't quite understand the reason why the LangevinTerm and the changes in _integrate.py are needed, so I tried to give an explanation, but if that is unclear we can have a call at some point. Let me know :)

Also you might want to read my reply here, but it's a bit hidden way up the conversation: https://github.com/patrick-kidger/diffrax/pull/453#discussion_r1685800805

andyElking avatar Jul 21 '24 21:07 andyElking

Quick heads up: I now made all the edits you suggested and the tests all passed :)

andyElking avatar Jul 25 '24 21:07 andyElking

@patrick-kidger a note about interpolation for ALIGN: to maintain a 2nd order of convergence at interpolated points, the interpolated value cannot depend just on t0, t, t1, y0, y1, but must also depend on W0, H0, W1, H1. Do you think it would be feasible to modify the interpolation code in order to allow for that? Maybe I could store W and H as part of the solution? Do you have any ideas how to do this in a way that fits into Diffrax naturally?

andyElking avatar Jul 29 '24 11:07 andyElking

to maintain a 2nd order of convergence at interpolated points, the interpolated value cannot depend just on t0, t, t1, y0, y1, but must also depend on W0, H0, W1, H1. Do you think it would be feasible to modify the interpolation code in order to allow for that? Maybe I could store W and H as part of the solution? Do you have any ideas how to do this in a way that fits into Diffrax naturally?

I think should be totally fine -- it can go into the dense_info output of a solver step. Take a look at how Runge--Kutta methods output the intermediate stages, for example.

patrick-kidger avatar Aug 02 '24 12:08 patrick-kidger

I think should be totally fine -- it can go into the dense_info output of a solver step. Take a look at how Runge--Kutta methods output the intermediate stages, for example.

That's good to know, thanks! I'll still make it a separate PR - I have a few other tasks to complete beforehand. Otherwise I think the only comment that remains unresolved is https://github.com/patrick-kidger/diffrax/pull/453#discussion_r1701888367. Let me know if there is anything else you'd like me to improve.

andyElking avatar Aug 02 '24 14:08 andyElking

@patrick-kidger I now made a temporary fix, but there are two things that remain to be solved:

  1. For some reason eqx.filter_eval_shape(term.vf, 0.0, y0, args) doesn't work when term is a LangevinTerm (the stacktrace is in a comment above). I genuinely do not undertand what is going wrong, so please help.
  2. I think we are still not on the same page as to why I introduced LangevinTerm in the first place. Please let's have a call at some point to discuss it. And then hopefully it will make sense why using MultiTerm[LangevinDriftTerm, LangevinDiffusionTerm] could lead to incorrect results.

andyElking avatar Aug 04 '24 15:08 andyElking

Great news @patrick-kidger @lockwo: what we discussed worked! Thanks for your advice!

Patrick, if there is nothing else about the code itself that you'd like me to change, then I'll add all of the new things to the docs. Do you think I should add a short example of how to use the langevin solvers as well? Maybe a simple Langevin Monte Carlo example?

andyElking avatar Aug 09 '24 23:08 andyElking

Do you think I should add a short example of how to use the langevin solvers as well? Maybe a simple Langevin Monte Carlo example?

I like the sound of that! An example would be great.

I'll emphasise 'short' -- I really try to keep the examples pedagogical.

patrick-kidger avatar Aug 17 '24 07:08 patrick-kidger

Thanks so much for the review, Patrick! And sorry for all my dummy mistakes πŸ˜…. I made all the smaller edits already and tomorrow I'll write a short example, a test for the backward solve and put everything into the docs.

andyElking avatar Aug 17 '24 22:08 andyElking

I added a brief Langevin example, a backwards in time test and fixed up the docs. I also did another check of the code, so hopefully there's no more rogue comments lurking in there. I also made sure all the docs render correctly.

andyElking avatar Aug 19 '24 17:08 andyElking

I don't want to violate the "short" directive (proceeds to suggest lengthening the example, I know), but I think the Langevin example could be good with a little more meat (depending on which sort of audience/usage it is catered for). E.g. to target the sampling/ML people a little more, maybe adding one more plot showing autocorrelation time/ESS of ULD specific solvers vs EM on the HO problem. Or for physics community you could just add a plot showing some convergence as a function of friction. Just my opinion, since its mostly a tutorial on "how do I use ULD solvers" which seems like docs whereas examples are like "what is interesting/what can I uniquely do with these solvers".

lockwo avatar Aug 19 '24 17:08 lockwo

That's a good point, I wanted to keep it short due to what Patrick said, but if he agrees I'm happy to make it longer. Also it seems you have some good ideas, so I certainly wouldn't mind if you wrote up something like that 😁. I do have some notebooks lying around where I use it for Bayesian logistic regression, but that is very far from short πŸ˜….

andyElking avatar Aug 19 '24 17:08 andyElking

Sure I can write something up that I think maintains the spirit of the "short" directive while adding a little extra spice to get people excited by these solvers and see what Patrick thinks

lockwo avatar Aug 19 '24 17:08 lockwo

Thanks so much Owen! Not sure if this is useful, but in case you'd like to see the Bayesian logistic regression code, it's here: https://github.com/andyElking/diffrax_STLA/blob/devel/mcmc/bayes_log_reg.ipynb

andyElking avatar Aug 19 '24 17:08 andyElking

That's a good example. I got some good examples working, but they also rely on more complex neural network energy functions (which just adds up in terms of LoC), at least adding a little more pretty pictures: https://github.com/andyElking/diffrax_STLA/pull/4.

lockwo avatar Aug 20 '24 06:08 lockwo

I'm actually revisiting your paper (https://arxiv.org/pdf/2405.06464), and I'm curious how these methods can have such a high strong order but have basically the same ESS as constant Euler? Is ESS/sampling just a problem that doesn't require convergence in a strong sense (seems like path wise convergence might matter for autocorrelations but idk)? Did over damped Euler (i.e. traditional Langevin MCMC) also perform comparably?

I don't wanna derail this PR lol, so we can take this off github, but I'm very curious as I played around with it more

lockwo avatar Aug 20 '24 06:08 lockwo

Two things to mention:

  • Even the Euler discretisation of ULD has negative Lyapunov exponents (i.e. two paths with different initial conditions but same BM will converge together over time) as long as it doesn't explode. In other words, while the solution is stable, the importance of the initial condition (or past errors) vanishes exponentially.
  • Other than that, if you keep adding sufficiently large independent Gaussians (or other sufficiently nice RVs), then the autocorrelation will go down fast. And as you can see Euler (mistakenly) explores a lot of space (due to instability), so its ESS is actually too big.

So you can converge to the wrong distribution and have a large ESS. Or you can converge to the correct distribution (like anything Metropolis adjusted should, in theory), but have a very low ESS and thus take forever to explore the state space sufficiently well. The remarkable thing about SORT, however is that it has both a high ESS AND the correct distribution.

LMK if that answers your question.

andyElking avatar Aug 20 '24 08:08 andyElking

"In other words, while the solution is stable, the importance of the initial condition (or past errors) vanishes exponentially." This is generally true of samplers right? Like exponential convergence away from initial condition, at least for discrete samplers, is necessary (since it's just the second largest eigenvalue of the transition matrix). I assume the same is true for continuous samplers.

"So you can converge to the wrong distribution and have a large ESS" I guess this is my confusion. Is this just because they don't meet the requirements of usual samplers. That is to say, any (valid, ergodic, detailed balance fulfilling, etc etc) sampler (e.g. HMC, MALA, RWMH, etc.) will necessarily converge to the true distribution (in infinite time). So any sample you draw after waiting equilibration time (in ULA this is characterized by the eigenvalues of the Fokker Planck operator) is a sample from the true distribution. And ESS dictates the amount of time I have to wait before drawing another sample (to erase the autocorellation), which seems good. So I guess my question is, is the problem that like ULA, Euler (and I guess QUICSORT) ULD are biased samplers (that do not meet my assumption above) because we don't have decreasing stepwise and we don't have any metropolis step/hastings correction? Or is your point just about the first step/getting to the true distribution (and measuring ESS before reaching it)?

lockwo avatar Aug 20 '24 15:08 lockwo

Let's continue over email.

andyElking avatar Aug 20 '24 15:08 andyElking

Hi Patrick, I made all the fixes you recommended. I went through the code again today and I think I don't see any more issues anywhere (but you'll probably still find some πŸ˜…).

I left a few conversations unresolved because I still need some related guidance.

andyElking avatar Aug 22 '24 17:08 andyElking

Quick question @patrick-kidger:

Should I make AbstractFosterLangevinSRK public and add it under Abstract Solvers in the docs? I haven't done it so far, because, unlike RK and SRK, where the user just has to specify a tableau, making a custom child of AbstractFosterLangevinSRK is much more involved, so I doubt users who are just using a packaged version of Diffrax would do that. WDYT?

andyElking avatar Sep 01 '24 12:09 andyElking

I think it probably should be public + in the abstract solvers. I agree that writing your own here is incredibly niche, but I think it's useful for inquisitive users to be able to poke at such things.

patrick-kidger avatar Sep 01 '24 12:09 patrick-kidger

So I added the check that drift and diffusion have the same arguments in AbstractFosterLangevinSRK.init and also added a short test for this in test_underdamped_langevin.py.

I will now do the scan trick.

andyElking avatar Sep 01 '24 14:09 andyElking

I think I now addressed everything you suggested, including the scan trick in both QUICSORT and ShOULD. I'll do another quick check and then I pass it back to you.

andyElking avatar Sep 01 '24 15:09 andyElking

I went through the code and the docs again and now I think I addressed everything you mentioned. Also please take a look at this comment here and let me know if I should revert it to how it was before. Other than that there are no major changes.

Also please take a look at the conversations I left unresolved, namely this and this. Thanks!

andyElking avatar Sep 01 '24 16:09 andyElking

Sorry, I left a tiny problem in the test I added, I fixed it now.

andyElking avatar Sep 01 '24 16:09 andyElking

Aaaaand... merged! πŸŽ‰ Great job getting this one done, I'm really happy to have it in! :)

patrick-kidger avatar Sep 01 '24 19:09 patrick-kidger

Thanks for bearing with me and taking your time Patrick!! I really appreciate it!

andyElking avatar Sep 01 '24 19:09 andyElking