brax icon indicating copy to clipboard operation
brax copied to clipboard

'Brax Training with PyTorch on GPU' notebook fails to execute

Open BenGutteridge opened this issue 3 years ago • 1 comments

Due to a "can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first." error in pyplot, but fixing this raises another error.

BenGutteridge avatar Apr 07 '22 08:04 BenGutteridge

Thanks for pointing this out. Both errors are addressed in a660b273f21f305b2d86e440f532a300775835f5 and the colab runs again.

erikfrey avatar May 06 '22 22:05 erikfrey

Hi, trying running the same 'PyTorch on GPU' notebook today, fails with this message (btw pytinyrenderer is fixed now in 0.0.14)


---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
[<ipython-input-1-57d223ce7d8e>](https://localhost:8080/#) in <module>
     19 from brax.envs import to_torch
     20 from brax.io import metrics
---> 21 from brax.training.agents.ppo import train as ppo
     22 import gym
     23 import matplotlib.pyplot as plt

3 frames
[/usr/local/lib/python3.9/dist-packages/brax/v2/envs/env.py](https://localhost:8080/#) in <module>
     21 from brax.v2 import base
     22 from brax.v2.generalized import pipeline as g_pipeline
---> 23 from brax.v2.positional import pipeline as p_pipeline
     24 from brax.v2.spring import pipeline as s_pipeline
     25 from flax import struct

ModuleNotFoundError: No module named 'brax.v2.positional'

---------------------------------------------------------------------------
NOTE: If your import is failing due to a missing package, you can
manually install dependencies using either !pip or !apt.

To view examples of installing some common dependencies, click the
"Open Examples" button below.
---------------------------------------------------------------------------

erwincoumans avatar Mar 20 '23 19:03 erwincoumans

Thanks for pointing this out @erwincoumans! I've fixed the brax.v2.positional import

btaba avatar Mar 20 '23 22:03 btaba

I just ran it again (using Colab public GPU runtime) and got this error:

image image

/usr/local/lib/python3.9/dist-packages/gym/envs/registration.py:440: UserWarning: WARN: The `registry.env_specs` property along with `EnvSpecTree` is deprecated. Please use `registry` directly as a dictionary instead.
  logger.warn(
---------------------------------------------------------------------------
XlaRuntimeError                           Traceback (most recent call last)
[<ipython-input-5-4076a06b45b0>](https://localhost:8080/#) in <module>
     21   plt.show()
     22 
---> 23 train(progress_fn=progress)
     24 
     25 print(f'time to jit: {times[1] - times[0]}')

28 frames
[/usr/local/lib/python3.9/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in backend_compile(backend, built_c, options, host_callbacks)
   1034   # TODO(sharadmv): remove this fallback when all backends allow `compile`
   1035   # to take in `host_callbacks`
-> 1036   return backend.compile(built_c, compile_options=options)
   1037 
   1038 _ir_dump_counter = itertools.count()

XlaRuntimeError: INTERNAL: RET_CHECK failure (external/org_tensorflow/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc:641) dnn != nullptr

erwincoumans avatar Mar 21 '23 15:03 erwincoumans

Hi @erwincoumans @BenGutteridge , erikfrey pushed out a new pytorch colab recently https://colab.sandbox.google.com/github/google/brax/blob/main/notebooks/training_torch.ipynb Just tested that it runs on GPU, sorry this took a while!

btaba avatar May 02 '23 18:05 btaba