flowjax icon indicating copy to clipboard operation
flowjax copied to clipboard

Q: Normal Distribution Compatibility with Reparameterization Trick?

Open gil2rok opened this issue 5 months ago • 2 comments

I am interested in using Optax's stochastic gradient estimators and control variates with FlowJAX. In particular, I am interested in compatibility with the reparameterization gradient (aka pathwise estimator).

The reparameterization gradient requires the "reparameterization trick" to compute the gradient of an expectation. For a normally distributed variable $x \sim N(\mu, \sigma)$, this is implemented by rewriting it as $x = z * \sigma + \mu$ for mean $\mu$, scale $\sigma$, and $z \sim N(0,1)$. For more details see: https://gregorygundersen.com/blog/2018/04/29/reparameterization/.

Are the Normal and MultivariateNormal distributions in FlowJAX compatible with the reparameterization trick by default when using jax.grad? I believe the answer is yes because they both take the StandardNormal distribution (equivalent to z above) and transform it with some (affine?) bijection.

I was wondering if @danielward27 can please confirm if this is true. If so, it may be worth mentioning somewhere in the docs. Thank you!

gil2rok avatar Sep 29 '24 17:09 gil2rok