Unexpected PJRT_Plugin_Attributes_Args size: expected 32, got 24.
Hello @alvarosg,
I've come across an error trying to run the graphcast_demo.ipynb notebook on a TPU v5litepod-1.
The following error is triggered when the rollout is performed, likely when the model is loaded or when the forward operator is called:
pjrt_c_api_helpers.cc:258: Unexpected PJRT_Plugin_Attributes_Args size: expected 32, got 24. The plugin is likely built with a later version than the framework. This plugin is built with PJRT API version 0.75
This error ultimately leads to an Aborted (core dumped), stopping the rollout.
I feel that it has to do with my TPU not being compatible with the version of jax or libtpu libraries I have installed in my python venv.
Did you ever encounter this problem?
Thank you very much
Hi,
Found the bug.
It apparently comes from two different sources:
- the requirements mismatch in between
jaxand the underlyinglibtpuandjaxlib:- the solution is to change from
jaxtojax[tpu], to make sure the tpu librarylibtpuis installed with correct version aligned withjax. It should be mentionned in the README.md for TPU users.
- the solution is to change from
- the introduction of breaking changes in
jax-v0.7.0(https://github.com/jax-ml/jax/releases/tag/jax-v0.7.0).-
jax.stages.OutInfo has been replaced with jax.ShapeDtypeStruct.
- This change throws an error when the
autoregressivemodule is imported. The most recent version ofjaxto correctly run thegraphcast_demo.ipynbseems to bejax[tpu]==0.6.2. Thejaxversion in the requirements should be lowered to this version, waiting for a solution to make it work onjax >= v0.7.0.
-
Hey,
Apologies for the delayed response. I'm glad you found a solution.
Could you confirm whether applying this fix https://github.com/google-deepmind/graphcast/commit/7077d40a36db6541e3ed72ccaed1c0d202fa6014 in graphcast_demo.ipynb also solved the issue? In which case we will also make the change there.
Best,
Andrew