jax
jax copied to clipboard
shape hint arguments in lax numpy functions to enable compilation
Several functions in jax.numpy
are not jit
-compilable for shape-level arguments. Specifically, we can't decide their output shape from input shapes alone.
We've given jax.numpy.bincount
an optional length
argument that's not in standard numpy because it determines the output shape (when it is static/concrete). It's not uncommon for callers to have this information on hand.
We could consider extending other functions similarly to accept optional shape hints:
- [x] argwhere (#6915)
- [x] compress (#21090)
- [x] extract (#21090)
- [x] nonzero (#6501 #7592)
- [x] flatnonzero (#6913)
- [x] 1-argument where (#6910)
- [x] repeat (#3670)
- [x] unique (#6912, #8121, #8186)
- [x] setdiff1d (#8144)
- [x] union1d (#6930, #8143)
- [ ] intersect1d
- [ ] setxor1d
For something like unique
, what would the other elements of the result be filled with if not unique values? It seems like we probably would need to add another parameter for indicating the fill_value
.