accelerated-scan icon indicating copy to clipboard operation
accelerated-scan copied to clipboard

Training Fast Weight Programmers by backpropagating through the Delta Rule

Open proger opened this issue 1 year ago • 0 comments

Introducing a kernel for training a fast weight programmer by backpropagating through the delta rule (online linear regression) with @ischlag.

Improving on top of first order recurrence with scalar hidden state, this kernel uses vector-valued updates like the transformer, allowing use of matrix multiplication hardware, and avoiding saturation of capacity of the fast weights network, thanks to the delta rule.

The implementation uses chunking provided by @sustcsonglin's equation 6 and currently excels at the head dimension of 32, perfectly fitting into the registers of a 3090 warp. The code for tensor cores uses ThunderKittens which enable effortless WMMA.

proger avatar Aug 20 '24 05:08 proger