jax icon indicating copy to clipboard operation
jax copied to clipboard

extensions module

Open froystig opened this issue 1 year ago • 3 comments

Several projects and libraries critically rely on jax internals, such the Jaxpr IR. Oryx is one example among many others. These projects use the jax core as a sort of library—e.g. for staging and transforming Python functions, for adding new primitives to jax, and more. But jax's internals were not exactly designed for this sort of external use.

We could organize some of our non-public symbols more like a library, and properly hide the rest (under jax._src). An idea is to do this incrementally, and to organize things under a top-level module called jax.extend, with a much weaker support policy than that of the main jax public API.

froystig avatar May 03 '23 22:05 froystig

We did a few PRs related to this recently: #17139, #17307, #17350, #17651, #17666, #17983, #18102, #18107, #18156

There's still some work to do, but I think we're beyond the point of needing a tracking issue now. Closing.

jakevdp avatar Nov 03 '23 20:11 jakevdp

Reopening, since there are a few more big ones to do (jax.core, jax.interpreters)

jakevdp avatar Nov 04 '23 02:11 jakevdp

Two more:

  • #20217
  • #20768

froystig avatar Apr 20 '24 00:04 froystig