torchtune icon indicating copy to clipboard operation
torchtune copied to clipboard

How to retrieve the distilled model in a manner similar to the OpenAI API interface ?

Open lingq1 opened this issue 1 year ago • 13 comments

I used the script tune run --nnodes 1 --nproc_per_node 2 knowledge_distillation_distributed --config llama3_2/8B_to_1B_KD_lora_distributed to distill the model and saved the results to /tmp/torchtune/llama3_2_8B_to_1B/KD_lora_distributed. How can I run my 1B model in a way similar to the OpenAI API to verify its effectiveness?

lingq1 avatar Dec 11 '24 10:12 lingq1

Congrats on the successful fine-tune @lingq1!

There are a couple ways to determine effectiveness of your model.

  1. Benchmarks - These are standardized datasets where the output from your model is compared against the desired output for a given task. Some common benchmarks include MMMU, which collects subsets of knowledge from science, anthropology, algebra, etc and asks models a series of multiple choice questions to see how close it gets to the right answer. The advantage of benchmarks is they are easy to compare against other model's performances. The disadvantage is that they can be fairly easy to game, so the longer a benchmark is around the less useful it tends to be.
  2. Vibes - Another popular way of measuring the effectiveness of your model is just to play around with it for awhile! Ask it some questions, try to stump it, see what interesting behavior emerges. LLM_Eval

I believe that both approaches are important to determining if your fine-tuned model is "good". So how can you implement these methods in torchtune?

For benchmarking, torchtune provides an integration with the EleutherAI Eval Harness. Simply create a config similar to the following example:

# Model Arguments
model:
  _component_: torchtune.models.llama3_2.llama3_2_1b

# Checkpointer
checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: /tmp/torchtune/llama3_2_8B_to_1B/KD_lora_distributed/epoch_X # substitute with your final epoch number
  checkpoint_files: [model.safetensors]
  output_dir: ./ # Not needed
  model_type: LLAMA3_2

# Tokenizer
tokenizer:
  _component_: torchtune.models.llama3.llama3_tokenizer
  path: /tmp/torchtune/llama3_2_8B_to_1B/KD_lora_distributed/base_model/original/tokenizer.model

# Environment
device: cuda
dtype: bf16
seed: 1234 # It is not recommended to change this seed, b/c it matches EleutherAI's default seed

# EleutherAI specific eval args
tasks: ["truthfulqa_mc2"]
limit: null
max_seq_length: 4096
batch_size: 8
enable_kv_cache: True

# Quantization specific args
quantizer: null

And run this config with tune run eleuther_eval --config PATH/TO/YOUR/CONFIG.YAML

Now for playing around with the model, you have some options. torchtune provides a generation script that works OOTB; however, it is very simple and only lets your use a set prompt to receive some output - there is no chat functionality.

# Model Arguments
model:
  _component_: torchtune.models.llama3_2.llama3_2_1b

# Checkpointer
checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: /tmp/torchtune/llama3_2_8B_to_1B/KD_lora_distributed/epoch_X # substitute with your final epoch number
  checkpoint_files: [model.safetensors]
  output_dir: ./ # Not needed
  model_type: LLAMA3_2

# Tokenizer
tokenizer:
  _component_: torchtune.models.llama3.llama3_tokenizer
  path: /tmp/torchtune/llama3_2_8B_to_1B/KD_lora_distributed/base_model/original/tokenizer.model
device: cuda
dtype: bf16

seed: 1234

# Generation arguments; defaults taken from gpt-fast
prompt:
  system: null
  user: "Tell me a joke."
max_new_tokens: 300
temperature: 0.6 # 0.8 and 0.6 are popular values to try
top_k: 300

enable_kv_cache: True

quantizer: null

And launch with tune run generate --config PATH/TO/YOUR/CONFIG.YAML

However, thanks to some awesome work from @felipemello1, it's also really easy to chat with your model through torchtune's integration with vLLM!! Follow the steps outlined in this PR description to get started. Once you have everything setup you can run vLLM locally:

from vllm import LLM, SamplingParams

def print_outputs(outputs):
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
    print("-" * 80)

llm = LLM(
    model="/tmp/torchtune/llama3_2_8B_to_1B/KD_lora_distributed/epoch_X", # substitute with your final epoch number
    load_format="safetensors",
    kv_cache_dtype="auto",
)
sampling_params = SamplingParams(max_tokens=16, temperature=0.5)

