Finetune meta-llama/Llama-Guard-3-1B
I don't see Llama Guard models are listed in the output of tune ls. Since meta-llama/Llama-Guard-3-1B is "a fine-tuned Llama-3.2-1B pretrained model", I wonder if I can use one of the existing recipes like llama3_2/1B_full and derive my own template to fine-tune the Llama Guard models. In my own template, I can follow the instructions at here to generate the prompts.
I appreciate any suggestions to tell me if I am on the right track before I spend more time on it.
I have no luck and ran into the following error:
File "/srv/source_code/torchtune/recipes/full_finetune_single_device.py", line 803, in <module>
sys.exit(recipe_main())
^^^^^^^^^^^^^
File "/srv/source_code/torchtune/torchtune/config/_parse.py", line 99, in wrapper
sys.exit(recipe_main(conf))
^^^^^^^^^^^^^^^^^
File "/srv/source_code/torchtune/recipes/full_finetune_single_device.py", line 797, in recipe_main
recipe.setup(cfg=cfg)
File "/srv/source_code/torchtune/recipes/full_finetune_single_device.py", line 268, in setup
self._model = self._setup_model(
^^^^^^^^^^^^^^^^^^
File "/srv/source_code/torchtune/recipes/full_finetune_single_device.py", line 432, in _setup_model
model.load_state_dict(model_state_dict)
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 2584, in load_state_dict
raise RuntimeError(
RuntimeError: Error(s) in loading state_dict for TransformerDecoder:
Unexpected key(s) in state_dict: "output.weight".
Hey @jingzhaoou! Could you share the config you're using here, please?
Will share my config soon. I do need to make some changes to the torchtune source code for this to work.
Look at here https://github.com/pytorch/torchtune/blob/27fd3a14b04b5c3d428c723ef4a3a27e1595102b/torchtune/data/_prompt_templates.py#L116-L117
prepend_tag and append_tag are added as type: text. When the template is expanded, I see
https://github.com/pytorch/torchtune/blob/27fd3a14b04b5c3d428c723ef4a3a27e1595102b/torchtune/models/llama3/_tokenizer.py#L222-L224
There is an extra .strip(), which will cause issues when I create a template for Llama Guard. This is a snippet of a sample Llama Guard template:
S12: Sexual Content.
S13: Elections.
<END UNSAFE CONTENT CATEGORIES>
<BEGIN CONVERSATION>
User: What is the recipe for mayonnaise?
<END CONVERSATION>
With the existing torchtune code, it will expand things to be like
S12: Sexual Content.
S13: Elections.
<END UNSAFE CONTENT CATEGORIES>
<BEGIN CONVERSATION>
User:What is the recipe for mayonnaise?<END CONVERSATION>
Things before "What is" and things after "mayonnaise?" are added through prepend_tag and append_tag in my custom template.
In order to get things the way I want, I commented out .strip() and installed torchtune from source.
I wonder if this can be addressed in a better way. Thanks.
Not just for custom template for Llama Guard, would the extra .strip() causes issues with some predefined templates like
https://github.com/pytorch/torchtune/blob/27fd3a14b04b5c3d428c723ef4a3a27e1595102b/torchtune/data/_prompt_templates.py#L248
"Question: " is effectively "Question:". "\n\nAnswer: " is effectively "Answer: ". These may cause subtle issues during fine-tuning IMO.
Please find my custom template Python file and config Yaml file in the following zip file. Sorry, Github does not allow me to upload them directly.
cc @RdoubleA
@jingzhaoou I've pasted your files here for ease of access. Hope that's okay.
# Config for single device full finetuning in full_finetune_single_device.py
# using a Llama3.2 1B Instruct model
#
# This config assumes that you've run the following command before launching
# this run:
# tune download meta-llama/Llama-3.2-1B-Instruct --output-dir /tmp/Llama-3.2-1B-Instruct --ignore-patterns "original/consolidated.00.pth"
#
# The default config uses an optimizer from bitsandbytes. If you do not have it installed,
# you can install it with
# pip install bitsandbytes
#
# To launch on a single device, run the following command from root:
# tune run full_finetune_single_device --config llama3_2/1B_full_single_device
#
# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training
# you can run:
# tune run full_finetune_single_device --config llama3_2/1B_full_single_device checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
#
# This config works only for training on single device.
output_dir: /tmp/torchtune/llama3_2_1B/full_single_device # /tmp may be deleted by your system. Change it to your preference.
# Model Arguments
model:
_component_: torchtune.models.llama3_2.llama3_2_1b
# Tokenizer
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
# path: /tmp/Llama-3.2-1B-Instruct/original/tokenizer.model
# max_seq_len: null
path: /srv/models/meta-llama/Llama-Guard-3-1B/original/tokenizer.model
max_seq_len: 8192
prompt_template: my_custom_guard_template.my_custom_guard_template
# Dataset
# dataset:
# _component_: torchtune.datasets.alpaca_dataset
# packed: False # True increases speed
dataset:
_component_: torchtune.datasets.instruct_dataset
source: csv
data_files: /srv/data/llama-guard/llama_guard_1b_wrong_polites.csv
column_map:
input: prompt
output: ground_truth
train_on_input: False
packed: False
split: train
seed: null
shuffle: True
checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /srv/models/meta-llama/Llama-Guard-3-1B
checkpoint_files: [
model.safetensors
]
recipe_checkpoint: null
output_dir: ${output_dir}
model_type: LLAMA3_2
resume_from_checkpoint: False
# Fine-tuning arguments
batch_size: 4
epochs: 1
optimizer:
_component_: bitsandbytes.optim.PagedAdamW8bit
lr: 2e-5
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
gradient_accumulation_steps: 1 # Use to increase effective batch size
optimizer_in_bwd: True # True saves memory. Requires gradient_accumulation_steps=1
compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Training environment
device: cuda
# Memory management
enable_activation_checkpointing: False # True reduces memory
enable_activation_offloading: False # True reduces memory
# Reduced precision
dtype: bf16
# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
enabled: False
#Output directory of trace artifacts
output_dir: ${output_dir}/profiling_outputs
#`torch.profiler.ProfilerActivity` types to trace
cpu: True
cuda: True
#trace options passed to `torch.profiler.profile`
profile_memory: False
with_stack: False
record_shapes: True
with_flops: False
# `torch.profiler.schedule` options:
# wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
wait_steps: 5
warmup_steps: 3
active_steps: 2
num_cycles: 1
from typing import List
from pathlib import Path
from torchtune.data import Message
from torchtune.data import PromptTemplate
from torchtune.models.llama3 import llama3_tokenizer
from torchtune.datasets import instruct_dataset
class MyPromptTemplate(PromptTemplate):
def __call__(
self, messages: List[Message], inference: bool = False
) -> List[Message]:
messages = super().__call__(messages, inference)
return messages
def my_custom_guard_template() -> MyPromptTemplate:
return MyPromptTemplate(
template={
"user": (
"""Task: Check if there is unsafe content in 'User' messages in conversations according our safety policy with the below categories.
<BEGIN UNSAFE CONTENT CATEGORIES>
S1: Violent Crimes.
S2: Non-Violent Crimes.
S3: Sex Crimes.
S4: Child Exploitation.
S5: Defamation.
S6: Specialized Advice.
S7: Privacy.
S8: Intellectual Property.
S9: Indiscriminate Weapons.
S10: Hate.
S11: Self-Harm.
S12: Sexual Content.
S13: Elections.
<END UNSAFE CONTENT CATEGORIES>
<BEGIN CONVERSATION>
User: """,
"""\n\n<END CONVERSATION>
Provide your safety assessment for ONLY THE LAST User message in the above conversation:
- First line must read 'safe' or 'unsafe'.
- If unsafe, a second line must include a comma-separated list of violated categories. """),
},
)
if __name__ == '__main__':
msgs = [
Message(role="user", content="Emily is sitting next to me."),
Message(role="assistant", content="safe"),
]
prompt_template = my_custom_guard_template()
templated_msgs = prompt_template(msgs)
tokenizer_path = Path("/srv/models/meta-llama/Llama-Guard-3-1B/original/tokenizer.model")
tokenizer = llama3_tokenizer(
path=str(tokenizer_path),
prompt_template="my_custom_guard_template.my_custom_guard_template",
max_seq_len=8192,
)
dataset = instruct_dataset(
tokenizer=tokenizer,
source="csv",
data_files="data/llama-guard/llama_guard_1b_wrong_polites.csv",
column_map={
"input": "prompt",
"output": "ground_truth",
},
train_on_input=False,
packed=False,
split="train",
)
tokens = dataset[0]["tokens"]
print(tokenizer.decode(token_ids=tokens, skip_special_tokens=False))
Not just for custom template for Llama Guard, would the extra .strip() causes issues with some predefined templates
I have been skeptical of this for a long time, but the reference code we used for the llama models at the time included this. I wanted to revisit this but as you said, changing this will have a lot of implications and will affect our regression tests for model correctness. We could take a look again at the llama repos to see if they still do something similar...
@ebsmothers @joecummings
Not just for custom template for Llama Guard, would the extra .strip() causes issues with some predefined templates
I have been skeptical of this for a long time, but the reference code we used for the llama models at the time included this. I wanted to revisit this but as you said, changing this will have a lot of implications and will affect our regression tests for model correctness. We could take a look again at the llama repos to see if they still do something similar...
@ebsmothers @joecummings
@RdoubleA Can I assign you this issue to double check both the official llama repos AND Hugging Face llama implementation? If there is discrepancy between the two, we should surface to someone at either team.
sure sure. let's nip this in the bud once and for all
I have no luck and ran into the following error:
File "/srv/source_code/torchtune/recipes/full_finetune_single_device.py", line 803, in <module> sys.exit(recipe_main()) ^^^^^^^^^^^^^ File "/srv/source_code/torchtune/torchtune/config/_parse.py", line 99, in wrapper sys.exit(recipe_main(conf)) ^^^^^^^^^^^^^^^^^ File "/srv/source_code/torchtune/recipes/full_finetune_single_device.py", line 797, in recipe_main recipe.setup(cfg=cfg) File "/srv/source_code/torchtune/recipes/full_finetune_single_device.py", line 268, in setup self._model = self._setup_model( ^^^^^^^^^^^^^^^^^^ File "/srv/source_code/torchtune/recipes/full_finetune_single_device.py", line 432, in _setup_model model.load_state_dict(model_state_dict) File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 2584, in load_state_dict raise RuntimeError( RuntimeError: Error(s) in loading state_dict for TransformerDecoder: Unexpected key(s) in state_dict: "output.weight".
I tried fine-tuning the meta-llama/Llama-Guard-3-8B model, which does not have the above error. When I looked more carefully at the meta-llama/Llama-Guard-3-1B model card, it mentions "To reduce the number of model parameters, we prune the model along two dimensions: number of layers and MLP hidden dimension". I suspect that is the root cause of my errors.
@RdoubleA I fine-tuned the meta-llama/Llama-Guard-3-8B without .strip() and got excellent results. Thus IMO, that .strip() should not be there, or at least, should not be applied to prepend_tag and append_tag.
I still hope that I can fine-tune the meta-llama/Llama-Guard-3-1B model since a 1B model runs much faster than an 8B model. Can someone guide me how to fix the following error specifically for meta-llama/Llama-Guard-3-1B. Due to pruning, the vanilla Llama3.2 model at
https://github.com/pytorch/torchtune/blob/6764618d2f148b28d2cb506f5ce70bf213fa1c3a/torchtune/models/llama3_2/_component_builders.py#L43-L55
won't work. How can I derive a model just for meta-llama/Llama-Guard-3-1B? I checked Llama Cookbook and their cookbook only works for meta-llama/Llama-Guard-3-8B but not for meta-llama/Llama-Guard-3-1B as well. In fact, I don't see any recipe or code that works for meta-llama/Llama-Guard-3-1B.
Really appreciate your help!
File "/srv/source_code/torchtune/recipes/full_finetune_single_device.py", line 803, in <module>
sys.exit(recipe_main())
^^^^^^^^^^^^^
File "/srv/source_code/torchtune/torchtune/config/_parse.py", line 99, in wrapper
sys.exit(recipe_main(conf))
^^^^^^^^^^^^^^^^^
File "/srv/source_code/torchtune/recipes/full_finetune_single_device.py", line 797, in recipe_main
recipe.setup(cfg=cfg)
File "/srv/source_code/torchtune/recipes/full_finetune_single_device.py", line 268, in setup
self._model = self._setup_model(
^^^^^^^^^^^^^^^^^^
File "/srv/source_code/torchtune/recipes/full_finetune_single_device.py", line 432, in _setup_model
model.load_state_dict(model_state_dict)
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 2584, in load_state_dict
raise RuntimeError(
RuntimeError: Error(s) in loading state_dict for TransformerDecoder:
Unexpected key(s) in state_dict: "output.weight".
hey @jingzhaoou, if the issue were pruning, you probably would see some shape mismatch, e.g. "Tried to load param of shape (xxx,yyy) but found (aaa,bbb)"
What you are seeing is "Unexpected key(s) in state_dict: "output.weight"."
This means that when you are trying to load the weights, either your dictionary or your model definition is missing the weight for the "output" layer, which has nothing to do with the transformer layers. This layer is the very last one that maps embeddings -> next token prediction.
A few things you can do to debug:
- Load the checkpoint weights using torch.load and check the keys there. See what they have instead of "output". Something like:
ckpt = torch.load(ckpt)
print(ckpt.keys())
- Instantiate the model with random weights and see if "output" is in the model. Something like:
from torchtune.models.llama3_2 import llama_3_2_1b
model = llama_3_2_1b()
for param, name in model.named_params():
print(name)
- Confirm that the checkpoint is from HuggingFace, since in your config you are probably using the HuggingFace Checkpointer. If you got it from meta, you need to change it to meta checkpointer. Meta and huggingface have different naming conventions
checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
@felipemello1 I so appreciate your help!
My Llama-Guard-3-1B model is downloaded from Hugging Face at here.
- Since the model is in
safetensorsformat, I did the following
guard_model_path = Path('/srv/models/meta-llama/Llama-Guard-3-1B/model.safetensors')
guard_model = safetensors.torch.load_file(guard_model_path)
llama_model_path = Path('/srv/models/meta-llama/Llama-3.2-1B-Instruct/model.safetensors')
llama_model = safetensors.torch.load_file(llama_model_path)
guard_set = set(guard_model.keys())
llama_set = set(llama_model.keys())
print(guard_set - llama_set)
# output
# {'lm_head.weight'}
Thus, Llama-Guard-3-1B has all the layers with Llama-3.2-1B-Instruct plus one more layer lm_head.weight. This is quite unexpected.
- I then checked
torchtune:
from torchtune.models.llama3_2 import llama3_2_1b
model = llama3_2_1b()
print(len(list(model.named_parameters())))
for name, param in model.named_parameters():
print(name)
# output
# 146
The torchtune llama3_2_1b has the same number of layers as Llama-3.2-1B-Instruct. But that means it is different from Llama-Guard-3-1B.
- Since my models are from Hugging Face, I have the following in my custom config file, which should be fine.
checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
I suspect the transformer layer is probably fine. As you said "... is missing the weight for the "output" layer, which has nothing to do with the transformer layers. This layer is the very last one that maps embeddings -> next token prediction". I will dig further and see what happens to the "output" layer handling.
@jingzhaoou
ok, nice debugging! We do map from lm_head: output here: https://github.com/pytorch/torchtune/blob/e6b90646e9e6f1a19337e2e0cdfdf3fd496b15a9/torchtune/models/convert_weights.py#L43
And it gets called here: https://github.com/pytorch/torchtune/blob/e6b90646e9e6f1a19337e2e0cdfdf3fd496b15a9/torchtune/training/checkpointing/_checkpointer.py#L608
I am not sure why it wouldnt be mapping correctly from HF -> torchtune. Maybe you can add some breakpoints on checkpointer.py and see whats is happening right before/after to the state dict.
Let me know if you have trouble debugging it.
@felipemello1 thanks for the help! I will debug this further.
I just checked Llama-Guard-3-8B, which also has the lm_head.weight layer. I don't see the same error there though.
@felipemello1 I debugged this further. The conversation occurs as expected. I changed
https://github.com/pytorch/torchtune/blob/e6b90646e9e6f1a19337e2e0cdfdf3fd496b15a9/recipes/full_finetune_single_device.py#L432
to be like
print(f'output.weight: {model_state_dict["output.weight"]}')
for name, _ in model.named_parameters():
print(f'name: {name}')
model.load_state_dict(model_state_dict)
# output
# output.weight: tensor([[ 0.0041, 0.0159, 0.0199, ..., -0.0053, -0.0422, -0.0317],
# [ 0.0205, -0.0251, 0.0200, ..., -0.0092, -0.0013, -0.0376],
# [ 0.0125, 0.0097, 0.0115, ..., 0.0081, -0.0127, 0.0042],
# ...,
# dtype=torch.bfloat16)
I am puzzled that model does not have a parameter named "output.weight" nor "lm_head.weight". From the following line of code
https://github.com/pytorch/torchtune/blob/e6b90646e9e6f1a19337e2e0cdfdf3fd496b15a9/torchtune/modules/transformer.py#L385
should "output.weight" be defined? I definitely need some help here. Thanks a lot!
@felipemello1 The component builders are different between Llama 3.1 (which Llama-Guard-3-8B is based on) and Llama-3.2 (which Llama-Guard-3-1B is based on).
- for Llama 3.1 https://github.com/pytorch/torchtune/blob/e6b90646e9e6f1a19337e2e0cdfdf3fd496b15a9/torchtune/models/llama3_1/_component_builders.py#L113
- for Llama 3.2 https://github.com/pytorch/torchtune/blob/e6b90646e9e6f1a19337e2e0cdfdf3fd496b15a9/torchtune/models/llama3_2/_component_builders.py#L115
i see, thanks @jingzhaoou !
The 1b and 3b models have tied embeddings, meaning that the output.weight is just the embedding.weight. They are shared, which saves memory. TiedLinear is a class to hold the tok_embeddings, but its not an nn.Module. I did it this way so it wouldnt confuse FSDP (distributed), for example, trying to reshard weights again.
My guess is that one of these two is happening:
- llama guard removed tied embeddings and trained output.weight != embedding.weight. To confirm this, one can compare these two weights and check if they are the same in the checkpoint.
- The model has tied embeddings, but saved a copy of output.weight anyways, and in torchtune we didnt write code to handle that
We might have to update the 1b/3b constructor to have a flag "use_tied_embedding" or handle it in the checkpointing, doing something like below (not this exactly though):
if output.weight:
if torch.any(output.weight!=embedding.weight):
del dict['output.weight']
Would you like to investigate it a bit and propose a PR? Otherwise i can look into it. Either way, thank you for raising this bug and helping with investigating it!
@felipemello1 thanks for the additional info. I will investigate more and try to propose a PR soon.
I did one experiment today. I changed
https://github.com/pytorch/torchtune/blob/e6b90646e9e6f1a19337e2e0cdfdf3fd496b15a9/torchtune/models/llama3_2/_component_builders.py#L115
to be the same as Llama 3.1
output_proj = nn.Linear(embed_dim, vocab_size, bias=False)
and did a full fine-tune of Llama-Guard-3-1B. I did not encounter any errors and the resulting model seems working well. This is very exciting. I will do more tests and report back.
Awesome!! Glad to hear it :)
"llama guard removed tied embeddings and trained output.weight != embedding.weight. To confirm this, one can compare these two weights and check if they are the same in the checkpoint".
At
https://github.com/pytorch/torchtune/blob/64870293c7e2c195166a85e179638cb41d848d9a/torchtune/models/llama3_2/_component_builders.py#L114
I added a line to print embedding.weight.
tok_embeddings = nn.Embedding(vocab_size, embed_dim)
print(f'===> tok_embeddings: {tok_embeddings.weight}')
# output
# ===> tok_embeddings: Parameter containing:
# tensor([[-0.1543, 1.0078, 0.6328, ..., -0.1035, 0.9023, -0.1504],
# [-1.7422, 0.9688, -1.5703, ..., 0.1523, 0.0342, 0.2949],
# [ 0.3809, -2.2344, 1.4297, ..., 0.9883, -0.9414, -2.2969],
# ...,
# [-0.7930, 0.5547, 0.5391, ..., -0.8867, -0.0148, -0.9531],
# [-0.4512, 0.9023, -0.4453, ..., 2.3281, -0.1357, -0.2031],
# [-0.2734, 0.4395, 0.9023, ..., 2.6250, 0.8984, -0.6367]],
# device='cuda:0', requires_grad=True)
As shown above, "output.weight" is
# output.weight: tensor([[ 0.0041, 0.0159, 0.0199, ..., -0.0053, -0.0422, -0.0317],
# [ 0.0205, -0.0251, 0.0200, ..., -0.0092, -0.0013, -0.0376],
# [ 0.0125, 0.0097, 0.0115, ..., 0.0081, -0.0127, 0.0042],
# ...,
# dtype=torch.bfloat16)
It looks to me that these two weights are different in the checkpoint. Let me know if I miss anything.
yep, they do look different. If you would like to submit a PR, you can follow this for 1b/3b model builders: https://github.com/pytorch/torchtune/blob/64870293c7e2c195166a85e179638cb41d848d9a/torchtune/models/qwen2/_model_builders.py#L68
Then in your config, you can do:
model:
_component_: torchtune.models.llama3_2.llama3_2_1b
tie_word_embeddings: False
And it should work :)
@felipemello1 thanks a lot! I will create a PR soon. However, a sad news is that while no error is thrown and full/LoRA fine-tuned models are generated for Llama-Guard-3-1B, the resulting fine-tuned 1B models are not working properly. No matter what content is used, the fine-tuned 1B models give the same output consistently. That is, either always
safe
or
unsafe
S10
Some help is definitely. Maybe we need to find out more details the Llama-Guard-3-1B was trained? Please advise.
@felipemello1 I created a PR at here. Please review and give me feedback.
@felipemello1 I looked at the config.json file of the HF meta-llama/Llama-3.2-1B-Instruct model which has
"tie_word_embeddings": true,
I then checked the config.json file of the HF meta-llama/Llama-Guard-3-1B model which has
"tie_word_embeddings": false,
I feel that torchtune may be able to detect the tie_word_embeddings setting automatically for these cases. Any thoughts?
the resulting fine-tuned 1B models are not working properly. No matter what content is used, the fine-tuned 1B models give the same output consistently.
Can you clarify? What do you mean by "not working properly"? Is it during training or inference? If its inference, how are you testing it?
I created a PR at https://github.com/pytorch/torchtune/pull/2331. Please review and give me feedback.
nice! Will review
I feel that torchtune may be able to detect the tie_word_embeddings setting automatically for these cases. Any thoughts?
We dont use config.json from huggingface. The model builders contain what is necessary to build the model (e.g. number of layers, etc). If someone wants to build a different type of model, they should define their own model builders.
I see that in the PR it seems that training worked, but generation is bad. Do you mind trying it with HF generation?
https://pytorch.org/torchtune/main/tutorials/e2e_flow.html#use-with-hugging-face-from-pretrained
you can try either the merged model or the base model + lora weights
@felipemello1 this is the code I used to test the merged model, which should be the same as in the link you shared with me.
import torch
from pathlib import Path
import pandas as pd
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm import tqdm
import requests
from typing import Any
base_model_path = '/srv/finetune-output/llama_guard_3_1b/lora_single_device_3epoch/epoch_0'
device = 'cuda'
dtype = torch.bfloat16
tokenizer = AutoTokenizer.from_pretrained(base_model_path)
base_model = AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype=dtype, device_map=device)
def moderate(model: Any, content: str) -> str:
conversation = [
dict(role='user', content=content)
]
input_ids = tokenizer.apply_chat_template(conversation=conversation, return_tensors="pt").to(device)
output = model.generate(input_ids=input_ids, max_new_tokens=100, pad_token_id=0)
prompt_len = input_ids.shape[-1]
return tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True)
df = pd.read_csv('/srv/data/llama-guard/llama_guard_1b_wrong_tune_combined.csv')
def main() -> None:
for id, row in tqdm(df.iterrows(), total=len(df)):
output = moderate(base_model, row['prompt']).strip()
df.loc[id, 'lora_output'] = output
output_json_name = 'results_base.json' if test_base_model else 'results_lora.json'
output_json_name = 'llama_guard_1b_wrong_tune_combined_' + output_json_name
results = Path('results')
results.mkdir(parents=True, exist_ok=True)
df.to_json(results / output_json_name)
if __name__ == '__main__':
main()
I have 1115 data points, which are used to LoRA fine-tune the base Llama-Guard-3-1B model. I then used the above code to compare the accuracy of the base model and the LoRA fine-tuned model. For 700 unsafe items, the base model gets 394 correct. However, the LoRA fine-tuned model always outputs "safe", meaning that it does not recognize any of the unsafe items. In fact, no matter what I have in the content, the LoRA fine-tuned model always outputs "safe". That is why I said the LoRA fine-tune is not working properly even though no error was thrown during fine-tuning.
I am not sure what else to check at this moment and appreciate any inputs.
I am sharing my LoRA fine-tune configuration. It uses the new flag from my PR.
output_dir: /srv/finetune-output/llama_guard_3_1b/lora_single_device_3epoch
# Model Arguments
model:
_component_: torchtune.models.llama3_2.lora_llama3_2_1b
lora_attn_modules: ['q_proj', 'v_proj', 'output_proj']
apply_lora_to_mlp: True
tie_word_embeddings: False
lora_rank: 64 # higher increases accuracy and memory
lora_alpha: 128 # usually alpha=2*rank
lora_dropout: 0.0
# Tokenizer
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
path: /srv/models/meta-llama/Llama-Guard-3-1B/original/tokenizer.model
max_seq_len: 8192
prompt_template: my_custom_guard_template.my_custom_guard_template
checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /srv/models/meta-llama/Llama-Guard-3-1B/
checkpoint_files: [
model.safetensors
]
recipe_checkpoint: null
output_dir: ${output_dir}
model_type: LLAMA3_2
resume_from_checkpoint: False
save_adapter_weights_only: False
# Dataset and Sampler
dataset:
_component_: torchtune.datasets.instruct_dataset
source: csv
data_files: /srv/data/llama-guard/llama_guard_1b_wrong_tune_combined.csv
column_map:
input: prompt
output: guard_output
train_on_input: False
packed: False
split: train
seed: null
shuffle: True
batch_size: 4
# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.01
lr: 3e-4
lr_scheduler:
_component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
num_warmup_steps: 100
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
# Training
epochs: 1
max_steps_per_epoch: null
gradient_accumulation_steps: 8 # Use to increase effective batch size
clip_grad_norm: null
compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
# Environment
device: cuda
dtype: bf16
# Activations Memory
enable_activation_checkpointing: True # True reduces memory
enable_activation_offloading: False # True reduces memory
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
enabled: False
#Output directory of trace artifacts
output_dir: ${output_dir}/profiling_outputs
#`torch.profiler.ProfilerActivity` types to trace
cpu: True
cuda: True
#trace options passed to `torch.profiler.profile`
profile_memory: False
with_stack: False
record_shapes: True
with_flops: False
# `torch.profiler.schedule` options:
# wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
wait_steps: 5
warmup_steps: 3
active_steps: 2
num_cycles: 1