torchtune icon indicating copy to clipboard operation
torchtune copied to clipboard

Sanity check: How to run inference after fine-tuning Llama-3 on chat data?

Open julianstastny opened this issue 1 year ago • 6 comments

Hi! I'm hopeful that I'm doing this correctly, but not completely sure, and I don't know how best to verify this -- I've fine-tuned Llama-3 8B (using qlora) following this tutorial: https://pytorch.org/torchtune/main/tutorials/chat.html

Now I wanted to run inference (actually I'd love it if there was a simple way for me to run inference on a validation dataset that is formatted like my training dataset, but in absence of that), by adapting the generation template:

# Model arguments
model:
  _component_: torchtune.models.llama3.llama3_8b

checkpointer:
  _component_: torchtune.utils.FullModelMetaCheckpointer
  checkpoint_dir: <directory path>/Meta-Llama-3-8B-Instruct/
  checkpoint_files: [
    meta_model_0.pt
  ]
  output_dir: <directory path>/Meta-Llama-3-8B-Instruct/
  model_type: LLAMA3

device: cuda
dtype: bf16

seed: 1234

# Tokenizer arguments
tokenizer:
  _component_: torchtune.models.llama3.llama3_tokenizer
  path: <directory path>/Meta-Llama-3-8B-Instruct/original/tokenizer.model

# Generation arguments; defaults taken from gpt-fast
prompt: '<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful and harmless assistant.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nGive three tips for staying healthy.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n'
max_new_tokens: 300
temperature: 0.6 # 0.8 and 0.6 are popular values to try
top_k: 300

quantizer: null

Is it correct that I specify the prompt in that format, or am I making a mistake because the tokenizer takes care of the special tokens somehow? Thanks!

julianstastny avatar May 17 '24 22:05 julianstastny

Hi @julianstastny, good question. The tokenizer should handle the template for you. You just need to specify the raw string in your config. You might get unexpected behavior by putting in the special tokens manually.

cc @joecummings on running inference on a validation dataset, I think that would be a good use case to support

RdoubleA avatar May 18 '24 17:05 RdoubleA

Generate recipe doesn't handle chat format automatically. But its adding the BOS token when encoding: See: https://github.com/pytorch/torchtune/blob/eb66b627221a770f32f8517a91178bf1ccac4a3c/recipes/generate.py#L86

So you need to generate the template manually like you did in your conf. At least for now until its implemented

musabgultekin avatar May 19 '24 15:05 musabgultekin

@musabgultekin Thanks for correcting me, I overlooked the fact that it wasn't using tokenize_messages. @ebsmothers is this something we should update?

RdoubleA avatar May 19 '24 15:05 RdoubleA

Is it possible to run inference on a CPU? I've tried running my finetuned with qlora LLama3-8B using device: cpu and it seems like it's stuck forever with no errors in the terminal. Is it just my hardware or does it silently fail?

rokasgie avatar May 20 '24 06:05 rokasgie

Add an option on generate recipe (e.g. chat_format), to determine whether to wrap the prompt as chat templates? Just like below:

# <|system|>
# You are a friendly chatbot who always responds in the style of a pirate.</s>
# <|user|>
# cfg.prompt</s>
# <|assistant|>

This might need some modifications in the tokenizer. tokenizer.tokenize_messages should call the corresponding chat template in data/_chat_formats.py to format the prompt.

It's similar to tokenizer.apply_chat_template function in transformers.

mantle2048 avatar May 23 '24 14:05 mantle2048

@julianstastny (and others) apologies for the delay here but I've put up a draft PR (#1019) that I hope should make it easier to pass the chat format. The example usage would be something like:

prompt:
  system: "You are a kind and helpful assistant. Respond to the following request."
  user: "Write a generation recipe for torchtune."
chat_format: MyChatFormat (this is optional)

And under the hood it'd call tokenizer.tokenize_messages to take care of the Llama3-style formatting (assuming you're using the Llama3 tokenizer). Alternatively you can still pass a string as in the current flow (with optional template) which'll still hit tokenizer.encode. I'd welcome any comments here or on the PR letting me know if this change would make things easier for you.

@rokasgie if you're still seeing problems running inference on CPU, please feel free to open a separate issue so that we can track it properly there.

ebsmothers avatar May 24 '24 21:05 ebsmothers

prompt:
  system: "You are a kind and helpful assistant. Respond to the following request."
  user: "Write a generation recipe for torchtune."

This logic is not working for Llama 3.1 @ebsmothers @RdoubleA

apthagowda97 avatar Sep 18 '24 07:09 apthagowda97