conversation = [
    {"role": "system", "content": "You are a helpful assistant"},
    {"role": "user", "content": "Hello"},
    {"role": "assistant", "content": "Hello! How can I assist you today?"},
    {
        "role": "user",
        "content": "Write an essay about the importance of higher education.",
    },
]
outputs = llm.chat(conversation, sampling_params=sampling_params, use_tqdm=False)
print_outputs(outputs)

Lastly, I know you mentioned that OpenAI API compatible evaluation ... this probably isn't the fastest way to get things up and running, but because torchtune integrates with vLLM and vLLM provides a similar interface, this is totally possible! You can read more on setting this up here.

Feel free to post some results when you're done evaluating!

joecummings avatar Dec 11 '24 14:12 joecummings

Congrats on the successful fine-tune @lingq1!

There are a couple ways to determine effectiveness of your model.

  1. Benchmarks - These are standardized datasets where the output from your model is compared against the desired output for a given task. Some common benchmarks include MMMU, which collects subsets of knowledge from science, anthropology, algebra, etc and asks models a series of multiple choice questions to see how close it gets to the right answer. The advantage of benchmarks is they are easy to compare against other model's performances. The disadvantage is that they can be fairly easy to game, so the longer a benchmark is around the less useful it tends to be.
  2. Vibes - Another popular way of measuring the effectiveness of your model is just to play around with it for awhile! Ask it some questions, try to stump it, see what interesting behavior emerges. LLM_Eval

I believe that both approaches are important to determining if your fine-tuned model is "good". So how can you implement these methods in torchtune?

For benchmarking, torchtune provides an integration with the EleutherAI Eval Harness. Simply create a config similar to the following example:

# Model Arguments
model:
  _component_: torchtune.models.llama3_2.llama3_2_1b

# Checkpointer
checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: /tmp/torchtune/llama3_2_8B_to_1B/KD_lora_distributed/epoch_X # substitute with your final epoch number
  checkpoint_files: [model.safetensors]
  output_dir: ./ # Not needed
  model_type: LLAMA3_2

# Tokenizer
tokenizer:
  _component_: torchtune.models.llama3.llama3_tokenizer
  path: /tmp/torchtune/llama3_2_8B_to_1B/KD_lora_distributed/base_model/original/tokenizer.model

# Environment
device: cuda
dtype: bf16
seed: 1234 # It is not recommended to change this seed, b/c it matches EleutherAI's default seed

# EleutherAI specific eval args
tasks: ["truthfulqa_mc2"]
limit: null
max_seq_length: 4096
batch_size: 8
enable_kv_cache: True

# Quantization specific args
quantizer: null

And run this config with tune run eleuther_eval --config PATH/TO/YOUR/CONFIG.YAML

Now for playing around with the model, you have some options. torchtune provides a generation script that works OOTB; however, it is very simple and only lets your use a set prompt to receive some output - there is no chat functionality.

# Model Arguments
model:
  _component_: torchtune.models.llama3_2.llama3_2_1b

# Checkpointer
checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: /tmp/torchtune/llama3_2_8B_to_1B/KD_lora_distributed/epoch_X # substitute with your final epoch number
  checkpoint_files: [model.safetensors]
  output_dir: ./ # Not needed
  model_type: LLAMA3_2

# Tokenizer
tokenizer:
  _component_: torchtune.models.llama3.llama3_tokenizer
  path: /tmp/torchtune/llama3_2_8B_to_1B/KD_lora_distributed/base_model/original/tokenizer.model
device: cuda
dtype: bf16

seed: 1234

# Generation arguments; defaults taken from gpt-fast
prompt:
  system: null
  user: "Tell me a joke."
max_new_tokens: 300
temperature: 0.6 # 0.8 and 0.6 are popular values to try
top_k: 300

enable_kv_cache: True

quantizer: null

And launch with tune run generate --config PATH/TO/YOUR/CONFIG.YAML

However, thanks to some awesome work from @felipemello1, it's also really easy to chat with your model through torchtune's integration with vLLM!! Follow the steps outlined in this PR description to get started. Once you have everything setup you can run vLLM locally:

from vllm import LLM, SamplingParams

def print_outputs(outputs):
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
    print("-" * 80)

llm = LLM(
    model="/tmp/torchtune/llama3_2_8B_to_1B/KD_lora_distributed/epoch_X", # substitute with your final epoch number
    load_format="safetensors",
    kv_cache_dtype="auto",
)
sampling_params = SamplingParams(max_tokens=16, temperature=0.5)

conversation = [
    {"role": "system", "content": "You are a helpful assistant"},
    {"role": "user", "content": "Hello"},
    {"role": "assistant", "content": "Hello! How can I assist you today?"},
    {
        "role": "user",
        "content": "Write an essay about the importance of higher education.",
    },
]
outputs = llm.chat(conversation, sampling_params=sampling_params, use_tqdm=False)
print_outputs(outputs)

