dreamerv3-torch
dreamerv3-torch copied to clipboard
Wallclock comparison for the benchmarks
It would be helpful to have wallclock comparions for the benchmarks you had posted. I think Danijar's jax implementation uses jax scan heavily to make the imagination/rollouts loops efficient.
Hi sai,Have you test the wall clock of this implementation? It seems much slower than the article says.
@sai-prasanna
Hello,
Thank you for your suggestion regarding wallclock comparisons for my benchmarks. You're absolutely right about the efficiency of Danijar's JAX implementation, particularly the use of the jax.scan function for optimizing imagination and rollout loops. This is indeed a significant factor in the performance of their implementation.
Regarding my PyTorch-based implementation, one of the key reasons it might not match the speed of the original JAX version is due to the absence of a directly comparable, efficient scan function in PyTorch. Ofcourse implementing raw cuda kernel is one of the option to optimize the efficiency.
However, fortunatelly there is a development in the pipeline for efficient scan function and it seems available in near future as discussed in the PyTorch GitHub issue 50688.
Once this function is released, I plan to integrate it into our implementation. In the meantime, for insights into the current efficiency of the code, you might find this discussion useful.