Further improvements to attention backward
Backward kernel where threads reuse data in registers to reduce memory transfers.
This PR is build on top of my previous PRs, which should be merged first. Once that is done, I'll rebase and remove the draft status here. But I need the changes to backward pass memory allocation, otherwise I cannot profile the backward pass because I get OOMs; and also, I need to be able to assume that the kernel writes (=) instead of accumulate (+=) its gradients.
I merged the previous PR, so this one should be ready.
ACK on using = instead of += in the backward pass. I didn't even realize originally that this would have dramatic performance impacts, but it makes sense in retrospect. There are only a few tensors in the graph where += is necessary, where gradients have to add - at the residuals, and for the tensor wte, which is used both for token embeddings and for the final matmul, due to the weight sharing scheme. Otherwise it's okay to just set them given the graph we have.
Also one possible request - I think a lot of people will come dev/cuda to learn CUDA. If you're able to comment some of the kernels I think it could be really valuable to a lot of people (including myself!)
So cool, I went down from 400ms/iter ->200ms/iter.