axlearn
axlearn copied to clipboard
Remove usage of jax.core and jax.interpreters
Those were internal APIs and will be deprecated / removed.
jax.core.Primitive should use jax.extend.core.Primitive
There is also an isinstance check for jax.core.Tracer which may require some more thought on how to replace the logic.
I can help with this work in collaboration with Jax team.
Relevant info:
- https://jax.readthedocs.io/en/latest/api_compatibility.html
- https://jax.readthedocs.io/en/latest/jax.extend.html