flowjax
flowjax copied to clipboard
Q: Normal Distribution Compatibility with Reparameterization Trick?
I am interested in using Optax's stochastic gradient estimators and control variates with FlowJAX. In particular, I am interested in compatibility with the reparameterization gradient (aka pathwise estimator).
The reparameterization gradient requires the "reparameterization trick" to compute the gradient of an expectation. For a normally distributed variable $x \sim N(\mu, \sigma)$, this is implemented by rewriting it as $x = z * \sigma + \mu$ for mean $\mu$, scale $\sigma$, and $z \sim N(0,1)$. For more details see: https://gregorygundersen.com/blog/2018/04/29/reparameterization/.
Are the Normal
and MultivariateNormal
distributions in FlowJAX compatible with the reparameterization trick by default when using jax.grad
? I believe the answer is yes because they both take the StandardNormal
distribution (equivalent to z
above) and transform it with some (affine?) bijection.
I was wondering if @danielward27 can please confirm if this is true. If so, it may be worth mentioning somewhere in the docs. Thank you!