aesara
aesara copied to clipboard
Add JAX implementation for `MultinomialRV`
Hey @rlouf and others, I will give this PR a go. I am following #1335 and #1284 for more explanations on the implementations. Likewise, I will comment here on progress and questions.
Much appreciated, @GStechschulte!
Here is an explanation of how to go about adding an implementation: https://github.com/aesara-devs/aesara/issues/1335#issuecomment-1344510169
After using NumPyro, I remembered that they have a JAX implementation of the Multinomial distribution, albeit following the design of the PyTorch distributions module. Therefore, I adapted the code to align with the parameters argument passed into the respective RV's __call__
function in this file.
I still need to add the tests. As this is my first time contributing, how should I/we handle "using" a code snippet from another library? In this case, I have used 3 functions from the NumPyro library to meet the needs of this PR. My idea is to add a reference in a doc string? Thanks!
Here is the link to my branch of the JAX implementation of MultinomialRV.
Here is the link to my branch of the JAX implementation of MultinomialRV.
Feel free to create a PR for that branch. If you're still working on it, no worries; you can make it a draft PR.
I still need to add the tests. As this is my first time contributing, how should I/we handle "using" a code snippet from another library? In this case, I have used 3 functions from the NumPyro library to meet the needs of this PR. My idea is to add a reference in a doc string? Thanks!
We'll need to make sure that the licenses are compatible and see what they require.