mujoco
mujoco copied to clipboard
mj_multiRay() using MJX
Hi, i need to cast multiple rays to simulate lidars, typically thousands of rays are casted per seconds so the algorithm must be fast, mj_multiRay() works with MJX?
mj_multiRay() using CPU is not enough fast, GPU could speed up the simulation a lot
Note that while ray casting is not yet supported in MJX (on our to-do list), the ray casting functions are completely thread safe, which can give you a significant speedup on CPU.
In case this is still helpful, I think mj_multiRay() can now be implemented by using jax.vmap on mjx.ray(). For example, please see this simple depth camera implementation.
Batching across environments, I'm getting >3 million FPS on 100x100 depth images; or 3.3e10 ray casts per second; on an NVIDIA 3060 Ti GPU. This is in a simple setting; see the depth renderings and ground truth renderings below.