mlx-examples
mlx-examples copied to clipboard
Normalizing flow example
Implements a normalizing flow generative model, specifically the basic Real NVP, for density estimation and sampling.
@smsharma This looks amazing! We would love to add it as an example :). I need a little time to review etc, just want to let you know it's not getting ignored!
Thanks, @awni, no hurry obviously!
Closing this as the examples repo seems primarily focused on LLM-style examples.
@smsharma believe it or not, I was still hoping to include this! I think it is a really clean NVP example and we definitely want to diversify beyond just LLMs.
If you're ok with that, I will get to it today? Would you reopen? Really sorry for the delay here, you can always ping us in case we aren't on it
@awni Ah, happy to reopen! Just thought I may have misunderstood the scope of the repo. Thanks!
@awni Ah, happy to reopen! Just thought I may have misunderstood the scope of the repo. Thanks!
Thanks! The LM stuff is keeping us busy for sure. But you did not misundestand at all. We'd like to keep this for simple and diverse examples. (The LM package stuff might get moved out into it's own thing in the future..that's the part that is starting to be "out of scope" 😄 )
@smsharma this looks great!! I pushed a few mostly cosmetic changes to your branch.
It runs and it works well :).
One thing I wanted to run by you:
As you say in the README, the level of class structure is a more than necessary. I'm wondering if we should simplify it?
My thought is that the base classes add noise and make it harder to understand the example without adding much in terms of simplifying future extensions. What do you think about simplifying the class structure a bit?
Also I moved the example from flow
to real_nvp
to be a bit more descriptive / help with discovery
Thanks a lot @awni, everything looks good!
Re: class structure, agreed it doesn't add very much in particular for distributions.py
, so I've gone ahead and simplified it by removing the base class.
For bijectors.py
, since there's a composition of two bijectors (a masked coupling that takes in an affine bijector), I've kept the base class as a convenience, but if you have thoughts on simplifying the structure, all for it!
Lastly and minor: I've simplified the MaskedCoupling
class a bit by defining a separate function for common code used in the forward and reverse transforms.
Thank you @awni! Been fun playing around with MLX, and I think it could be nice for general Bayesian inference in the scientific context, for which this was a little initial test case.