KL Divergence for Latent SDEs
A continuation of https://github.com/patrick-kidger/diffrax/pull/402 with the new 0.6.0 lineax changes.
Relevant origin issue: https://github.com/patrick-kidger/diffrax/issues/401
Okay, I really like the example here!
I'm afraid this might still take a bit more iteration to get to something clean enough to merge, though -- see my comments. :)
Happy to iterate on cleaning it, I think the biggest question is the design one (on solvers, terms, and how to represent the problem in diffrax). Once that is resolved, I can iterate quickly to get the rest in :)
Ok, I took the feedback from Andraz's Langevin PR regarding terms and incorporated it into this PR. I think it made things simpler and more in line with the diffrax philosophy, let me know what you think. Basically, now like Langevin, there's just a function that accepts multi term and returns a multi term of private terms that can be consumed by any solver.
The reason I went with returning a single multiterm is you are kind of only solving the one SDE. You use the prior SDE to inform the KL divergence, but its not like fully integrated or anything