axolotl
axolotl copied to clipboard
[BOUNTY] Optimized Triton Kernels for full fine tunes
🔖 Feature description
We've seen marketing from Unsloth that optimized triton kernels for various operations can significantly improve both the speed and memory efficiency of fine-tuning LoRA adapters as well as full fine tunes. However, only the LoRA triton kernels are open-source. We are awarding up to $600 bounty each for optimized triton kernels that are compatible with FlashAttention 2.0 for the following model architectures:
- Llama - $350
- Mistral (w Sliding Window Attention) - $250 (should be tackled after Llama since the only change from llama is SWA)
- Mixtral MoE - $600
[!IMPORTANT] EDIT: bounty has been doubled to $700, $500, and $1200 respectively thanks to a bounty match
To be eligible for a bounty, the submission into axolotl must be open sourced under Apache 2.0, support single and multi-gpu finetuning, include unit tests, support both regular full finetuning and full fine-tuning with multipack. Kernels should include the forward and backward passes for the MLP and attention modules. For Mixtral additional kernels required are both sparse and grouped permute_and_compute as well as kernels for gating experts.
[!IMPORTANT] EDIT 2024-01-03: Optimized is defined as at least a 15% time improvement over the current Flash Attention implementation, and a 25% memory improvement over the current Flash Attention implementation
For those looking for inspiration to claim the $1200 for Mixtral:
- MegaBlocks (Triton): A good baseline for communication / various distributed operations.
- vLLM Expert Parallelism PR (Triton): An implementation of expert sharding across the available GPUs.
-
Grouped GEMM (CUDA): 15-25% faster than MegaBlocks
sparse_permute_and_compute
on H100s according to the PR. This is thegrouped_permute_and_compute
part in MegaBlocks. (Looks like this is only good for H100s)
@winglian I suggest you put a targeted speedup, on what qualifies for "optimized". Who knows, maybe torch.compile
used the right way can generate your definition of "optimized" :) and someone from the PyTorch community can attempt something like that (similar to the gpt-fast work we've been doing for inference)
@winglian I suggest you put a targeted speedup, on what qualifies for "optimized". Who knows, maybe
torch.compile
used the right way can generate your definition of "optimized" :) and someone from the PyTorch community can attempt something like that (similar to the gpt-fast work we've been doing for inference)
Thanks, I've added additional clarification on that to the original post.
I am not saying that this task is easy and the goals are simple, but if the accelerations in training time and decreases in VRAM memory usage promised by Unsloth in paid plans are real, let's assume that they are, the necessary conditions for receiving the reward remain low. Why don't we aim for a higher level of requirements? Should we aim for at least half the speedup and memory usage rates that unsloth promises? 25% time improvement over the current Flash Attention implementation, and a 30% memory improvement?
Hi, I think this is a great initiative! When you talk about the "current flash attention implementation", could you perhaps specify the exact tech stack and version that you have in mind? In fact, it might also be useful to specify the desired hardware. I think this would make the rules of the competition really clear-cut.
Hi! I think this is a good opportunity for those trying to get deep into LLMs. It would be really helpful if you can explain what to do on a High Level basis to get started. Thanks
Also specify for those that land directly to this page that is about the Triton Lang not the Triton Server. Any particular GPU architecture to target as preference (A100 / H100)? Are there benchmark of the current kernel speed? (should create those to see the baseline).
For Mixtral additional kernels required are both sparse and grouped permute_and_compute as well as kernels for gating experts.
Here is my answer specific to Mixtral. Solutions that achieve a speedup on both A100 and H100 should be accepted. You would have to implement a sparse kernel on A100 and a grouped kernel on H100.
I think @winglian should provide a baseline axolotl config. Perhaps one for short context and long context datasets.
Triton kernel for expert computation in MoE compatible with float16 and bfloat16. Speed up of 2.3-5x dependent on batch size. You would just need to make it compatible with axolotl and implement the backward pass.
If this can be implemented in axolotl for Mixtral, you could likely claim the $1200.
https://github.com/vllm-project/vllm/pull/2453
@unslothai @danielhanchen would you open-source your kernels to claim the bounty?