jax
jax copied to clipboard
JAX implementation of scipy.stats.multinomial
Please:
- [X] Check for duplicate requests.
- [X] Describe your goal, and if possible provide a code snippet with a motivating example.
Hi all,
Whilst working on a recent project I realised that scipy.stats.multinomial doesn't have a jax.scipy implementation.
I thought it'd be worthwhile to implement as well as a decent first contribution.
I think I'm largely done on the work here but thought to make an issue before making a PR.
Is there any reason this hasn't been implemented thus far?
I don't know of any reason it hasn't been implemented - feel free to open a PR if you'd like to contribute!