elegy
elegy copied to clipboard
[Bug] Installing elegy is broken by jax v0.2.21
Describe the bug
Importing elegy after a fresh install is now broken. This is because jax v0.2.21 removed jax.api
. jax.api
is imported by the version of haiku that installs with elegy (haiku v0.0.2), in the file haiku/_src/named_call.py
.
named_call.py
does not exist in newer releases of haiku, so I think upgrading the version of haiku that installs with elegy may solve the problem.
At first, I was confused as to why this old version of haiku installed with elegy, because in the master branch of the elegy github repo, pyproject.toml
lists the haiku requirement as dm-haiku = "^0.0.4"
. However, after some investigation, I discovered that the pyproject.toml
served by PyPI lists the haiku requirement as dm-haiku = "^0.0.2"
, which translates to 'dm-haiku>=0.0.2,<0.0.3'
in the setup.py
.
I assume this is a mistake, and that the error can be easily solved by publishing the version of elegy that currently sits in the master branch of the github repo.
In the meantime, there is an easy workaround for users: you can install an older version of jax first, followed by elegy. I.e., run the following code to install jax and elegy:
poetry add "jax<0.2.21"
poetry add elegy
You can then import elegy without problem.
Minimal code to reproduce
First, run
poetry new elegyTest
cd elegyTest
Then edit pyproject.toml
to set the python version to ">=3.7,<3.9"
.
Next, install elegy:
poetry add elegy
Finally, try importing elegy
poetry run python
>>> import elegy
You will get the following error:
...
File ".../elegyTest/.venv/lib/python3.7/site-packages/haiku/_src/named_call.py", line 23, in <module>
from jax import api
ImportError: cannot import name 'api' from 'jax'
Sorry it took a bit longer than expected to get the new version out. Is this still an issue?
Sorry just got around to testing this. Looks like there's still an issue.
I followed the steps as I listed them above in "Minimal code to reproduce". Upon running import elegy
I got this error:
>>> import elegy
Traceback (most recent call last):
File "/phys/users/jfc20/Documents/elegyTest/.venv/lib/python3.7/site-packages/treex/nn/haiku_module.py", line 13, in <module>
import haiku as hk
ModuleNotFoundError: No module named 'haiku'
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/phys/users/jfc20/Documents/elegyTest/.venv/lib/python3.7/site-packages/elegy/__init__.py", line 5, in <module>
from treex import *
File "/phys/users/jfc20/Documents/elegyTest/.venv/lib/python3.7/site-packages/treex/__init__.py", line 10, in <module>
from treex.nn import *
File "/phys/users/jfc20/Documents/elegyTest/.venv/lib/python3.7/site-packages/treex/nn/__init__.py", line 8, in <module>
from .haiku_module import HaikuModule
File "/phys/users/jfc20/Documents/elegyTest/.venv/lib/python3.7/site-packages/treex/nn/haiku_module.py", line 16, in <module>
raise types.OptionalDependencyNotFound("Haiku Unavailable")
treex.types.OptionalDependencyNotFound: Haiku Unavailable
So it looks like there's a bug causing Haiku to not actually be optional
I just opened #212 which hopefully fixes it. Will do a new release after that issue is merged.
#212 wasn't the solution but #213 I think definitely did the job. I tested it out on colab.
Now I get this TensorFlow import error:
>>> import elegy
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/phys/users/jfc20/.conda/envs/elegyTest/lib/python3.7/site-packages/elegy/__init__.py", line 18, in <module>
from .model.model import Model
File "/phys/users/jfc20/.conda/envs/elegyTest/lib/python3.7/site-packages/elegy/model/model.py", line 11, in <module>
from elegy.model.model_base import ModelBase
File "/phys/users/jfc20/.conda/envs/elegyTest/lib/python3.7/site-packages/elegy/model/model_base.py", line 21, in <module>
from elegy.model.model_core import ModelCore, PredStepOutput, TestStepOutput
File "/phys/users/jfc20/.conda/envs/elegyTest/lib/python3.7/site-packages/elegy/model/model_core.py", line 13, in <module>
from jax.experimental import jax2tf
File "/phys/users/jfc20/.conda/envs/elegyTest/lib/python3.7/site-packages/jax/experimental/jax2tf/__init__.py", line 16, in <module>
from jax.experimental.jax2tf.jax2tf import (convert, dtype_of_val,
File "/phys/users/jfc20/.conda/envs/elegyTest/lib/python3.7/site-packages/jax/experimental/jax2tf/jax2tf.py", line 52, in <module>
from jax.experimental.jax2tf import shape_poly_tf
File "/phys/users/jfc20/.conda/envs/elegyTest/lib/python3.7/site-packages/jax/experimental/jax2tf/shape_poly_tf.py", line 22, in <module>
import tensorflow as tf # type: ignore[import]
ModuleNotFoundError: No module named 'tensorflow'
Looks like from jax.experimental import jax2tf
in elegy/model/model_core.py
needs to be optional depending on whether you have TensorFlow installed?
Thanks a lot @jfcrenshaw for these reports! Yeah jax2tf
depends on TensorFlow, I'll refactor Elegy so jax2tf
is only imported if TF is available. Apart from that I think I will add a CI step to check that import works properly without the dev dependencies to catch this kind of stuff.
Hey @jfcrenshaw! I believe 0.8.4 finally fixes all the issues.
Thanks for all your work fixing this @cgarciae!
Unfortunately, now I get
>>> import elegy
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/phys/users/jfc20/.conda/envs/elegyTest/lib/python3.7/site-packages/elegy/__init__.py", line 5, in <module>
from treex import *
File "/phys/users/jfc20/.conda/envs/elegyTest/lib/python3.7/site-packages/treex/__init__.py", line 18, in <module>
from treex.nn import *
AttributeError: module 'treex.nn' has no attribute 'HaikuModule'
Maybe treex also needs a CI step that checks import without dev dependencies?
Thanks!
@jfcrenshaw again thanks for the reports, they are really useful!
The culprit this time was a recent change in treex
after the last fix was published so CI could not catch it. Testing import on CI was actually added in the previous PR like this:
https://github.com/poets-ai/elegy/blob/546c50475ac55dcbf9d8dd811c0536d1ad589f38/.github/workflows/ci_test.yml#L58-L87