~2x perf improvement beating PyTorch (cublasLt, TF32, CUDA graphs, kernel fusion, etc…)
This improves performance on my local RTX 4090 from ~65ms to ~34ms (while pyTorch takes ~36ms!)
ORIGINAL: step 1: train loss 4.406481 (took 64.890952 ms) OPTIMISED: step 1: train loss 4.406351 (took 34.064025 ms) PYTORCH: iteration 1, loss: 4.579084396362305, time: 36.545ms
Tested on pyTorch 2.2.2+cu121 with driver 550.54.14 (and nvcc 12.1 to be like-for-like, possibly faster with 12.4). TF32 is enabled for both my optimised code and pytorch using the command from README.md: python train_gpt2.py --inference_only 1 --write_tensors 0 --sequence_length 1024 --batch_size 4 --compile 1 --tensorcores 1"
When TF32 is enabled, I had to increase the tolerance from 1e2 to 1.0f in test_gpt2.cu, which results in the following output:
-43.431705 -43.351101 -39.836426 -39.763416 -43.066010 -42.994701 OK (LOGITS) LOSS OK: 5.269499 5.270009
The biggest performance gains come from:
- Using cuBLASLt for matmul_forward with merged bias and GELU (with cublasLtMatmulAlgoGetHeuristic).
- Optional TF32 for cuBlasLt/cuBlas to match pyTorch precision (requires looser threshold in test_gpt2.cu).
- Optimised softmax kernels, 1st version with fused scale kernel for attention and hardcoded block size of 512 threads, 2nd version for very large C (many loop iterations) with some advanced loop unrolling tricks.
- CUDA graphs with a non-default stream to maximise GPU/CPU parallelism (including cudaMemcpyAsync).
One possible issue with this commit is the huge number of new global static variables for CUDA at the top of train_gpt2.cu. This is to avoid passing loads of new arguments all over the place, e.g. every kernel launch now has to use a custom CUDA stream instead of the default one in order to be able to use CUDA graphs (same reason why cuBlas(Lt) handles can only be created once now).
Also I didn't include the associated changes to the standalone .cu files for now, partly because it became a bit of a mess with the cublas(Lt) handle problem from above and it depends on whether it needs to be refactored or not, but I'm happy to provide that as well tomorrow if needed.
Not tested on A100/H100 or with different CUDA/pyTorch versions yet so there's a very strong chance it doesn't match pyTorch on other configurations, but that doesn't sound as cool as just saying it beats pyTorch, so that's what I am going with... ;)
Future work ideas:
- Benchmark on A100 and/or H100 using CUDA 12.4 and latest pyTorch
- Check whether there's a faster way to do cublasSgemmStridedBatched with cublasLt
- Look into optimising away the permute/unpermute kernels (can it be "free" with the TMA on H100?)
- Investigate H100/AD102 lossless memory compression and/or cache residency controls (interesting info on this at GTC, extremely sensitive to access patterns so might not work in practice)
(sounds really great! processing through this now)
I think we'll want to break this up into chunks, a lot of really good stuff here.
One more data point for your PR. I ran your PR on my 4090 (torch==2.2.2, cuda 12.1, nvidia driver 530.30.02)
- llm.c (main branch) : step 1: train loss 4.406586 (took 37.426351 ms) with
./train_gpt2cu - llm.c (PR#89) : step 1: train loss 4.406351 (took 29.146125 ms) 🔥 with
./train_gpt2cu - pytorch : iteration 1, loss: 4.579084396362305, time: 33.262ms with
python train_gpt2.pyin the thread above