diffrax icon indicating copy to clipboard operation
diffrax copied to clipboard

Add KL divergence terms for Latent SDEs

Open lockwo opened this issue 1 year ago • 9 comments

Addresses https://github.com/patrick-kidger/diffrax/issues/401. Revives https://github.com/patrick-kidger/diffrax/pull/104. Based on that PR, I made the minimal requirements to get it up to current version (e.g. taking callables instead of ODE terms since we can't make these .vf becuase _broadcast_and_upcast requires that aug_y and drift(aug_y) are the same shape, but they aren't).

lockwo avatar Apr 17 '24 21:04 lockwo

Before going further (there is a lot I am going to improve/polish) I wanted to check with your thoughts on the general approach of KL being terms and exposing the user to a function that converts their problem. An alternative could be something like in torchsde where it's part of the intregration method, i.e. the user flags it at integration time.

lockwo avatar Apr 17 '24 21:04 lockwo

On the topic of Lineax: indeed, this should definitely make handling PyTrees much easier.

patrick-kidger avatar Apr 21 '24 20:04 patrick-kidger

I think your idea makes a lot of sense, and I made a fair amount of progress on the solver wrapper approach.

lockwo avatar Apr 24 '24 01:04 lockwo

Ok, I polished things up. I went with a sort of hybrid approach where the users specifies the SDEs as you described, then just wraps a solver and everything works smoothly. However, I did create internal terms, in order to get an arbitrary solver to integrate through the KL computation, that was the best way I could think of to do so, but they are completely hidden from the user. I also added the example (can be modified to add more text, or remove pmap although I do like having an example with distribution especially since its painfully slow without it) and a test and updated the docs. Taking it off draft now since its a real PR.

lockwo avatar Apr 27 '24 06:04 lockwo

This is a very cool feature/example! It looks like one needs to specify

levy_area=diffrax.BrownianIncrement

in diffrax.UnsafeBrownianPath

frankschae avatar May 08 '24 21:05 frankschae

Thanks @frankschae , good catch!

lockwo avatar May 08 '24 23:05 lockwo