brax
brax copied to clipboard
JaxToTorchWrapper error with jax 0.4.1
Hello,
I am trying to run some experiments using pytorch with the JaxToTorchWrapper.
I'm running the default Training in Brax with PyTorch on GPUs on a local jupyter instance, but there are errors.
The error occurs when I use jax==0.4.1 and goes away when I use a lower version.
It seems like it has to do with the new jax.Array type introduced in 0.4.1.
Environment:
- Python 3.10.7
- Cuda 11.8
jax[cuda]brax==0.0.16
Thanks!