dreamerv3-torch icon indicating copy to clipboard operation
dreamerv3-torch copied to clipboard

Wallclock comparison for the benchmarks

Open sai-prasanna opened this issue 1 year ago • 2 comments

It would be helpful to have wallclock comparions for the benchmarks you had posted. I think Danijar's jax implementation uses jax scan heavily to make the imagination/rollouts loops efficient.

sai-prasanna avatar Nov 22 '23 21:11 sai-prasanna

Hi sai,Have you test the wall clock of this implementation? It seems much slower than the article says.

Caixy1113 avatar Nov 28 '23 10:11 Caixy1113

@sai-prasanna

Hello,

Thank you for your suggestion regarding wallclock comparisons for my benchmarks. You're absolutely right about the efficiency of Danijar's JAX implementation, particularly the use of the jax.scan function for optimizing imagination and rollout loops. This is indeed a significant factor in the performance of their implementation.

Regarding my PyTorch-based implementation, one of the key reasons it might not match the speed of the original JAX version is due to the absence of a directly comparable, efficient scan function in PyTorch. Ofcourse implementing raw cuda kernel is one of the option to optimize the efficiency.

However, fortunatelly there is a development in the pipeline for efficient scan function and it seems available in near future as discussed in the PyTorch GitHub issue 50688.

Once this function is released, I plan to integrate it into our implementation. In the meantime, for insights into the current efficiency of the code, you might find this discussion useful.

NM512 avatar Jan 08 '24 14:01 NM512