brax icon indicating copy to clipboard operation
brax copied to clipboard

JaxToTorchWrapper error with jax 0.4.1

Open jypark0 opened this issue 2 years ago • 0 comments

Hello, I am trying to run some experiments using pytorch with the JaxToTorchWrapper. I'm running the default Training in Brax with PyTorch on GPUs on a local jupyter instance, but there are errors.

image

The error occurs when I use jax==0.4.1 and goes away when I use a lower version. It seems like it has to do with the new jax.Array type introduced in 0.4.1.

Environment:

  • Python 3.10.7
  • Cuda 11.8
  • jax[cuda]
  • brax==0.0.16

Thanks!

jypark0 avatar Dec 19 '22 02:12 jypark0