AQLM icon indicating copy to clipboard operation
AQLM copied to clipboard

Case Study: Instruction Tuning on AQLM Models

Open hiyouga opened this issue 3 months ago • 6 comments

Hi, we have performed a small experiment on fine-tuning the Llama-2-70B-AQLM-2Bit model using the PEFT QLoRA method. We utilized the Alpaca and Glaive datasets for instruction tuning, and the fine-tuned version demonstrates preliminary conversation and tool-using abilities. We found that the training only requires 24GB of GRAM, while inference only needs 20GB. Thus fine-tuning a 70B model on consumer devices can be feasible. However, we found that AQLM significantly increases the overall training time. It would be better if the training speed could be improved. Thanks again for your excellent work!

Adapter weights: https://huggingface.co/hiyouga/Llama-2-70b-AQLM-2Bit-QLoRA-function-calling

examples train_loss gpu

hiyouga avatar Mar 05 '24 06:03 hiyouga

Hi @hiyouga ! It's, sadly, not properly documented yet, but you should do:

import aqlm

with aqlm.optimize_for_training():
    model = AutoModelForCausalLM.from_pretrained(...)

The thing is, there a few ways to compute a forward pass and some of them work better for very small number of tokens (e.g. generation), and some are optimized for large batch sizes (e.g. training). We're hoping to be able to determine which kernels to use dynamically in later versions of aqlm, but, for now, please add that wrapped explicitly. Also, keep in mind, that a model loaded under that wrapper will be very slow on generation. We're working on making it a more pleasant experience!

BlackSamorez avatar Mar 05 '24 09:03 BlackSamorez

@BlackSamorez Indeed! It's very important to me, I will try to fine-tune the model again following your advice. Thanks for pointing it out!

hiyouga avatar Mar 05 '24 10:03 hiyouga

A bit more context: those are the speeds for a typical layer on an RTX 3090 GPU. We have a kernel for a single token pass (generation), which is slightly faster than fp16, and we have a kernel which introduces a huge but constant overhead over fp16, meaning it's asymptotically as fast as fp16.

num_tokens (batch_size x seq_len) with optimize_for_training, ms/pass without optimize_for_training, ms/pass fp16 baseline
1 4.71 0.18 0.14
4 4.69 0.53 0.14
16 4.70 1.91 0.14
64 4.72 7.43 0.16
256 5.02 too slow 0.46
1024 6.14 too slow 1.57
4096 10.04 too slow 5.54
16384 25.68 too slow 21.15

As of now, we don't have a good enough kernel for anything between 4 and 4000 tokens processed in a pass. We're hoping to implement them someday.

BlackSamorez avatar Mar 05 '24 11:03 BlackSamorez

I see. The generation is much faster than training, and it might also be related to the gradient checkpointing technique in training.

hiyouga avatar Mar 05 '24 12:03 hiyouga

I've merged #39 and released aqlm==1.1.0 where I got rid of the need to use aqlm.optimize_for_training(). Everything is determined automatically from here on.

BlackSamorez avatar Mar 07 '24 20:03 BlackSamorez

Sounds great! We will instruct users to use the latest AQLM in our training framework

hiyouga avatar Mar 08 '24 15:03 hiyouga

This issue is stale because it has been open for 30 days with no activity.

github-actions[bot] avatar Apr 08 '24 01:04 github-actions[bot]