numpyro icon indicating copy to clipboard operation
numpyro copied to clipboard

Implement `RelaxedOneHotCategoricalStraightThrough`

Open rtbs-dev opened this issue 4 years ago • 3 comments

Following #548 discussion, and while we wait for discrete latent variables, it would be nice to have a Gumbel-Softmax categorical approximation as featured in Pyro. Didn't realize this was the name given to Gumbel-Softmax in Pyro, but hopefully replication might be straight-forward?

numpyro (i.e. Jax) seems uniquely suited for problems involving large discrete structures (e.g. networks), so an ability to recover latent discrete variables (or their approximations) would be fantastic!

rtbs-dev avatar Mar 13 '20 13:03 rtbs-dev

Pyro link for the original implementation.

rtbs-dev avatar Mar 13 '20 13:03 rtbs-dev

If no one else commits this until then (and no hard feelings if they do), I will give this a stab starting next weekend.

daydreamt avatar Mar 22 '20 16:03 daydreamt

@daydreamt - Please go ahead, you will be assured a thorough and timely review. If you have any questions around the codebase, please let us know.

neerajprad avatar Mar 23 '20 17:03 neerajprad