torchtune icon indicating copy to clipboard operation
torchtune copied to clipboard

Sanity check: How to run inference after fine-tuning Llama-3 on chat data?

Open julianstastny opened this issue 9 months ago • 6 comments

Hi! I'm hopeful that I'm doing this correctly, but not completely sure, and I don't know how best to verify this -- I've fine-tuned Llama-3 8B (using qlora) following this tutorial: https://pytorch.org/torchtune/main/tutorials/chat.html

Now I wanted to run inference (actually I'd love it if there was a simple way for me to run inference on a validation dataset that is formatted like my training dataset, but in absence of that), by adapting the generation template:

# Model arguments
model:
  _component_: torchtune.models.llama3.llama3_8b

checkpointer:
  _component_: torchtune.utils.FullModelMetaCheckpointer
  checkpoint_dir: <directory path>/Meta-Llama-3-8B-Instruct/
  checkpoint_files: [
    meta_model_0.pt
  ]
  output_dir: <directory path>/Meta-Llama-3-8B-Instruct/
  model_type: LLAMA3

device: cuda
dtype: bf16

seed: 1234

# Tokenizer arguments
tokenizer:
  _component_: torchtune.models.llama3.llama3_tokenizer
  path: <directory path>/Meta-Llama-3-8B-Instruct/original/tokenizer.model

# Generation arguments; defaults taken from gpt-fast
prompt: '<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful and harmless assistant.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nGive three tips for staying healthy.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n'
max_new_tokens: 300
temperature: 0.6 # 0.8 and 0.6 are popular values to try
top_k: 300

quantizer: null

Is it correct that I specify the prompt in that format, or am I making a mistake because the tokenizer takes care of the special tokens somehow? Thanks!

julianstastny avatar May 17 '24 22:05 julianstastny