axolotl icon indicating copy to clipboard operation
axolotl copied to clipboard

[BOUNTY] Optimized Triton Kernels for full fine tunes

Open winglian opened this issue 1 year ago • 13 comments

🔖 Feature description

We've seen marketing from Unsloth that optimized triton kernels for various operations can significantly improve both the speed and memory efficiency of fine-tuning LoRA adapters as well as full fine tunes. However, only the LoRA triton kernels are open-source. We are awarding up to $600 bounty each for optimized triton kernels that are compatible with FlashAttention 2.0 for the following model architectures:

  1. Llama - $350
  2. Mistral (w Sliding Window Attention) - $250 (should be tackled after Llama since the only change from llama is SWA)
  3. Mixtral MoE - $600

[!IMPORTANT] EDIT: bounty has been doubled to $700, $500, and $1200 respectively thanks to a bounty match

To be eligible for a bounty, the submission into axolotl must be open sourced under Apache 2.0, support single and multi-gpu finetuning, include unit tests, support both regular full finetuning and full fine-tuning with multipack. Kernels should include the forward and backward passes for the MLP and attention modules. For Mixtral additional kernels required are both sparse and grouped permute_and_compute as well as kernels for gating experts.

[!IMPORTANT] EDIT 2024-01-03: Optimized is defined as at least a 15% time improvement over the current Flash Attention implementation, and a 25% memory improvement over the current Flash Attention implementation

winglian avatar Jan 03 '24 19:01 winglian

Parlance Labs is matching @winglian's bounty. So it's

  1. Llama $700
  2. Mistral $500
  3. Mixtral MoE $1200

hamelsmu avatar Jan 03 '24 19:01 hamelsmu

For those looking for inspiration to claim the $1200 for Mixtral:

  • MegaBlocks (Triton): A good baseline for communication / various distributed operations.
  • vLLM Expert Parallelism PR (Triton): An implementation of expert sharding across the available GPUs.
  • Grouped GEMM (CUDA): 15-25% faster than MegaBlocks sparse_permute_and_compute on H100s according to the PR. This is the grouped_permute_and_compute part in MegaBlocks. (Looks like this is only good for H100s)

casper-hansen avatar Jan 03 '24 19:01 casper-hansen

@winglian I suggest you put a targeted speedup, on what qualifies for "optimized". Who knows, maybe torch.compile used the right way can generate your definition of "optimized" :) and someone from the PyTorch community can attempt something like that (similar to the gpt-fast work we've been doing for inference)

soumith avatar Jan 03 '24 20:01 soumith

@winglian I suggest you put a targeted speedup, on what qualifies for "optimized". Who knows, maybe torch.compile used the right way can generate your definition of "optimized" :) and someone from the PyTorch community can attempt something like that (similar to the gpt-fast work we've been doing for inference)

Thanks, I've added additional clarification on that to the original post.

winglian avatar Jan 03 '24 20:01 winglian

I am not saying that this task is easy and the goals are simple, but if the accelerations in training time and decreases in VRAM memory usage promised by Unsloth in paid plans are real, let's assume that they are, the necessary conditions for receiving the reward remain low. Why don't we aim for a higher level of requirements? Should we aim for at least half the speedup and memory usage rates that unsloth promises? 25% time improvement over the current Flash Attention implementation, and a 30% memory improvement?

kostum123 avatar Jan 04 '24 07:01 kostum123

Hi, I think this is a great initiative! When you talk about the "current flash attention implementation", could you perhaps specify the exact tech stack and version that you have in mind? In fact, it might also be useful to specify the desired hardware. I think this would make the rules of the competition really clear-cut.

jedreky avatar Jan 04 '24 12:01 jedreky

Hi! I think this is a good opportunity for those trying to get deep into LLMs. It would be really helpful if you can explain what to do on a High Level basis to get started. Thanks

Itssshikhar avatar Jan 04 '24 13:01 Itssshikhar

Also specify for those that land directly to this page that is about the Triton Lang not the Triton Server. Any particular GPU architecture to target as preference (A100 / H100)? Are there benchmark of the current kernel speed? (should create those to see the baseline).

Mistobaan avatar Jan 04 '24 18:01 Mistobaan

For Mixtral additional kernels required are both sparse and grouped permute_and_compute as well as kernels for gating experts.

Here is my answer specific to Mixtral. Solutions that achieve a speedup on both A100 and H100 should be accepted. You would have to implement a sparse kernel on A100 and a grouped kernel on H100.

I think @winglian should provide a baseline axolotl config. Perhaps one for short context and long context datasets.

casper-hansen avatar Jan 04 '24 18:01 casper-hansen

Triton kernel for expert computation in MoE compatible with float16 and bfloat16. Speed up of 2.3-5x dependent on batch size. You would just need to make it compatible with axolotl and implement the backward pass.

If this can be implemented in axolotl for Mixtral, you could likely claim the $1200.

https://github.com/vllm-project/vllm/pull/2453

casper-hansen avatar Jan 20 '24 13:01 casper-hansen

@unslothai @danielhanchen would you open-source your kernels to claim the bounty?

kno10 avatar Apr 18 '24 16:04 kno10