barrel-rec-pytorch icon indicating copy to clipboard operation
barrel-rec-pytorch copied to clipboard

Another implementation of PScan operation

Open maximzubkov opened this issue 1 year ago • 9 comments

Hello, @dvruette! Thank you for releasing your work on BarrelRec, it's very interesting. I was working on the similar topic and recently released my O(T log T) implementation of PScan utilizing Fast Fourier Transformation under the hood. Also it might be interesting to check this work by @fheinsen

maximzubkov avatar Dec 29 '23 00:12 maximzubkov

Awesome, thanks for the pointer! Have you tried comparing the speed (forward/backward) of the reference version compared to the FFT implementation? Will definitely try it out.

dvruette avatar Dec 29 '23 00:12 dvruette

Not yet, unfortunately, I was looking for a codebase to benchmark the performance in the real life setting, and that's how I discovered your repository ahah. Hope to work on it in next 3-4 days

maximzubkov avatar Dec 29 '23 00:12 maximzubkov

Haha perfect! I’ll definitely try it out and report performance, but it would be nice to know how much of a difference it makes in isolation as well.

dvruette avatar Dec 29 '23 08:12 dvruette

Perfect, thanks! I'll do the isolated benchmarking

maximzubkov avatar Dec 29 '23 08:12 maximzubkov

Unfortunately it seems like the FFT version is even slower, at least in its current form. I'm getting 1.5it/s with pscan_fft_efficient and 3.8it/s with the current pscan.

Haven't done any profiling as of yet, so I'm not sure whether it's the forward or backward pass that's slow (or both), but FWIW the current pscan implementation implements a custom backward pass for improved speed.

dvruette avatar Dec 29 '23 14:12 dvruette

I have few ideas on how to make it more efficient. For instance tensors W = UA and V = -UA + A_log can be precomputed during the initialization of the layer to avoid repetitive computations during forward. Could you please let me join your PR to run some experiments? I might have time closer to the night today

maximzubkov avatar Dec 29 '23 14:12 maximzubkov

Which PR are you referring to? I usually just push to master, have uploaded the pscan_fft implementation now. Feel free to fork and open a PR yourself!

dvruette avatar Dec 29 '23 14:12 dvruette

Sorry for the confusion, I thought you were working in a private branch. But it does not matter, will implement a faster version soon, thank you!

maximzubkov avatar Dec 29 '23 14:12 maximzubkov

Hey, @dvruette, Happy New Year! As promised, I updated the FFT implementation and added another one based on the cumsum operation, which turned out to be significantly more efficient than FFT-based. See the PR for more details. I add the benchmarking of different approaches w.r.t. different sequence lengths, $T$

maximzubkov avatar Jan 02 '24 16:01 maximzubkov