numpyro icon indicating copy to clipboard operation
numpyro copied to clipboard

Support for Empirical distributions

Open AdrienCorenflos opened this issue 5 years ago • 4 comments

Hi,

I have started implementing an Empirical distribution by porting the pyro one. Kind of unsure how to do the log probability, I actually think the implementation in pyro might be a bit problematic for limited memory GPUs as it artificially creates a N * N tensor where N is the number of diracs in the distribution.

Does anyone have a better idea?

Code as it stands for now: https://colab.research.google.com/drive/171oY6j290HSRI2NR9rDmVDBfJB2UepKG?usp=sharing

AdrienCorenflos avatar Jul 15 '20 13:07 AdrienCorenflos

I would think the cleanest way to deal with the O(N * N * d) complexity of all-pairs-equality is to implement a matmul-like operation that uses less memory (or optionally uses less memory) under the hood. Naive matmul is also O(N * N * d) but is implemented via cheap kernels that perform reductions over one of those Ns locally, so the space complexity is reduced to O(N * d) and time complexity is sub-cubic.

fritzo avatar Jul 15 '20 15:07 fritzo

There used to be an Empirical distribution. Any particular reason it was removed? Thanks!

Update: I just found it under contrib.tfb.distributions, sorry for the false alarm!

ma-sadeghi avatar Dec 25 '23 13:12 ma-sadeghi

I think you can also use tfp distribution directly in numpyro models. We might want to remove contrib.tfb.distributions.Empirical in the future.

fehiepsi avatar Dec 26 '23 16:12 fehiepsi

I tried it, but it seems that the API is a bit different, (seed vs rng_key, etc)

ma-sadeghi avatar Dec 26 '23 19:12 ma-sadeghi