aesara
aesara copied to clipboard
Add `MultinomialRV` JAX implementation
This draft PR is a work in progress and contains a JAX implementation of MultinomialRV
for issue #1326. The implementation builds off the Multinomial Distribution
implementation in NumPyro. Likewise, the output is similar to that of the numpy
implementation. Below, you will find a brief outline of the functions used to construct the MultinomialRV
.
def _categorical(key, p, shape)
- returns the outcomes $k$ with probability $p$ for each trial / experiment $n$.
def _scatter_add_ones(operand, indices, update)
- returns the outcome counts by utilising the
jax.lax.scatter_add()
function -
operand
is a zero filled array. -
indices
is theoutcomes
array with an added dimension and specifies the indices to which the update should be applied to. -
update
is an array filled with ones and can be thought of as acnt += 1
for each $K = k$ occurrence. - In summary, the
operand
array is updated+1
using theupdate
array according to the outcomes in theindices
array.
I still need to add a test for this. Thanks!
Thanks! We also need to figure out if the licenses are compatible and how to do proper attribution if you took inspiration from someone else's implementation. It looks like Numpyro is licensed under Apache 2.0
Thanks! We also need to figure out if the licenses are compatible and how to do proper attribution if you took inspiration from someone else's implementation. It looks like Numpyro is licensed under Apache 2.0
Based on the NumPyro Apache License 2.0 section 4, we may reproduce and distribute copies of the Work or Derivative Works (the JAX implementation of MultinomialRV) provided we:
- give any other recipients of the Work or Derivative Works a copy of this License; and
- the modified file must contain carry a notice stating the file was changed
- in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work
- if the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation
Since inspiration was drawn from a few functions and not an entire file, I suggest we include, in addition to (1), (2), (3), in the documentation for this RV, something along the lines
MultinomialRV uses source code from the file xyz.py from <link to src code GitHub file> of the NumPyro project, copyright YYYY, licensed under the Apache 2.0 license>
I rebased your branch on main
to use the new key splitting scheme in the JAX backend. You'll have to pull the changes!
Yes you need to git pull --rebase
in such cases. Here's a good explanation of how rebasing works. And of course the documentation for git pull
@GStechschulte I think we should follow @AdrienCorenflos's suggestion here. I'll take another look at it this week.
@GStechschulte I rebased your branch on main
. Do you plan on implementing @AdrienCorenflos's suggestion above?
@rlouf, I had to add a case for
Constant
s inassert_size_argument_jax_compatible
. You'll need to confirm that this is valid more generally.
That's valid. I was so focused on the complex case that I forgot the simplest one.