accelerated-scan
accelerated-scan copied to clipboard
Training Fast Weight Programmers by backpropagating through the Delta Rule
Introducing a kernel for training a fast weight programmer by backpropagating through the delta rule (online linear regression) with @ischlag.
Improving on top of first order recurrence with scalar hidden state, this kernel uses vector-valued updates like the transformer, allowing use of matrix multiplication hardware, and avoiding saturation of capacity of the fast weights network, thanks to the delta rule.
The implementation uses chunking provided by @sustcsonglin's equation 6 and currently excels at the head dimension of 32, perfectly fitting into the registers of a 3090 warp. The code for tensor cores uses ThunderKittens which enable effortless WMMA.