brax icon indicating copy to clipboard operation
brax copied to clipboard

TPU/GPU visualization for learning from pixels

Open erikfrey opened this issue 4 years ago • 6 comments
trafficstars

Brax's rendering is CPU-only at the moment. For agents with vision, it would be useful for rendering to happen on accelerator so that training can stay fast.

Vispy may be useful, but a fast native jax renderer would be even better for data on device compatability.

erikfrey avatar Nov 10 '21 23:11 erikfrey

Thanks for the great simulator. Any update on TPU/GPU visualization to enable learning from pixels?

masud99r avatar Jan 19 '22 18:01 masud99r

Hi! I would be happy to try and help, any ideas where to start? Could this or this be of any use?

I also like the idea of having a native jax solution as it would potentially enable differentiable rendering (analytical policy gradients from images, anyone?)

yardenas avatar Apr 15 '22 08:04 yardenas

Hi @yardenas and @masud99r - JAX rendering is still an area of active interest for us! One of those libraries might be a good starting point, although you'll also need to generate meshes from our brax primitives like capsules.

You can find examples of how to do that in our current cpu-only renderer code: https://github.com/erwincoumans/tinyrenderer

If you get on-device rendering going via JAX, please share with us a colab so we can check it out!

erikfrey avatar Apr 19 '22 04:04 erikfrey

Thanks @erikfrey for the hints!

So just trying to get a better understanding: basically we would need to implement functions like create_capsule (e.g. here and its implementation here) which takes in the radius, half_height etc., but have it written in JAX, instead of C++, right?

yardenas avatar Apr 19 '22 09:04 yardenas

Brax's rendering is CPU-only at the moment. For agents with vision, it would be useful for rendering to happen on accelerator so that training can stay fast.

Vispy may be useful, but a fast native jax renderer would be even better for data on device compatability.

I am currently looking in this direction for my personal projects. I made a post there how one might do this efficiently. Might be interesting here too. Issue over at VisPy

simon-bachhuber avatar Jan 31 '23 22:01 simon-bachhuber

Hi, I have created a pure JAX implementation and an adapter layer mimicking the behaviour of existing pytinyrenderer. The code will be released under Apache-2.0 License here. I will be working on a PR to replace current visualisation code in Brax soon.

A working example can be found under the root directory, e.g., this. A Colab is also available now, which adopts exact examples of the existing pytinyrenderer's.

Update: I have opened a draft PR #367 . There is still minor issue (plane may not be rendered correctly; performance issue), but it is working now. I will address the plane rendering issue first, then performance. The rendering issue is fixed. I'll be working on the performance improvements.

JoeyTeng avatar May 29 '23 21:05 JoeyTeng