Langevin PR
Hi Patrick,
Thanks for the heads up, here's the reuploaded PR (with another quick fix).
This PR contains all the new Langevin solvers. All of these inherit from AbstractLangevinSRK in langevin_srk.py. Another important addition is LangevinTerm in _term.py. I explained why it is needed in a comment bellow.
I haven't added the new solvers to the docs and autocite yet, because 1) the relevant paper is not on arxiv yet, but might be in a month or two and 2) I expect you might suggest several changes, so might as well write the docs once the rest is stationary. Still, I think the docstrings and comments are quite comprehensive.
I'm making this PR now so you have ample time to have a look at it, but I will be away for the next few weeks, so there is aboslutely no hurry.
Best, AndraΕΎ
Hi @patrick-kidger, just bumping this in case you didn't notice that the tests passed. No hurry though :)
Hi Patrick,
Thanks so much for your review! I addressed almost all of your comments. Two of them I will address in a later commit (it's getting late today haha).
I think you didn't quite understand the reason why the LangevinTerm and the changes in _integrate.py are needed, so I tried to give an explanation, but if that is unclear we can have a call at some point. Let me know :)
Also you might want to read my reply here, but it's a bit hidden way up the conversation: https://github.com/patrick-kidger/diffrax/pull/453#discussion_r1685800805
Quick heads up: I now made all the edits you suggested and the tests all passed :)
@patrick-kidger a note about interpolation for ALIGN:
to maintain a 2nd order of convergence at interpolated points, the interpolated value cannot depend just on t0, t, t1, y0, y1, but must also depend on W0, H0, W1, H1. Do you think it would be feasible to modify the interpolation code in order to allow for that? Maybe I could store W and H as part of the solution? Do you have any ideas how to do this in a way that fits into Diffrax naturally?
to maintain a 2nd order of convergence at interpolated points, the interpolated value cannot depend just on t0, t, t1, y0, y1, but must also depend on W0, H0, W1, H1. Do you think it would be feasible to modify the interpolation code in order to allow for that? Maybe I could store W and H as part of the solution? Do you have any ideas how to do this in a way that fits into Diffrax naturally?
I think should be totally fine -- it can go into the dense_info output of a solver step. Take a look at how Runge--Kutta methods output the intermediate stages, for example.
I think should be totally fine -- it can go into the
dense_infooutput of a solver step. Take a look at how Runge--Kutta methods output the intermediate stages, for example.
That's good to know, thanks! I'll still make it a separate PR - I have a few other tasks to complete beforehand. Otherwise I think the only comment that remains unresolved is https://github.com/patrick-kidger/diffrax/pull/453#discussion_r1701888367. Let me know if there is anything else you'd like me to improve.
@patrick-kidger I now made a temporary fix, but there are two things that remain to be solved:
- For some reason
eqx.filter_eval_shape(term.vf, 0.0, y0, args)doesn't work whentermis aLangevinTerm(the stacktrace is in a comment above). I genuinely do not undertand what is going wrong, so please help. - I think we are still not on the same page as to why I introduced
LangevinTermin the first place. Please let's have a call at some point to discuss it. And then hopefully it will make sense why usingMultiTerm[LangevinDriftTerm, LangevinDiffusionTerm]could lead to incorrect results.
Great news @patrick-kidger @lockwo: what we discussed worked! Thanks for your advice!
Patrick, if there is nothing else about the code itself that you'd like me to change, then I'll add all of the new things to the docs. Do you think I should add a short example of how to use the langevin solvers as well? Maybe a simple Langevin Monte Carlo example?
Do you think I should add a short example of how to use the langevin solvers as well? Maybe a simple Langevin Monte Carlo example?
I like the sound of that! An example would be great.
I'll emphasise 'short' -- I really try to keep the examples pedagogical.
Thanks so much for the review, Patrick! And sorry for all my dummy mistakes π . I made all the smaller edits already and tomorrow I'll write a short example, a test for the backward solve and put everything into the docs.
I added a brief Langevin example, a backwards in time test and fixed up the docs. I also did another check of the code, so hopefully there's no more rogue comments lurking in there. I also made sure all the docs render correctly.
I don't want to violate the "short" directive (proceeds to suggest lengthening the example, I know), but I think the Langevin example could be good with a little more meat (depending on which sort of audience/usage it is catered for). E.g. to target the sampling/ML people a little more, maybe adding one more plot showing autocorrelation time/ESS of ULD specific solvers vs EM on the HO problem. Or for physics community you could just add a plot showing some convergence as a function of friction. Just my opinion, since its mostly a tutorial on "how do I use ULD solvers" which seems like docs whereas examples are like "what is interesting/what can I uniquely do with these solvers".
That's a good point, I wanted to keep it short due to what Patrick said, but if he agrees I'm happy to make it longer. Also it seems you have some good ideas, so I certainly wouldn't mind if you wrote up something like that π. I do have some notebooks lying around where I use it for Bayesian logistic regression, but that is very far from short π .
Sure I can write something up that I think maintains the spirit of the "short" directive while adding a little extra spice to get people excited by these solvers and see what Patrick thinks
Thanks so much Owen! Not sure if this is useful, but in case you'd like to see the Bayesian logistic regression code, it's here: https://github.com/andyElking/diffrax_STLA/blob/devel/mcmc/bayes_log_reg.ipynb
That's a good example. I got some good examples working, but they also rely on more complex neural network energy functions (which just adds up in terms of LoC), at least adding a little more pretty pictures: https://github.com/andyElking/diffrax_STLA/pull/4.
I'm actually revisiting your paper (https://arxiv.org/pdf/2405.06464), and I'm curious how these methods can have such a high strong order but have basically the same ESS as constant Euler? Is ESS/sampling just a problem that doesn't require convergence in a strong sense (seems like path wise convergence might matter for autocorrelations but idk)? Did over damped Euler (i.e. traditional Langevin MCMC) also perform comparably?
I don't wanna derail this PR lol, so we can take this off github, but I'm very curious as I played around with it more
Two things to mention:
- Even the Euler discretisation of ULD has negative Lyapunov exponents (i.e. two paths with different initial conditions but same BM will converge together over time) as long as it doesn't explode. In other words, while the solution is stable, the importance of the initial condition (or past errors) vanishes exponentially.
- Other than that, if you keep adding sufficiently large independent Gaussians (or other sufficiently nice RVs), then the autocorrelation will go down fast. And as you can see Euler (mistakenly) explores a lot of space (due to instability), so its ESS is actually too big.
So you can converge to the wrong distribution and have a large ESS. Or you can converge to the correct distribution (like anything Metropolis adjusted should, in theory), but have a very low ESS and thus take forever to explore the state space sufficiently well. The remarkable thing about SORT, however is that it has both a high ESS AND the correct distribution.
LMK if that answers your question.
"In other words, while the solution is stable, the importance of the initial condition (or past errors) vanishes exponentially." This is generally true of samplers right? Like exponential convergence away from initial condition, at least for discrete samplers, is necessary (since it's just the second largest eigenvalue of the transition matrix). I assume the same is true for continuous samplers.
"So you can converge to the wrong distribution and have a large ESS" I guess this is my confusion. Is this just because they don't meet the requirements of usual samplers. That is to say, any (valid, ergodic, detailed balance fulfilling, etc etc) sampler (e.g. HMC, MALA, RWMH, etc.) will necessarily converge to the true distribution (in infinite time). So any sample you draw after waiting equilibration time (in ULA this is characterized by the eigenvalues of the Fokker Planck operator) is a sample from the true distribution. And ESS dictates the amount of time I have to wait before drawing another sample (to erase the autocorellation), which seems good. So I guess my question is, is the problem that like ULA, Euler (and I guess QUICSORT) ULD are biased samplers (that do not meet my assumption above) because we don't have decreasing stepwise and we don't have any metropolis step/hastings correction? Or is your point just about the first step/getting to the true distribution (and measuring ESS before reaching it)?
Let's continue over email.
Hi Patrick, I made all the fixes you recommended. I went through the code again today and I think I don't see any more issues anywhere (but you'll probably still find some π ).
I left a few conversations unresolved because I still need some related guidance.
Quick question @patrick-kidger:
Should I make AbstractFosterLangevinSRK public and add it under Abstract Solvers in the docs? I haven't done it so far, because, unlike RK and SRK, where the user just has to specify a tableau, making a custom child of AbstractFosterLangevinSRK is much more involved, so I doubt users who are just using a packaged version of Diffrax would do that. WDYT?
I think it probably should be public + in the abstract solvers. I agree that writing your own here is incredibly niche, but I think it's useful for inquisitive users to be able to poke at such things.
So I added the check that drift and diffusion have the same arguments in AbstractFosterLangevinSRK.init and also added a short test for this in test_underdamped_langevin.py.
I will now do the scan trick.
I think I now addressed everything you suggested, including the scan trick in both QUICSORT and ShOULD. I'll do another quick check and then I pass it back to you.
I went through the code and the docs again and now I think I addressed everything you mentioned. Also please take a look at this comment here and let me know if I should revert it to how it was before. Other than that there are no major changes.
Also please take a look at the conversations I left unresolved, namely this and this. Thanks!
Sorry, I left a tiny problem in the test I added, I fixed it now.
Aaaaand... merged! π Great job getting this one done, I'm really happy to have it in! :)
Thanks for bearing with me and taking your time Patrick!! I really appreciate it!