numpyro
numpyro copied to clipboard
Implement `RelaxedOneHotCategoricalStraightThrough`
Following #548 discussion, and while we wait for discrete latent variables, it would be nice to have a Gumbel-Softmax categorical approximation as featured in Pyro. Didn't realize this was the name given to Gumbel-Softmax in Pyro, but hopefully replication might be straight-forward?
numpyro
(i.e. Jax) seems uniquely suited for problems involving large discrete structures (e.g. networks), so an ability to recover latent discrete variables (or their approximations) would be fantastic!
Pyro link for the original implementation.
If no one else commits this until then (and no hard feelings if they do), I will give this a stab starting next weekend.
@daydreamt - Please go ahead, you will be assured a thorough and timely review. If you have any questions around the codebase, please let us know.