Best way to approach KL divergence
I was wondering what the best way of recreating the KL divergence from torchsde would be in diffrax. I have a hacked together version of https://github.com/google-research/torchsde/blob/master/examples/latent_sde.py, but it seems like there might be a more general approach. I didn't see any issues, but I see an older stale PR (https://github.com/patrick-kidger/diffrax/pull/104), might be able to revive that if this would be of interest to the package
I think it should still be possible in a similar way to #104! I'm undecided on whether this makes sense to include as a function in Diffrax itself, but if you put some functionality together then I'd be interested to see it.
I think probably the only thing I'd remark is that we now have Lineax, and might be able to use a cleaner + more structured approach to solving the linear system -- using Lineax to exploiti the diagonal structure if present etc.
Definitely should be possible to clean it up a lot, but any preliminary thoughts on this approach would be much appreciated: https://github.com/patrick-kidger/diffrax/pull/402