graphcast
graphcast copied to clipboard
Jax cannot access GPU when generating predictions
When generating predictions, this error comes up:
WARNING:jax._src.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Using an A100 environment on colab with the stock notebook except with the modifications mentioned here.
Any luck on this? I'm having a similar issue.
Maybe is the version of cuda and cudnn not match the jaxlib version? You can find right version here. By the way, you need to uninstall jax&jaxlib of old version first.