"module 'jax.core' has no attribute 'new_main'" using jax>=0.4.36
Hello! Thanks for this library, I've been using it for a while (via Lorax), but I now get an error after updating to the latest Jax version.
Apparently Jax 0.4.36 removed jax.core.new_main which causes qax to fail:
File ~/.local/lib/python3.10/site-packages/qax/implicit/implicit_array.py:32, in _implicit_outer(*in_vals) [7/1924]
30 @lu.transformation
31 def _implicit_outer(*in_vals):
---> 32 with core.new_main(ImplicitArrayTrace) as main:
33 outs = yield (main, *in_vals), {}
34 del main
File ~/.local/lib/python3.10/site-packages/jax/_src/deprecations.py:57, in deprecation_getattr.<locals>.getattr(name)
55 warnings.warn(message, DeprecationWarning, stacklevel=2)
56 return fn
---> 57 raise AttributeError(f"module {module!r} has no attribute {name!r}")
AttributeError: module 'jax.core' has no attribute 'new_main'
Is it (easily) possible to support Jax 0.4.36 and up? I unfortunately don't know enough about Jax internals to judge.
I think it'll be a little involved, so if time is an issue for you I'd recommend just pinning to a lower JAX version if that's an option.
Done for now! I needed some extra functionality from a later version but I just re-implement that manually.
Hi @davisyoshida,
I hope you're doing well! I wanted to reach out regarding qax usage in fjformer for quantization workflows within the EasyDel project.
After the 0.4.35 release, I noticed some delays in updates for JAX >=0.4.36 compatibility. To unblock our work, I developed a temporary fork called jaximus to address immediate needs.
If it aligns with your roadmap, I’d be happy to:
- Open a pull request to update qax core for JAX 0.5.0 support.
- Contribute back any changes from
jaximusthat might benefit the main library.
Let me know your thoughts
@erfanzar Sure I'd be happy to take a look at a PR!