add is slow on GPU
The benchmarks show adding to a buffer is very slow on GPU (13 ms vs 0.4 ms for reverb or stable baselines, over 30x slower). Has anyone filed a bug against Jax about this?
Hello, so we haven't specifically asked the JAX maintainers about this issue. However, important to note for reverb and stable baselines that the memory is not stored on the GPU so its not really a fair comparison and it would be better to look at the CPU times. If you find the GPU add times too slow and you're not doing a fully jitted training loop then you can ensure that the flashbax buffer is stored on the CPU.
Thanks for the reply. Even comparing to flashbax on TPU, it's much much slower on GPU so might be worth filing a bug about that with JAX? I'm assuming the source data is already on the GPU when you're adding?
Yes i believe so, i did the benchmarks a while ago but I'm sure i would have created the data on device.
Hi, I would like to ask for more details regarding where we stand on in this situation.
- Currently, GPU speeds for adding single timesteps is bad? can we point where the delay happens? From this discussion I understand that it simply the jax operation that is used?
- For adding batch of timesteps we don't have these delays right?
- Is there anything we can do to improve the situation?
- Yes it's likely the underlying XLA that JAX is compiled to
- Less delays, but still could likely be better
- One possible improvement is that we could inform JAX that we are using
unique_indicesandindices_are_sortedduring our.at[].set()over here for example. I'm not sure if this would help, but the docs imply that it might. Unfortunately I don't quite have time to test this right now