torchtune
torchtune copied to clipboard
Sanity check: How to run inference after fine-tuning Llama-3 on chat data?
Hi! I'm hopeful that I'm doing this correctly, but not completely sure, and I don't know how best to verify this -- I've fine-tuned Llama-3 8B (using qlora) following this tutorial: https://pytorch.org/torchtune/main/tutorials/chat.html
Now I wanted to run inference (actually I'd love it if there was a simple way for me to run inference on a validation dataset that is formatted like my training dataset, but in absence of that), by adapting the generation template:
# Model arguments
model:
_component_: torchtune.models.llama3.llama3_8b
checkpointer:
_component_: torchtune.utils.FullModelMetaCheckpointer
checkpoint_dir: <directory path>/Meta-Llama-3-8B-Instruct/
checkpoint_files: [
meta_model_0.pt
]
output_dir: <directory path>/Meta-Llama-3-8B-Instruct/
model_type: LLAMA3
device: cuda
dtype: bf16
seed: 1234
# Tokenizer arguments
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
path: <directory path>/Meta-Llama-3-8B-Instruct/original/tokenizer.model
# Generation arguments; defaults taken from gpt-fast
prompt: '<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful and harmless assistant.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nGive three tips for staying healthy.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n'
max_new_tokens: 300
temperature: 0.6 # 0.8 and 0.6 are popular values to try
top_k: 300
quantizer: null
Is it correct that I specify the prompt in that format, or am I making a mistake because the tokenizer takes care of the special tokens somehow? Thanks!