Lastly, I know you mentioned that OpenAI API compatible evaluation ... this probably isn't the fastest way to get things up and running, but because torchtune integrates with vLLM and vLLM provides a similar interface, this is totally possible! You can read more on setting this up here.

Feel free to post some results when you're done evaluating!

Thank you for your answer, it helped me a lot.

  1. I used llama-factory to fine-tune the llama3-8B model with a custom dataset formatted as follows (using 2000 pieces of data in this format for fine-tuning):
{
   "instruction": "analyze whether the following mtr data is normal",
   "input": "mtr=Start: Mon Sep 30 09:28:35 2024\nHOST: VM-I944SQ1D  Loss%   Snt   Last   Avg  Best  Wrst StDev\n 1.|-- 149.76.179.202   58.7%    50    19.6  97.7  15.7  29.2  52.9\n 2.|-- 108.215.47.202   72.1%    50    23.3  21.0  70.0  90.0  40.9\n 3.|-- 114.129.40.151   55.8%    50    84.5  17.8  50.5  88.0  93.1",
   "output": "{\"last_percent\": \"55.8%\", \"status\": \"abnormal\"}"
} 
  1. After using vllm for inference, the llama3-8B model fully met expectations. However, after applying 8B-1B distillation and using vllm for inference on the 1B model, the output did not quite meet expectations (the inference code was referenced from vllm), although it performed much better than before distillation. My setup uses 2 A10 24G gpus, and the batch_size cannot be increased any further. Is there any room for optimization in the following distillation code?
tune run --nnodes 1 --nproc_per_node 2 knowledge_distillation_distributed --config llama3_2/8B_to_1B_KD_lora_distributed \
teacher_checkpointer.checkpoint_files='[
   "model-00001-of-00009.safetensors",
   "model-00002-of-00009.safetensors",
   "model-00003-of-00009.safetensors",
   "model-00004-of-00009.safetensors",
   "model-00005-of-00009.safetensors",
   "model-00006-of-00009.safetensors",
   "model-00007-of-00009.safetensors",
   "model-00008-of-00009.safetensors",
   "model-00009-of-00009.safetensors"
]' \
optimizer.lr=1e-5 \
kd_ratio=0.5 \
dataset=torchtune.datasets.alpaca_dataset \
dataset.source=json \
dataset.data_files=/usr/local/xxx.json \
batch_size=2 

lingq1 avatar Dec 12 '24 15:12 lingq1

@joecummings So did the application scenario I mentioned above not suit the use of distillation technology, or is there room for parameter optimization?

lingq1 avatar Dec 13 '24 02:12 lingq1

IIUC correctly, your model performed better after distillation on your benchmarking than before? In that case, it seems that distillation worked?

To make sure though that this is working properly, can you try directly finetuning your small 1B model and comparing it to the distilled model? The distilled model should perform slightly better than direct finetuning b/c it's incorporating the distribution from the much larger model.

joecummings avatar Dec 13 '24 12:12 joecummings

IIUC correctly, your model performed better after distillation on your benchmarking than before? In that case, it seems that distillation worked?

To make sure though that this is working properly, can you try directly finetuning your small 1B model and comparing it to the distilled model? The distilled model should perform slightly better than direct finetuning b/c it's incorporating the distribution from the much larger model.

I will try

lingq1 avatar Dec 16 '24 01:12 lingq1

IIUC correctly, your model performed better after distillation on your benchmarking than before? In that case, it seems that distillation worked?

To make sure though that this is working properly, can you try directly finetuning your small 1B model and comparing it to the distilled model? The distilled model should perform slightly better than direct finetuning b/c it's incorporating the distribution from the much larger model.

How can I distill all modules during the distillation process? The following command results in an error.

model.lora_attn_modules=['q_proj', 'v_proj', 'output_proj', 'k_proj', 'gate_proj', 'up_proj', 'down_proj'];

error:

