mesh-transformer-jax icon indicating copy to clipboard operation
mesh-transformer-jax copied to clipboard

AttributeError: module 'jaxlib.pocketfft' has no attribute 'pocketfft'

Open umm-maybe opened this issue 2 years ago • 4 comments

Hello, I have followed the (very much appreciated) howto_finetune.md guide and, upon attempting to run the magic python device_train.py command, received the error noted above. The only Google search result that seems to mention something similar is this: https://bytemeta.vip/repo/deepmind/alphafold/issues/515

The answer to that question seems to imply it has to do with a version incompatibility between jax and jaxlib, but the solution they link to doesn't work here. Any tips or advice for working around this would be greatly appreciated!

umm-maybe avatar Jul 16 '22 02:07 umm-maybe

From the top of my head; pip install jax==0.2.12 jaxlib==0.1.67 Can not try right now, but that version combination should work on TPU-VM.

Edit: I think it also has to do with what Python (3.7 on TPU v2 and Colab, 3.8 on v3) version you're running and what TPU-version / accelerator-type. I think I've seen jaxlib==0.1.68 in v2 setups, so also worth a shot.

Ontopic avatar Aug 03 '22 07:08 Ontopic

I'm also using a TPU v2 setup and ran into this problem. I used the JAX TPU install instructions from their README and it worked for me.

dunstantom avatar Aug 16 '22 14:08 dunstantom

Now I'm also getting "AttributeError: module 'jax' has no attribute 'version'"... Or, also: AttributeError: module 'jaxlib.pocketfft' has no attribute 'pocketfft'. Tried couple of different colab notebooks... Doesn't work...

sxiii avatar Aug 18 '22 13:08 sxiii

I fixed it by doing this right after the install dependencies section:

!pip install jaxlib==0.1.67

And restart the runtime if it asks

Though it feels so fragile. Don't know why

musabgultekin avatar Oct 11 '22 13:10 musabgultekin