qax icon indicating copy to clipboard operation
qax copied to clipboard

"module 'jax.core' has no attribute 'new_main'" using jax>=0.4.36

Open bminixhofer opened this issue 11 months ago • 4 comments

Hello! Thanks for this library, I've been using it for a while (via Lorax), but I now get an error after updating to the latest Jax version.

Apparently Jax 0.4.36 removed jax.core.new_main which causes qax to fail:

File ~/.local/lib/python3.10/site-packages/qax/implicit/implicit_array.py:32, in _implicit_outer(*in_vals)                                                                                                                                                                           [7/1924]
     30 @lu.transformation                                                                                                                                                                                                                                                                   
     31 def _implicit_outer(*in_vals):                                                                                                                                                                                                                                                       
---> 32     with core.new_main(ImplicitArrayTrace) as main:                                                                                                                                                                                                                                  
     33         outs = yield (main, *in_vals), {}                                                                                                                                                                                                                                            
     34         del main                                                                                                                                                                                                                                                                     

File ~/.local/lib/python3.10/site-packages/jax/_src/deprecations.py:57, in deprecation_getattr.<locals>.getattr(name)
     55   warnings.warn(message, DeprecationWarning, stacklevel=2)
     56   return fn
---> 57 raise AttributeError(f"module {module!r} has no attribute {name!r}")

AttributeError: module 'jax.core' has no attribute 'new_main'

Is it (easily) possible to support Jax 0.4.36 and up? I unfortunately don't know enough about Jax internals to judge.

bminixhofer avatar Jan 14 '25 18:01 bminixhofer

I think it'll be a little involved, so if time is an issue for you I'd recommend just pinning to a lower JAX version if that's an option.

davisyoshida avatar Jan 17 '25 04:01 davisyoshida

Done for now! I needed some extra functionality from a later version but I just re-implement that manually.

bminixhofer avatar Jan 19 '25 11:01 bminixhofer

Hi @davisyoshida,

I hope you're doing well! I wanted to reach out regarding qax usage in fjformer for quantization workflows within the EasyDel project.

After the 0.4.35 release, I noticed some delays in updates for JAX >=0.4.36 compatibility. To unblock our work, I developed a temporary fork called jaximus to address immediate needs.

If it aligns with your roadmap, I’d be happy to:

  1. Open a pull request to update qax core for JAX 0.5.0 support.
  2. Contribute back any changes from jaximus that might benefit the main library.

Let me know your thoughts

erfanzar avatar Feb 15 '25 15:02 erfanzar

@erfanzar Sure I'd be happy to take a look at a PR!

davisyoshida avatar Feb 16 '25 01:02 davisyoshida