mujoco icon indicating copy to clipboard operation
mujoco copied to clipboard

mj_multiRay() using MJX

Open stm32f303ret6 opened this issue 1 year ago • 2 comments

Hi, i need to cast multiple rays to simulate lidars, typically thousands of rays are casted per seconds so the algorithm must be fast, mj_multiRay() works with MJX?

mj_multiRay() using CPU is not enough fast, GPU could speed up the simulation a lot

stm32f303ret6 avatar Nov 16 '23 03:11 stm32f303ret6

Note that while ray casting is not yet supported in MJX (on our to-do list), the ray casting functions are completely thread safe, which can give you a significant speedup on CPU.

yuvaltassa avatar Nov 22 '23 18:11 yuvaltassa

In case this is still helpful, I think mj_multiRay() can now be implemented by using jax.vmap on mjx.ray(). For example, please see this simple depth camera implementation.

Batching across environments, I'm getting >3 million FPS on 100x100 depth images; or 3.3e10 ray casts per second; on an NVIDIA 3060 Ti GPU. This is in a simple setting; see the depth renderings and ground truth renderings below.

mjx_drendering mj_rendering

Andrew-Luo1 avatar Mar 10 '24 08:03 Andrew-Luo1