Another implementation of PScan operation
Hello, @dvruette! Thank you for releasing your work on BarrelRec, it's very interesting. I was working on the similar topic and recently released my O(T log T) implementation of PScan utilizing Fast Fourier Transformation under the hood. Also it might be interesting to check this work by @fheinsen
Awesome, thanks for the pointer! Have you tried comparing the speed (forward/backward) of the reference version compared to the FFT implementation? Will definitely try it out.
Not yet, unfortunately, I was looking for a codebase to benchmark the performance in the real life setting, and that's how I discovered your repository ahah. Hope to work on it in next 3-4 days
Haha perfect! I’ll definitely try it out and report performance, but it would be nice to know how much of a difference it makes in isolation as well.
Perfect, thanks! I'll do the isolated benchmarking
Unfortunately it seems like the FFT version is even slower, at least in its current form. I'm getting 1.5it/s with pscan_fft_efficient and 3.8it/s with the current pscan.
Haven't done any profiling as of yet, so I'm not sure whether it's the forward or backward pass that's slow (or both), but FWIW the current pscan implementation implements a custom backward pass for improved speed.
I have few ideas on how to make it more efficient. For instance tensors W = UA and V = -UA + A_log can be precomputed during the initialization of the layer to avoid repetitive computations during forward. Could you please let me join your PR to run some experiments? I might have time closer to the night today
Which PR are you referring to? I usually just push to master, have uploaded the pscan_fft implementation now. Feel free to fork and open a PR yourself!
Sorry for the confusion, I thought you were working in a private branch. But it does not matter, will implement a faster version soon, thank you!
Hey, @dvruette, Happy New Year!
As promised, I updated the FFT implementation and added another one based on the cumsum operation, which turned out to be significantly more efficient than FFT-based. See the PR for more details. I add the benchmarking of different approaches w.r.t. different sequence lengths, $T$