brax
brax copied to clipboard
'Brax Training with PyTorch on GPU' notebook fails to execute
Due to a "can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first." error in pyplot, but fixing this raises another error.
Thanks for pointing this out. Both errors are addressed in a660b273f21f305b2d86e440f532a300775835f5 and the colab runs again.
Hi, trying running the same 'PyTorch on GPU' notebook today, fails with this message (btw pytinyrenderer is fixed now in 0.0.14)
---------------------------------------------------------------------------
ModuleNotFoundError Traceback (most recent call last)
[<ipython-input-1-57d223ce7d8e>](https://localhost:8080/#) in <module>
19 from brax.envs import to_torch
20 from brax.io import metrics
---> 21 from brax.training.agents.ppo import train as ppo
22 import gym
23 import matplotlib.pyplot as plt
3 frames
[/usr/local/lib/python3.9/dist-packages/brax/v2/envs/env.py](https://localhost:8080/#) in <module>
21 from brax.v2 import base
22 from brax.v2.generalized import pipeline as g_pipeline
---> 23 from brax.v2.positional import pipeline as p_pipeline
24 from brax.v2.spring import pipeline as s_pipeline
25 from flax import struct
ModuleNotFoundError: No module named 'brax.v2.positional'
---------------------------------------------------------------------------
NOTE: If your import is failing due to a missing package, you can
manually install dependencies using either !pip or !apt.
To view examples of installing some common dependencies, click the
"Open Examples" button below.
---------------------------------------------------------------------------
Thanks for pointing this out @erwincoumans! I've fixed the brax.v2.positional import
I just ran it again (using Colab public GPU runtime) and got this error:
/usr/local/lib/python3.9/dist-packages/gym/envs/registration.py:440: UserWarning: WARN: The `registry.env_specs` property along with `EnvSpecTree` is deprecated. Please use `registry` directly as a dictionary instead.
logger.warn(
---------------------------------------------------------------------------
XlaRuntimeError Traceback (most recent call last)
[<ipython-input-5-4076a06b45b0>](https://localhost:8080/#) in <module>
21 plt.show()
22
---> 23 train(progress_fn=progress)
24
25 print(f'time to jit: {times[1] - times[0]}')
28 frames
[/usr/local/lib/python3.9/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in backend_compile(backend, built_c, options, host_callbacks)
1034 # TODO(sharadmv): remove this fallback when all backends allow `compile`
1035 # to take in `host_callbacks`
-> 1036 return backend.compile(built_c, compile_options=options)
1037
1038 _ir_dump_counter = itertools.count()
XlaRuntimeError: INTERNAL: RET_CHECK failure (external/org_tensorflow/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc:641) dnn != nullptr
Hi @erwincoumans @BenGutteridge , erikfrey pushed out a new pytorch colab recently https://colab.sandbox.google.com/github/google/brax/blob/main/notebooks/training_torch.ipynb Just tested that it runs on GPU, sorry this took a while!