jax
jax copied to clipboard
extensions module
Several projects and libraries critically rely on jax internals, such the Jaxpr IR. Oryx is one example among many others. These projects use the jax core as a sort of library—e.g. for staging and transforming Python functions, for adding new primitives to jax, and more. But jax's internals were not exactly designed for this sort of external use.
We could organize some of our non-public symbols more like a library, and properly hide the rest (under jax._src
). An idea is to do this incrementally, and to organize things under a top-level module called jax.extend
, with a much weaker support policy than that of the main jax
public API.
We did a few PRs related to this recently: #17139, #17307, #17350, #17651, #17666, #17983, #18102, #18107, #18156
There's still some work to do, but I think we're beyond the point of needing a tracking issue now. Closing.
Reopening, since there are a few more big ones to do (jax.core
, jax.interpreters
)
Two more:
- #20217
- #20768