mesh-transformer-jax
mesh-transformer-jax copied to clipboard
AttributeError: module 'jaxlib.pocketfft' has no attribute 'pocketfft'
Hello, I have followed the (very much appreciated) howto_finetune.md guide and, upon attempting to run the magic python device_train.py command, received the error noted above. The only Google search result that seems to mention something similar is this: https://bytemeta.vip/repo/deepmind/alphafold/issues/515
The answer to that question seems to imply it has to do with a version incompatibility between jax and jaxlib, but the solution they link to doesn't work here. Any tips or advice for working around this would be greatly appreciated!
From the top of my head; pip install jax==0.2.12 jaxlib==0.1.67
Can not try right now, but that version combination should work on TPU-VM.
Edit: I think it also has to do with what Python (3.7 on TPU v2 and Colab, 3.8 on v3) version you're running and what TPU-version / accelerator-type. I think I've seen jaxlib==0.1.68
in v2 setups, so also worth a shot.
I'm also using a TPU v2 setup and ran into this problem. I used the JAX TPU install instructions from their README and it worked for me.
Now I'm also getting "AttributeError: module 'jax' has no attribute 'version'"... Or, also:
AttributeError: module 'jaxlib.pocketfft' has no attribute 'pocketfft'
. Tried couple of different colab notebooks... Doesn't work...
I fixed it by doing this right after the install dependencies section:
!pip install jaxlib==0.1.67
And restart the runtime if it asks
Though it feels so fragile. Don't know why