FSDP QLORA doesn't work with multiple adapters
System Info
-
Accelerateversion: 1.10.1 - Platform: Linux-5.15.0-157-generic-x86_64-with-glibc2.39
-
acceleratebash location: /opt/uv/venv/bin/accelerate - Python version: 3.12.11
- Numpy version: 2.2.6
- PyTorch version: 2.9.0.dev20250825+cu128
- PyTorch accelerator: CUDA
- System RAM: 1003.13 GB
- GPU type: NVIDIA H100 80GB HBM3
-
Acceleratedefault config: Not found - peft version: 0.15.2
- transformers version: 4.56.1
Who can help?
@benjaminbossan @githubnemo
Reproduction
"""Based on peft/examples/sft/run_peft_qlora_fsdp.sh
Launch command:
{
"name": "Accelerate Launch - Minimal FSDP QLoRA Training",
"type": "debugpy",
"request": "launch",
"module": "accelerate.commands.launch",
"args": [
"--config_file",
"scripts/fsdp_config_qlora.yaml",
"--num_processes",
"2",
"scripts/20251008_fsdp_qlora_sft_custom.py",
"--seed",
"100",
"--model_name_or_path",
"meta-llama/Llama-3.1-8B-Instruct",
"--dataset_name",
"smangrul/ultrachat-10k-chatml",
"--add_special_tokens",
"False",
"--append_concat_token",
"False",
"--splits",
"train,test",
"--max_seq_len",
"2048",
"--num_train_epochs",
"1",
"--logging_steps",
"5",
"--log_level",
"info",
"--logging_strategy",
"steps",
"--learning_rate",
"1e-4",
"--lr_scheduler_type",
"cosine",
"--weight_decay",
"1e-4",
"--warmup_ratio",
"0.0",
"--max_grad_norm",
"1.0",
"--output_dir",
"llama-sft-qlora-fsdp",
"--per_device_train_batch_size",
"2",
"--per_device_eval_batch_size",
"2",
"--gradient_accumulation_steps",
"2",
"--gradient_checkpointing",
"True",
"--lora_r",
"8",
"--lora_alpha",
"16",
"--lora_dropout",
"0.1",
"--lora_target_modules",
"all-linear",
"--max_steps",
"2",
],
"console": "integratedTerminal",
"justMyCode": false,
"cwd": "${workspaceFolder}"
}
"""
import os
import sys
from dataclasses import dataclass, field
import torch
from accelerate import Accelerator
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
from datasets.builder import DatasetGenerationError
from peft import LoraConfig, PeftConfig, PeftModel
from peft.utils.other import fsdp_auto_wrap_policy
from torch.utils.data import DataLoader
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
HfArgumentParser,
PreTrainedModel,
PreTrainedTokenizer,
TrainingArguments,
get_scheduler,
set_seed,
)
from transformers.data.data_collator import DataCollatorWithPadding
class MinimalSFTTrainer:
def __init__(
self,
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
peft_config: PeftConfig,
train_dataset,
args: TrainingArguments,
):
self.args = args
self.train_dataset = train_dataset
self.tokenizer = tokenizer
# Initialize accelerator with FSDP
self.accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision="bf16",
)
# Prepare PEFT model
if args.gradient_checkpointing:
model.gradient_checkpointing_enable()
model.enable_input_require_grads()
# Create PEFT model
self.model = PeftModel.from_pretrained(
model,
"AlignmentResearch/Llama-3.1-8B-Instruct-gsm8k-lora-reference",
autocast_adapter_dtype=False,
adapter_name="reference",
)
self.model.load_adapter(
"AlignmentResearch/Llama-3.1-8B-Instruct-gsm8k-lora-reference", adapter_name="policy", autocast_adapter_dtype=False
)
# Critical: Update FSDP plugin for QLORA
if self.accelerator.state.fsdp_plugin is not None:
fsdp_plugin = self.accelerator.state.fsdp_plugin
# Set auto wrap policy for PEFT
fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(self.model)
quant_storage = model.hf_quantizer.quantization_config.bnb_4bit_quant_storage
if quant_storage.is_floating_point:
fsdp_plugin.set_mixed_precision(quant_storage, override=True)
# Create dataloader
formatted_ds = self.train_dataset.map(
lambda x: {"content": tokenizer.apply_chat_template(x["messages"], tokenize=False)},
batched=False,
remove_columns=self.train_dataset.column_names,
)
tokenized_ds = formatted_ds.map(
lambda x: self.tokenizer(x["content"], truncation=True), batched=True, remove_columns=formatted_ds.column_names
)
self.train_dataloader = DataLoader(
tokenized_ds,
batch_size=args.per_device_train_batch_size,
collate_fn=DataCollatorWithPadding(self.tokenizer),
shuffle=True,
)
# Create optimizer - only optimize trainable parameters
optimizer_params = [p for p in model.parameters() if p.requires_grad]
self.optimizer = torch.optim.AdamW(
optimizer_params,
lr=args.learning_rate,
weight_decay=args.weight_decay,
)
# Calculate training steps
num_update_steps_per_epoch = len(self.train_dataloader) // args.gradient_accumulation_steps
max_steps = args.max_steps if args.max_steps > 0 else int(args.num_train_epochs * num_update_steps_per_epoch)
# Create scheduler
self.lr_scheduler = get_scheduler(
args.lr_scheduler_type,
optimizer=self.optimizer,
num_warmup_steps=args.warmup_steps,
num_training_steps=max_steps,
)
# Prepare everything with accelerator
self.model.base_model.set_adapter(["reference", "policy"])
self.accelerator.print(f"Active adapters: {self.model.active_adapters}")
for name, param in self.model.named_parameters():
if "layers.0.self_attn.q_proj" in name:
print(f"{name} {param.shape} {param.device} {param.dtype} {param.requires_grad}")
# N.B. the below will hang unless peft.tuners.tuner_utils.py::BaseTunerLayer._move_adapter_to_device_of_base_layer is
# overridden to remove the special meta device handling
self.accelerator.print("Preparing everything with accelerator")
self.model, self.optimizer, self.train_dataloader, self.lr_scheduler = self.accelerator.prepare(
self.model, self.optimizer, self.train_dataloader, self.lr_scheduler
)
self.accelerator.print("Everything prepared with accelerator")
self.global_step = 0
self.max_steps = max_steps
def create_dataset(tokenizer, data_args):
raw_datasets = DatasetDict()
for split in data_args.splits.split(","):
try:
# Try first if dataset on a Hub repo
dataset = load_dataset(data_args.dataset_name, split=split)
except DatasetGenerationError:
# If not, check local dataset
dataset = load_from_disk(os.path.join(data_args.dataset_name, split))
assert isinstance(dataset, Dataset)
dataset = dataset.select(range(8))
if "train" in split:
raw_datasets["train"] = dataset
elif "test" in split:
raw_datasets["test"] = dataset
else:
raise ValueError(f"Split type {split} not recognized as one of test or train.")
train_data = raw_datasets["train"]
print(f"Size of the train set: {len(train_data)}")
print(f"A sample of train dataset: {train_data[0]}")
return train_data
def create_and_prepare_model(args):
quant_storage_dtype = torch.bfloat16
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype="bfloat16",
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_storage=quant_storage_dtype,
)
torch_dtype = quant_storage_dtype if quant_storage_dtype and quant_storage_dtype.is_floating_point else torch.float32
# Prepare model loading arguments
model_kwargs = {
"trust_remote_code": True,
"torch_dtype": torch_dtype,
"attn_implementation": "flash_attention_2",
"quantization_config": bnb_config,
"use_cache": False,
}
model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, **model_kwargs)
peft_config = LoraConfig(
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
r=args.lora_r,
bias="none",
task_type="CAUSAL_LM",
target_modules=args.lora_target_modules.split(",")
if args.lora_target_modules != "all-linear"
else args.lora_target_modules,
)
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
return model, peft_config, tokenizer
# Define and parse arguments.
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
max_seq_length: int | None = field(
default=512,
metadata={"help": "The maximum total input sequence length after tokenization."},
)
lora_alpha: int | None = field(default=16)
lora_dropout: float | None = field(default=0.1)
lora_r: int | None = field(default=64)
lora_target_modules: str | None = field(
default="q_proj,k_proj,v_proj,o_proj,down_proj,up_proj,gate_proj",
metadata={"help": "comma separated list of target modules to apply LoRA layers to"},
)
@dataclass
class DataTrainingArguments:
dataset_name: str | None = field(
default="timdettmers/openassistant-guanaco",
metadata={"help": "The preference dataset to use."},
)
append_concat_token: bool | None = field(
default=False,
metadata={"help": "If True, appends `eos_token_id` at the end of each sample being packed."},
)
add_special_tokens: bool | None = field(
default=False,
metadata={"help": "If True, tokenizers adds special tokens to each sample being packed."},
)
splits: str | None = field(
default="train,test",
metadata={"help": "Comma separate list of the splits to use from the dataset."},
)
def main(model_args, data_args, training_args):
# Set seed for reproducibility
set_seed(training_args.seed)
# model
model, peft_config, tokenizer = create_and_prepare_model(model_args)
training_args.dataset_kwargs = {
"append_concat_token": data_args.append_concat_token,
"add_special_tokens": data_args.add_special_tokens,
}
# datasets
train_dataset = create_dataset(
tokenizer,
data_args,
)
# trainer
trainer = MinimalSFTTrainer(
model=model,
tokenizer=tokenizer,
args=training_args,
train_dataset=train_dataset,
peft_config=peft_config,
)
trainer.accelerator.print(f"{trainer.model}")
if hasattr(trainer.model, "print_trainable_parameters"):
trainer.model.print_trainable_parameters()
if __name__ == "__main__":
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
main(model_args, data_args, training_args)
Expected behavior
The accelerator.prepare should not hang, and also I would expect that device1 would show all tensors on meta device, but in fact it shows that the second adapter is redundantly on cpu
base_model.model.model.layers.0.self_attn.q_proj.base_layer.weight torch.Size([4194304, 1]) meta torch.bfloat16 False
base_model.model.model.layers.0.self_attn.q_proj.lora_A.reference.weight torch.Size([64, 4096]) meta torch.bfloat16 True
base_model.model.model.layers.0.self_attn.q_proj.lora_A.policy.weight torch.Size([64, 4096]) cpu torch.float32 True
base_model.model.model.layers.0.self_attn.q_proj.lora_B.reference.weight torch.Size([4096, 64]) meta torch.bfloat16 True
base_model.model.model.layers.0.self_attn.q_proj.lora_B.policy.weight torch.Size([4096, 64]) cpu torch.float32 True
I was able to fix this by commenting out these two lines: https://github.com/huggingface/peft/blob/ec5a1b2ce62b14925239f62582eb3db5d9174af2/src/peft/tuners/tuners_utils.py#L1458 but I'm not sure if that would break other use cases
I didn't have time yet for a deep dive into the FSDP issue you reported in #2833, hopefully I'll have time tomorrow. Once that is fixed, I plan to take a closer look at this issue, maybe there is a relation between the two issues.
As for your suggested PR, it would break the lwo_cpu_mem_usage=True option in PEFT (to wit: pytest tests/test_custom_models.py -k low_cpu), so we'll have to find something else.
@ojh31 Given the upcoming transformers v5 release and the work associated with that, I didn't have time week to investigate your issue (which can potentially take quite some time because git bisect is not an option with this many coupled libraries). I hope I can find a solution next week.
@ojh31 Given the upcoming transformers v5 release and the work associated with that, I didn't have time week to investigate your issue (which can potentially take quite some time because git bisect is not an option with this many coupled libraries). I hope I can find a solution next week.
Thanks for the update!
I finally had time to look at this issue. I couldn't quite get your script to run, but I took the existing sft/train.py script and dropped this into the main function, I think it should reflect what you tried to do:
def main(model_args, data_args, training_args):
...
model, peft_config, tokenizer = create_and_prepare_model(model_args, data_args, training_args)
# v added this v
from peft import PeftModel
model = PeftModel.from_pretrained(
model,
"AlignmentResearch/Llama-3.1-8B-Instruct-gsm8k-lora-reference",
autocast_adapter_dtype=False,
adapter_name="reference",
)
model.load_adapter(
"AlignmentResearch/Llama-3.1-8B-Instruct-gsm8k-lora-reference",
adapter_name="policy",
autocast_adapter_dtype=False,
)
model.base_model.set_adapter(["reference", "policy"])
peft_config = None
After doing so, I do indeed get the hanging you reported. The fix you proposed (which, as I mentioned, is problematic) removes the hanging but then during training, I get:
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
Not exactly sure why, but it seems it works fine for you.
Now, to fix the original issue without the problematic code change, could you please try setting fsdp_cpu_ram_efficient_loading: false in your FSDP config? This resolved the hanging issue for me without any further changes to PEFT. I still get the aforementioned RuntimeError, but that seems to be a different issue.
Now, to fix the original issue without the problematic code change, could you please try setting
fsdp_cpu_ram_efficient_loading: falsein your FSDP config? This resolved the hanging issue for me without any further changes to PEFT. I still get the aforementionedRuntimeError, but that seems to be a different issue.
Yes I can confirm that this avoids the hanging issue, but it is a shame to have to lose the CPU efficient loading.
Thanks for confirming. I agree it would be nice if it worked with FSDP but would think that in the grand scheme of things, it's not a significant factor. LMK if it impacts your use case significantly, but finding a clean solution for this will probably not be easy.
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.