unsloth icon indicating copy to clipboard operation
unsloth copied to clipboard

Faster Inference & Training Roadmap

Open jeromeku opened this issue 5 months ago • 27 comments

@danielhanchen

In the unsloth Gemma intro blogpost, you mention VRAM increase due to larger MLP size in Gemma compared to Llama and Mistral, and show a graph demonstrating decreased memory usage when running unsloth vs. HF and FA2:

  • How does unsloth reduce memory usage?
  • What are the model and runtime configs used to generate the HF vs FA2 vs unsloth graph? Is it inference or training?

Curious what optimizations are leading to memory decrease -- quantization, autograd efficiency, etc.

jeromeku avatar Mar 07 '24 16:03 jeromeku

@jeromeku I will get to reviewing GPTQ - sorry on the delay!!

  • The VRAM reductions are from Unsloth's optims :) Ie Triton kernels, making memory copies go away, FA2, autograd tricks etc
  • Oh training! All use 4bit quantization, the ga is set to 1, and bsz and sequence length moves dynamically so I can measure VRAM usage.

danielhanchen avatar Mar 08 '24 02:03 danielhanchen

@danielhanchen

Thanks -- would be helpful to have a step-by-step breakdown of where the memory savings are coming from, i.e., an ablation study.

Is there interest in faster inference kernels, or is the focus primarily on the training side?

jeromeku avatar Mar 08 '24 03:03 jeromeku

@jeromeku For Mistral itself: https://unsloth.ai/blog/mistral-benchmark image

Gemma's VRAM reduction should be similar to our breakdown for Mistral.

For inference for Gemma - I did make it 2x faster, but it's mainly cobbling up ideas from vLLM and other packages, so I only spent 1 week on it :) The goal is to merge GPT Fast and other ideas like EAGLE to make inference faster :)

danielhanchen avatar Mar 08 '24 08:03 danielhanchen

@danielhanchen

I'd be interested in contributing on the inference front -- let's create a priority list of ideas for implementation?

jeromeku avatar Mar 08 '24 15:03 jeromeku

@jeromeku That'll be cool!! :) We can collab either via Github or async on our Discord - whatever suites you :)

danielhanchen avatar Mar 09 '24 03:03 danielhanchen

@danielhanchen

Looking forward to it!

What's top of mind currently? Perhaps we can draw up a roadmap (if one doesn't already exist).

jeromeku avatar Mar 09 '24 04:03 jeromeku

@jeromeku Oh ye a roadmap would be nice - don't actually have one for inference specifically :)

danielhanchen avatar Mar 09 '24 05:03 danielhanchen

@danielhanchen

You mentioned integrating ideas from fastGPT and EAGLE, what others did you have in mind?

What's on the roadmap for fine-tuning / training -- architectures, algorithms, etc.? Asking so I know what literature / code to review.

jeromeku avatar Mar 09 '24 15:03 jeromeku

@jeromeku In terms of inference specifically:

  1. GPT Fast
  2. Speculative Decoding (use a small model to generate tokens, then use a large model in 1 forward pass and see if the argmax of the logits match)
  3. EAGLE (Speculative Decoding but only Word2Vec style ie lm_head -> embeddings)
  4. All quant methods - HQQ, AWQ, Exllama etc
  5. vLLM's Paged Attention
  6. Full 1 singular Triton kernel fusion - ie can we write 1 forward pass in 1 humoungous Triton kernel? Very hard since there are synchronizations which have to be done
  7. Using float8 like Fire Attention. cuDNN has float8 flash attention I think as well.
  8. Rewriting matrix vector multiplication in Triton exactly (like what you were trying to do with GPTQ but not matmul, but matvec
  9. Torch export

I might have more, but those are from the top of my head.

For training / finetuning:

  1. Fast MoE matmul kernel https://github.com/vllm-project/vllm/pull/2453 but for training - much more complex than inference on batch sizes of 1. Mixtral selects the top 2 experts, which can easily be done in Triton. However, when you have bsz>1, we have issues. One has to do dynamic compressed packing then call torch.bmm. The backward pass is even more problematic, since it requires a reversed packing then calling torch.bmm, then deconstructing it. A nightmare.
  2. Galore - extremely fascinating projecting gradients to a small (rank, rank) matrix, then using SVD to update the projectors. It's not Galore that I was fascinatined by, but rather Lomo, which does gradient updates dynamically, and this can save 20GB of VRAM during pretraining.
  3. 1.58bit - I recently wrote on HN about how 1.58bit allows one to not to multiplications since (-1, 0, 1) becomes a simple sign flip then the mantissas are added after the exponents are flipped. Using 8bit floats, 1.58bit uses 2x less space than float8, which makes it possible to cram 2x transistors. Writing it in Triton can be more complex.

Just a brain dump!

danielhanchen avatar Mar 09 '24 16:03 danielhanchen

@danielhanchen

Love it.

Inference:

  • vLLM paged attention - happy to look into this as well as KV cache quantization.
  • Triton GEMV - seems pretty straightforward -- can prototype such implementation in Triton -- I believe Torch compile already can generate such a kernel with proper inductor settings (effectively decomposes to a vectorized mul + add). Can also adapt existing GEMV CUDA kernels for quantized weights.
  • Torch export - can look into it. I've done some work into decoupling Triton kernels from Triton runtime.

Training:

  • I've been playing around with implementing a Cutlass kernel for MoE matmul which could help with larger batch sizes.
  • Galore - top of my list of papers to read
  • 1.58 bit - also been looking into Cutlass for optimizing custom quantized ops.

jeromeku avatar Mar 09 '24 18:03 jeromeku

  • Oh ye KV cache quant is cool! On issue I have with it is dynamically quantizing the KV cache will cause overhead issues - a super fast method for quantization will have to be deployed.
  • Triton GEMV: Ye the kernel is fine to create - one possibility is can we fold GEMVs and layernorms and everything into 1 large kernel
  • CUTLASS is good :) Looking forward to it - my main view is we need to use as much Triton as possible for device agnostic purposes :)
  • Ye Galore and 1.58bit :) 1.58bit actually can be very doable in Triton. Galore is very cool.