rank0: ] = convert_weights.tune_to_peft_adapter_config(
[rank0]: File "/usr/local/src/torchtune/torchtune/models/convert_weights.py", line 242, in tune_to_peft_adapter_config
raise ValueError(f"Unknown target module {k}")

lingq1 avatar Dec 16 '24 03:12 lingq1

Our LoRA recipe does not allow you to "LoRA-fy" the gate_proj, up_proj, or down_proj through the self attention modules b/c they are actually part of the MLP layer! You can read more about it in our LoRA tutorial. Acceptable options are 'q_proj', 'v_proj', 'output_proj', and 'k_proj'.

Then, if you want to "LoRA-fy" the gate, up, and down projections, you can just use apply_lora_to_mlp=True.

joecummings avatar Dec 16 '24 13:12 joecummings

Our LoRA recipe does not allow you to "LoRA-fy" the gate_proj, up_proj, or down_proj through the self attention modules b/c they are actually part of the MLP layer! You can read more about it in our LoRA tutorial. Acceptable options are 'q_proj', 'v_proj', 'output_proj', and 'k_proj'.

Then, if you want to "LoRA-fy" the gate, up, and down projections, you can just use apply_lora_to_mlp=True.

@joecummings Thank you for your patient response. I tried modifying the configuration again. This time it ran and inferred successfully. However, the result is not as good as when I previously fine-tuned the 1B model with Llamafactory. I observed from the logs that the loss during distillation is quite high, while the loss was at a lower threshold when fine-tuning with Llamafactory. Here is my configuration. Is there any good way to optimize my loss?

tune run --nnodes 1 --nproc_per_node 2 knowledge_distillation_distributed --config llama3_2/8B_to_1B_KD_lora_distributed \
teacher_checkpointer.checkpoint_files='[
    "model-00001-of-00009.safetensors",
    "model-00002-of-00009.safetensors",
    "model-00003-of-00009.safetensors",
    "model-00004-of-00009.safetensors",
    "model-00005-of-00009.safetensors",
    "model-00006-of-00009.safetensors",
    "model-00007-of-00009.safetensors",
    "model-00008-of-00009.safetensors",
    "model-00009-of-00009.safetensors"
]' \
optimizer.lr=1e-5 \
kd_ratio=1.0 \
dataset=torchtune.datasets.alpaca_dataset \
dataset.source=json \
dataset.data_files=/usr/local/alpaca_zh_demo.json \
batch_size=2 \
output_dir=/usr/local/torchtune/llama3_2_8B_to_1B/KD_lora_distributed \
tokenizer.path=/usr/local/Llama-3.2-1B-Instruct/Llama-3.2-1B-Instruct/original/tokenizer.model \
checkpointer.checkpoint_dir=/usr/local/Llama-3.2-1B-Instruct/Llama-3.2-1B-Instruct/ \
teacher_checkpointer.checkpoint_dir=/usr/local/Meta-Llama-3.1-8B-Instruct/ \
teacher_checkpointer.output_dir=/usr/local/Meta-Llama-3.1-8B-Instruct \
model.lora_rank=8 \
model.lora_alpha=16 \
model.lora_attn_modules="['q_proj', 'v_proj', 'output_proj', 'k_proj']"  \
epochs=3 \
warmup_steps=0
3|141|Loss: 1.7566685676574707:   7%|▋         | 5/68 [03:32<46:14, 44.04s/it]
3|141|Loss: 1.7566685676574707:   9%|▉         | 6/68 [04:07<42:18, 40.95s/it]
3|142|Loss: 2.570791006088257:   9%|▉         | 6/68 [04:07<42:18, 40.95s/it] 
3|142|Loss: 2.570791006088257:  10%|█         | 7/68 [04:51<42:39, 41.95s/it]
3|143|Loss: 1.8999344110488892:  10%|█         | 7/68 [04:51<42:39, 41.95s/it]
3|143|Loss: 1.8999344110488892:  12%|█▏        | 8/68 [05:48<46:51, 46.86s/it]
3|144|Loss: 1.9488004446029663:  12%|█▏        | 8/68 [05:48<46:51, 46.86s/it]
3|144|Loss: 1.9488004446029663:  13%|█▎        | 9/68 [06:31<44:58, 45.73s/it]
3|145|Loss: 1.926500380039215:  13%|█▎        | 9/68 [06:31<44:58, 45.73s/it] 
3|145|Loss: 1.926500380039215:  15%|█▍        | 10/68 [07:02<39:39, 41.03s/it]
3|146|Loss: 1.98761785030365:  15%|█▍        | 10/68 [07:02<39:39, 41.03s/it] 
3|146|Loss: 1.98761785030365:  16%|█▌        | 11/68 [07:41<38:32, 40.56s/it]
3|147|Loss: 1.945314347743988:  16%|█▌        | 11/68 [07:41<38:32, 40.56s/it]
3|147|Loss: 1.945314347743988:  18%|█▊        | 12/68 [08:38<42:29, 45.53s/it]
3|148|Loss: 1.8299177289009094:  18%|█▊        | 12/68 [08:38<42:29, 45.53s/it]
3|148|Loss: 1.8299177289009094:  19%|█▉        | 13/68 [09:12<38:27, 41.96s/it]
3|149|Loss: 1.9277714490890503:  19%|█▉        | 13/68 [09:12<38:27, 41.96s/it]
......
3|200|Loss: 1.7644927501678467:  96%|█████████▌| 65/68 [59:47<02:42, 54.13s/it]
3|201|Loss: 1.7249903678894043:  96%|█████████▌| 65/68 [59:47<02:42, 54.13s/it]
3|201|Loss: 1.7249903678894043:  97%|█████████▋| 66/68 [1:00:34<01:43, 51.96s/it]
3|202|Loss: 1.9537839889526367:  97%|█████████▋| 66/68 [1:00:34<01:43, 51.96s/it]
3|202|Loss: 1.9537839889526367:  99%|█████████▊| 67/68 [1:01:34<00:54, 54.40s/it]
3|203|Loss: 1.8339385390281677:  99%|█████████▊| 67/68 [1:01:34<00:54, 54.40s/it]
3|203|Loss: 1.8339385390281677: 100%|██████████| 68/68 [1:02:21<00:00, 52.28s/it]
3|204|Loss: 2.169794023036957: 100%|██████████| 68/68 [1:02:21<00:00, 52.28s/

lingq1 avatar Dec 17 '24 10:12 lingq1

Interesting - that loss is definitely a little high. At a glance, I'd recommend trying out the following things:

  1. Reducing the kd_ratio like the advice here.
  2. Bumping up the batch_size or using gradient accumulation if your memory doesn't allow directly increasing the batch_size
  3. Using a larger LoRA rank and alpha. Read more about that here

cc @lindawangg to see if she has any suggestions as well

Could you also share your setup for when you ran with Llama-Factory?

joecummings avatar Dec 17 '24 13:12 joecummings

Your lr is also very small. Try 10x it to 1e-4

felipemello1 avatar Dec 17 '24 21:12 felipemello1

Interesting - that loss is definitely a little high. At a glance, I'd recommend trying out the following things:

  1. Reducing the kd_ratio like the advice here.
  2. Bumping up the batch_size or using gradient accumulation if your memory doesn't allow directly increasing the batch_size
  3. Using a larger LoRA rank and alpha. Read more about that here

cc @lindawangg to see if she has any suggestions as well

Could you also share your setup for when you ran with Llama-Factory?

@joecummings This is the Llama-Factory configuration I use, what are the differences compared to the tune configuration?

llamafactory-cli train \
    --stage sft \
    --do_train True \
    --model_name_or_path  /usr/local/Llama-3.2-1B-Instruct \
    --preprocessing_num_workers 16 \
    --finetuning_type lora \
    --template default \
    --flash_attn auto \
    --dataset_dir data \
    --dataset mllm_demo \
    --cutoff_len 1024 \
    --learning_rate 5e-05 \
    --num_train_epochs 3.0 \
    --max_samples 100000 \
    --per_device_train_batch_size 2 \
    --gradient_accumulation_steps 8 \
    --lr_scheduler_type cosine \
    --max_grad_norm 1.0 \
    --logging_steps 5 \
    --save_steps 100 \
    --warmup_steps 0 \
    --optim adamw_torch \
    --packing False \
    --report_to none \
    --output_dir saves/LLaMA3-8B/lora/train_2024-12-17-11-45-21 \
    --fp16 True \
    --plot_loss True \
    --ddp_timeout 180000000 \
    --include_num_input_tokens_seen True \
    --lora_rank 8 \
    --lora_alpha 16 \
    --lora_dropout 0 \
    --lora_target all

lingq1 avatar Dec 18 '24 02:12 lingq1

Your lr is also very small. Try 10x it to 1e-4

I'll try

lingq1 avatar Dec 18 '24 02:12 lingq1

Interesting - that loss is definitely a little high. At a glance, I'd recommend trying out the following things:

  1. Reducing the kd_ratio like the advice here.
  2. Bumping up the batch_size or using gradient accumulation if your memory doesn't allow directly increasing the batch_size
  3. Using a larger LoRA rank and alpha. Read more about that here

cc @lindawangg to see if she has any suggestions as well

Could you also share your setup for when you ran with Llama-Factory?

@joecummings After optimizing the configuration, the loss dropped to 1.3. However, there is still a significant gap compared to LlamaFactory's loss of 0.0001. Upon reviewing LlamaFactory's logs, it seems that LlamaFactory might be vectorizing the input data. Torchtune, on the other hand, might not be processing the data in the same way ? I guess it could be why the loss remains relatively high in my scenario ?

lingq1 avatar Dec 18 '24 06:12 lingq1