aesara icon indicating copy to clipboard operation
aesara copied to clipboard

Add `MultinomialRV` JAX implementation

Open GStechschulte opened this issue 2 years ago • 7 comments

This draft PR is a work in progress and contains a JAX implementation of MultinomialRV for issue #1326. The implementation builds off the Multinomial Distribution implementation in NumPyro. Likewise, the output is similar to that of the numpy implementation. Below, you will find a brief outline of the functions used to construct the MultinomialRV.

def _categorical(key, p, shape)

  • returns the outcomes $k$ with probability $p$ for each trial / experiment $n$.

def _scatter_add_ones(operand, indices, update)

  • returns the outcome counts by utilising the jax.lax.scatter_add() function
  • operand is a zero filled array.
  • indices is the outcomes array with an added dimension and specifies the indices to which the update should be applied to.
  • update is an array filled with ones and can be thought of as a cnt += 1 for each $K = k$ occurrence.
  • In summary, the operand array is updated +1 using the update array according to the outcomes in the indices array.

I still need to add a test for this. Thanks!

GStechschulte avatar Dec 11 '22 16:12 GStechschulte

Thanks! We also need to figure out if the licenses are compatible and how to do proper attribution if you took inspiration from someone else's implementation. It looks like Numpyro is licensed under Apache 2.0

rlouf avatar Dec 12 '22 09:12 rlouf

Thanks! We also need to figure out if the licenses are compatible and how to do proper attribution if you took inspiration from someone else's implementation. It looks like Numpyro is licensed under Apache 2.0

Based on the NumPyro Apache License 2.0 section 4, we may reproduce and distribute copies of the Work or Derivative Works (the JAX implementation of MultinomialRV) provided we:

  1. give any other recipients of the Work or Derivative Works a copy of this License; and
  2. the modified file must contain carry a notice stating the file was changed
  3. in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work
  4. if the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation

Since inspiration was drawn from a few functions and not an entire file, I suggest we include, in addition to (1), (2), (3), in the documentation for this RV, something along the lines

MultinomialRV uses source code from the file xyz.py from <link to src code GitHub file> of the NumPyro project, copyright YYYY, licensed under the Apache 2.0 license>

GStechschulte avatar Dec 12 '22 20:12 GStechschulte

I rebased your branch on main to use the new key splitting scheme in the JAX backend. You'll have to pull the changes!

rlouf avatar Dec 14 '22 13:12 rlouf

Yes you need to git pull --rebase in such cases. Here's a good explanation of how rebasing works. And of course the documentation for git pull

rlouf avatar Dec 14 '22 21:12 rlouf

@GStechschulte I think we should follow @AdrienCorenflos's suggestion here. I'll take another look at it this week.

rlouf avatar Jan 16 '23 19:01 rlouf

@GStechschulte I rebased your branch on main. Do you plan on implementing @AdrienCorenflos's suggestion above?

rlouf avatar Feb 21 '23 15:02 rlouf

@rlouf, I had to add a case for Constants in assert_size_argument_jax_compatible. You'll need to confirm that this is valid more generally.

That's valid. I was so focused on the complex case that I forgot the simplest one.

rlouf avatar Mar 10 '23 14:03 rlouf