numpyro
numpyro copied to clipboard
Support for Empirical distributions
Hi,
I have started implementing an Empirical distribution by porting the pyro one. Kind of unsure how to do the log probability, I actually think the implementation in pyro might be a bit problematic for limited memory GPUs as it artificially creates a N * N tensor where N is the number of diracs in the distribution.
Does anyone have a better idea?
Code as it stands for now: https://colab.research.google.com/drive/171oY6j290HSRI2NR9rDmVDBfJB2UepKG?usp=sharing
I would think the cleanest way to deal with the O(N * N * d) complexity of all-pairs-equality is to implement a matmul-like operation that uses less memory (or optionally uses less memory) under the hood. Naive matmul is also O(N * N * d) but is implemented via cheap kernels that perform reductions over one of those Ns locally, so the space complexity is reduced to O(N * d) and time complexity is sub-cubic.
There used to be an Empirical distribution. Any particular reason it was removed? Thanks!
Update: I just found it under contrib.tfb.distributions, sorry for the false alarm!
I think you can also use tfp distribution directly in numpyro models. We might want to remove contrib.tfb.distributions.Empirical in the future.
I tried it, but it seems that the API is a bit different, (seed vs rng_key, etc)