torchtitan icon indicating copy to clipboard operation
torchtitan copied to clipboard

[DeepSeek][kernels] index select permute, cuda

Open lessw2020 opened this issue 7 months ago • 4 comments

This PR adds an index_select_permute operation termed fast_permute_tokens. Basically we do an index select on tokens to prep/move them into contiguous memory by expert.

I found that based on the problem size, one kernel does not fit all if we want to beat out PyTorch. So this featues an adaptive kernel that is really 4 kernels that are selected based on problem size.

Each kernel is small but progressively adds improvements needed to scale addressing the problem size. Small = use shared memory for transfer, one thread per element Medium = multiple tokens per thread block Large = 2D grid, where each thread handles a tile of (tokens x features) XL = 2D grid with templated vectorized memory

Initial perf and verification results:

=== Small Configuration ===
Benchmarking with batch_size=1024, hidden_dim=4096, n_indices=512
Verifying CUDA implementation matches PyTorch...
✓ Results match!
PyTorch:  0.020 ± 0.001 ms
CUDA:     0.011 ± 0.000 ms
Speedup:  1.84x

=== Medium Configuration ===
Benchmarking with batch_size=4096, hidden_dim=4096, n_indices=4096
Verifying CUDA implementation matches PyTorch...
✓ Results match!
PyTorch:  0.085 ± 0.002 ms
CUDA:     0.042 ± 0.001 ms
Speedup:  2.01x

=== Large Configuration ===
Benchmarking with batch_size=8192, hidden_dim=4096, n_indices=8192
Verifying CUDA implementation matches PyTorch...
✓ Results match!
PyTorch:  0.163 ± 0.001 ms
CUDA:     0.068 ± 0.001 ms
Speedup:  2.38x

lessw2020 avatar Apr 09 '25 21:04 lessw2020

General comment: maybe we can use unrolling to speedup the kernels too. Reason is that these kernels are mainly doing memory access. Unrolling can issue more commands to fill up the round-trip latency.

kwen2501 avatar Apr 10 '25 16:04 kwen2501

@ngimel - thanks for info above! To your questions: 1 - "This optimizes performance for an extremely common function, and as such should go into pytorch core and not into torchtitan. " The reason this was started was b/c PyTorch Eager was slow and was a hotspot in initial profiling, and easiest to dev work right in place. If it proved out, then could look at moving it to core...doing it in reverse would be too much overhead.

2 - "Additionally, torch.compile already provides performance better than this kernels, so given that torch.titan relies on torch.compile already, the benefit of adding precompiled implementations is unclear" Correction - TorchTitan main relies on torch.compile. There you are right, we do use torch.compile for regional compilation of the 'main' models (llama3, etc) transformer blocks. However, we are not using torch.compile generally in experimental to start, at least not until things are in good enough shape to consider moving into main and then sure we can add on compile.

Does it make sense to just upgrade the pytorch eager core code with the compile generated triton kernel then as a generic fix, so that folks that don't use compile can get a faster experience?

lessw2020 avatar Apr 16 '25 02:04 lessw2020

we still try to avoid calling triton kernels in eager, given that they come with unpredictable recompilation. However, triton kernel is a proof that just vectorized loads/stores with simple 1d blocking is enough to get good perf, so we can just write this simple kernel in eager.

ngimel avatar Apr 16 '25 05:04 ngimel

https://github.com/pytorch/pytorch/pull/151490

ngimel avatar Apr 16 '25 23:04 ngimel