danielhanchen avatar Mar 10 '24 02:03 danielhanchen

Let me know what I should prioritize.

Also, can you expand more on Triton GEMV? What kind of horizontal / vertical fusions to target?

jeromeku avatar Mar 10 '24 05:03 jeromeku

Oh so GEMV is generally OK I guess - the issue is the dequant step merged in (ie what you were doing with GPTQ, except its not matrix matrix mult but matrix vector mult) this allows different optimizations - ie is blocked mm better or is column or is row wise mv better? It depends on the cache footprint

But the goal is can we somehow merge X @ Wq, X @ Wk, X @ Wv together with RoPE and attention and everything into 1 large kernel

danielhanchen avatar Mar 10 '24 11:03 danielhanchen

If I understand correctly:

  • Separate the dequant step from matmul
  • Fuse as much of the forward pass into a single kernel for Llama, Mistral, and Gemma architectures

Can you point me to the current GEMV implementation? Need a minimal implementation / testbed for benchmarking purposes.

jeromeku avatar Mar 10 '24 15:03 jeromeku

Oh for inference, you method of fusing the dequant step inside the kernel is actually ideal! For training its not, since CUBLAS is relatively smart in data movements.

An ideal kernel for GEMV ie vector * matrix kernel normally is done via: image

However a more optimal procedure is to split the reductions into 4 blocks by using atomic_add. It in fact can be say reduction columns of 4, but say blocks of 24, and cycling using the modulus function. image

A final reduction will need to be made at the end.

The current GEMV implementation will be probably the one in Fast-GPT although I haven't inspected it myself yet.

The hardest is the folding in of Bitsandbytes int4 which is a nightmare, since the blocksize is lopsided ie not whole integer multiple, which is a nightmare for cache optimality.

danielhanchen avatar Mar 10 '24 15:03 danielhanchen

Another approach people do is row wise image

which again can be done in parallel with a reduction as i described above

danielhanchen avatar Mar 10 '24 15:03 danielhanchen

@danielhanchen Ok - so I'm clear on objectives:

  • a reasonable first pass is a Triton kernel that fuses bitsandbytes 4-bit dequant with an efficient GEMV
  • further iterations would then fold in additional prologue / epilogue ops such as positional encodings, activations, etc.
  • ultimate goal would be fusing in as much of the forward pass as possible into single kernel.

jeromeku avatar Mar 10 '24 17:03 jeromeku

For training / finetuning:

@danielhanchen Obligatory request for Multi GPU XD

nivibilla avatar Mar 10 '24 18:03 nivibilla

@jeromeku Extremely sorry on the delay - yep sounds right! :) @nivibilla Yep!

danielhanchen avatar Mar 15 '24 12:03 danielhanchen

@danielhanchen

Is the issue with the existing bitsandbytes gemv the fact that it's CUDA only?

jeromeku avatar Mar 15 '24 14:03 jeromeku

@jeromeku Yes that can be one of the main issues - the other is folding it inside other kernels ie say 1 singular kernel can become too complex to do.

The main issue I still see with 1 kernel, so maybe I'm overthinking, is every new op requires synchronization, so maybe we should rather rely on torch.compile with CUDAGraphs to reduce the CPU overhead in between.

danielhanchen avatar Mar 15 '24 14:03 danielhanchen

I'd imagine there is an optimization spectrum:

  • torch.compile entire graph with appropriate inductor settings to maximize fusion / reduce overhead
  • manually fuse kernels and use torch cudagraph APIs to glue things together

Will make a quick pass at implementing bnb dequant gemv in triton to see how performance compares.

Cutlass also enables some flexibility with bespoke gemm and fusions but is again cuda only. Let me know if this is of interest.

jeromeku avatar Mar 15 '24 15:03 jeromeku

@jeromeku Oh ye let's try be device agnostic :)) compile is OK, but I guess handwritting is best :) We then can use CUDAGraphs manually

danielhanchen avatar Mar 15 '24 17:03 danielhanchen

@danielhanchen

A few updates:

  • GaLore -- ran some initial experiments to fuse the GaLore Adam update step -- see PR
  • Is a triton 4-bit bnb dequant kernel of interest?
  • Going to start working on implementing fused backward pass for mixtral.

jeromeku avatar Mar 22 '24 04:03 jeromeku

@jeromeku Fantastic work as always!! very very cool on fusing Adam and Galore!! Love this!

Oh on Mixtral - https://github.com/shawntan/scattermoe/tree/main/scattermoe :) Was reading up on this as well :)

On BnB dequant - I'll have a look first at it :) But you're more than happy to do it if you want :)

danielhanchen avatar Mar 22 '24 06:03 danielhanchen

@danielhanchen

  • Are you planning on integrating GaLore into unsloth? Planning on working on an Adam8bit version.
  • Will make a quick crack at bnb dequant

jeromeku avatar Mar 22 '24 12:03 jeromeku

Really excited about optimized kernels for inference!

Worth looking at https://github.com/zeux/calm - where the forward pass is implemented as a single cuda kernel

Uses fp8 rather than int4/8 quantization.

pHaeusler avatar Mar 24 '24 02:03 pHaeusler