jax icon indicating copy to clipboard operation
jax copied to clipboard

shape hint arguments in lax numpy functions to enable compilation

Open froystig opened this issue 4 years ago • 1 comments

Several functions in jax.numpy are not jit-compilable for shape-level arguments. Specifically, we can't decide their output shape from input shapes alone.

We've given jax.numpy.bincount an optional length argument that's not in standard numpy because it determines the output shape (when it is static/concrete). It's not uncommon for callers to have this information on hand.

We could consider extending other functions similarly to accept optional shape hints:

  • [x] argwhere (#6915)
  • [x] compress (#21090)
  • [x] extract (#21090)
  • [x] nonzero (#6501 #7592)
  • [x] flatnonzero (#6913)
  • [x] 1-argument where (#6910)
  • [x] repeat (#3670)
  • [x] unique (#6912, #8121, #8186)
  • [x] setdiff1d (#8144)
  • [x] union1d (#6930, #8143)
  • [ ] intersect1d
  • [ ] setxor1d

froystig avatar Jun 30 '20 21:06 froystig

For something like unique, what would the other elements of the result be filled with if not unique values? It seems like we probably would need to add another parameter for indicating the fill_value.

shoyer avatar Jun 30 '20 22:06 shoyer