diffrax icon indicating copy to clipboard operation
diffrax copied to clipboard

Best way to approach KL divergence

Open lockwo opened this issue 1 year ago • 2 comments

I was wondering what the best way of recreating the KL divergence from torchsde would be in diffrax. I have a hacked together version of https://github.com/google-research/torchsde/blob/master/examples/latent_sde.py, but it seems like there might be a more general approach. I didn't see any issues, but I see an older stale PR (https://github.com/patrick-kidger/diffrax/pull/104), might be able to revive that if this would be of interest to the package

lockwo avatar Apr 17 '24 05:04 lockwo

I think it should still be possible in a similar way to #104! I'm undecided on whether this makes sense to include as a function in Diffrax itself, but if you put some functionality together then I'd be interested to see it.

I think probably the only thing I'd remark is that we now have Lineax, and might be able to use a cleaner + more structured approach to solving the linear system -- using Lineax to exploiti the diagonal structure if present etc.

patrick-kidger avatar Apr 17 '24 07:04 patrick-kidger

Definitely should be possible to clean it up a lot, but any preliminary thoughts on this approach would be much appreciated: https://github.com/patrick-kidger/diffrax/pull/402

lockwo avatar Apr 17 '24 21:04 lockwo