brax
brax copied to clipboard
TPU training colab is not working for me
TPU training colab stopped working for me after one of the updates in recent months. It stalls when comes to the training cell it shows that some work is going on but it never finishes, and no plots and trained policies are produced.
Hi, tried this TPU training colab today, and it fails with this error:
AttributeError Traceback (most recent call last)
[<ipython-input-3-9ce5fdb19302>](https://localhost:8080/#) in <module>
14
15 try:
---> 16 import brax
17 except ImportError:
18 get_ipython().system('pip install git+https://github.com/google/brax.git@main')
2 frames
[/usr/local/lib/python3.9/dist-packages/brax/jumpy.py](https://localhost:8080/#) in <module>
504
505
--> 506 def where(condition: jax.typing.ArrayLike, x: jax.typing.ArrayLike,
507 y: jax.typing.ArrayLike) -> ndarray:
508 """Return elements chosen from `x` or `y` depending on `condition`."""
AttributeError: module 'jax' has no attribute 'typing'
Hi @erwincoumans! I'm not able to reproduce the issue, which jax version are you using?
I'm just running the public colab, following the links on the github front page. Did you try training it using a public TPU runtime? The other public colab (training using PyTorch) is also still broken, see the other issue.
data:image/s3,"s3://crabby-images/30144/301446fdb4e2382e72401090bc4310259981e0b6" alt="image"
https://colab.research.google.com/github/google/brax/blob/main/notebooks/training.ipynb
Just tried it again, here is the output:
---------------------------------------------------------------------------
ModuleNotFoundError Traceback (most recent call last)
[<ipython-input-1-9ce5fdb19302>](https://localhost:8080/#) in <module>
15 try:
---> 16 import brax
17 except ImportError:
ModuleNotFoundError: No module named 'brax'
During handling of the above exception, another exception occurred:
AttributeError Traceback (most recent call last)
3 frames
[/usr/local/lib/python3.9/dist-packages/brax/jumpy.py](https://localhost:8080/#) in <module>
504
505
--> 506 def where(condition: jax.typing.ArrayLike, x: jax.typing.ArrayLike,
507 y: jax.typing.ArrayLike) -> ndarray:
508 """Return elements chosen from `x` or `y` depending on `condition`."""
AttributeError: module 'jax' has no attribute 'typing'
Ok thanks for the pointer! It turns out that jax>=0.4.6 is incompatible with public colab TPU runtimes (see https://stackoverflow.com/a/75734517). We're pinning the jax/jaxlib versions to >=0.4.6 now, so it's best to run in a GPU runtime for the time being until the colab issue is fixed I've confirmed training works on GPU in a public colab runtime