Bijectors.jl icon indicating copy to clipboard operation
Bijectors.jl copied to clipboard

Improve `with_logabsdet_jacobian` performance for `SimplexBijector`

Open torfjelde opened this issue 1 year ago • 1 comments

The PR made me wonder whether it would be possible to improve performance of with_logabsdet_jacobian for SimplexBijector by not performing transform and logabsdetjac separately when both are requested. Doesn't block this PR and maybe would lead to more code duplications though.

Originally posted by @devmotion in https://github.com/TuringLang/Bijectors.jl/pull/302#pullrequestreview-2000888436

torfjelde avatar Apr 15 '24 12:04 torfjelde

Yes, it's absolutely possible. The simplex transform we use (and Stan as well) is just the classic stick-breaking transform (called the inverse multiplicative log-ratio transform in the compositional data analysis literature) shifted in the unconstrained space so that for all symmetric Dirichlet distributions, the unconstrained distribution's mode is at the origin. Its Jacobian is the product of the elements of the output vector, but for numerical stability, it's better to perform the entire transform on the log-scale. There's a Stan implementation of this here: https://github.com/mjhajharia/transforms/blob/1207723c4c4208116f80204fe35f1631aaa30f6a/transforms/simplex/StickbreakingLogistic.stan#L2-L17 .

sethaxen avatar May 24 '24 16:05 sethaxen