alpa icon indicating copy to clipboard operation
alpa copied to clipboard

ImportError: cannot import name 'FrozenDict' from 'jax.experimental.maps'

Open carlosgmartin opened this issue 1 year ago • 4 comments

$ python -c "import alpa"
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/usr/local/lib/python3.10/site-packages/alpa/__init__.py", line 3, in <module>
    from . import api
  File "/usr/local/lib/python3.10/site-packages/alpa/api.py", line 10, in <module>
    from jax.experimental.maps import FrozenDict
ImportError: cannot import name 'FrozenDict' from 'jax.experimental.maps' (/usr/local/lib/python3.10/site-packages/jax/experimental/maps.py)

System information and environment

  • macOS 11.7.4
  • Python 3.10.11
  • alpha 0.2.3
  • jax 0.4.8
  • jaxlib 0.4.7

I do not have CUDA or any GPUs on this device.

carlosgmartin avatar Apr 21 '23 21:04 carlosgmartin

Alpa doesn't work with JAX 0.4.8. https://alpa.ai/install.html#install-from-wheels

gjoliver avatar Apr 21 '23 22:04 gjoliver

@gjoliver Thanks. Are there any plans to add such support in the near future?

carlosgmartin avatar Apr 21 '23 22:04 carlosgmartin

Working is underway.

On Apr 21, 2023, at 3:30 PM, Carlos Martin @.***> wrote:



@gjoliverhttps://github.com/gjoliver Thanks. Are there any plans to add such support in the near future?

— Reply to this email directly, view it on GitHubhttps://github.com/alpa-projects/alpa/issues/922#issuecomment-1518392146, or unsubscribehttps://github.com/notifications/unsubscribe-auth/ABQNWQPKD5NJZ7UAFPNQMTTXCMDCBANCNFSM6AAAAAAXHLN2PI. You are receiving this because you were mentioned.Message ID: @.***>

gjoliver avatar Apr 21 '23 23:04 gjoliver

It seems a lot of the issue with newer jax version is from jax's api being moved around. Things like get_backend, xla_bridge, etc being moved to other places. Besides those changes, is there any other changes required? If it's just those changes, I think I can make a PR for that.

For example, doing a jax version check to decide whether to do from jax._src.maps import FrozenDict or from jax._src.maps import FrozenDict

Edit: Besides those change, some functions had been removed entirely: _check_callable named_call_p

Lime-Cakes avatar May 06 '23 11:05 Lime-Cakes