LoRA fine-tuning weights explosion in FSDP training
Dear authors,
I encountered weights explosion problems during integrating LoRA to torchtitan. I am running with train_configs/llama3_8b.toml configs with run_llama_train.sh on 4 A10 24GB GPUs. PyTorch version is the latest 2.5.0 nightly.
I have made the following changes so far:
-
In train.py, I added two utility functions get_parameters(), calculate_parameter_change() to compute the difference of LoRA weight during each training step. And a call to mark_only_lora_as_trainable() function which mark all lora matrices to be trainable and freeze all other layers (implemented by the original LoRA authors in model.py). When creating model_cls from model args, I changed device from meta to cpu since this would allow me to load pretrained weights directly (Step 3).
-
In model.py, I replaced the wq, wk, wv matrices from nn.Linear to custom defined Linear layer implemented by the authors of LoRA to incorporate LoRA adapter training.
-
In model.py, I replaced the init_weights function in the Transformer module with loading pretrained weights from HF llama-3-8B-Instruct checkpoint. I checked the weights loaded and it seems to be loading the correct weights.
Since the LoRA implementation has been widely tested, I suppose it should be fine. What I also noted was that when I didn't set device to cpu in creating model_cls from model args (default in code would be meta), my weight copying operation in step 3 would essentially copy all 0s to the tensors. However, in this case, LoRA-A's behavior seems to be very similar to after I copied the weights correctly, while LoRA-B, which is 0-initialized, would keep staying at 0. In the current case, both LoRA-A and LoRA-B have exploding weights. Due to the similar behavior of LoRA-A under different initialization, I am suspecting the bug still lies in the FSDP setup somewhere. I have attached my code below and would appreciate any hints on where the bug might come from (Uploading txt files since py is not allowed, they are py files).
Below is sample output showing the progression of LoRA A/B weight changes from a sampled layer.
[rank0]:{'layers.29._checkpoint_wrapped_module.attention.wq.lora_A': 146.9622344970703, 'layers.29._checkpoint_wrapped_module.attention.wq.lora_B': 5.200856207920879e-07, 'layers.29._checkpoint_wrapped_module.attention.wk.lora_A': 147.65611267089844, 'layers.29._checkpoint_wrapped_modul e.attention.wk.lora_B': 6.672774333082998e-08, 'layers.29._checkpoint_wrapped_module.attention.wv.lora_A': 147.62704467773438, 'layers.29._checkpoint_wrapped_module.attention.wv.lora_B': 0.018832538276910782} [rank0]:{'layers.29._checkpoint_wrapped_module.attention.wq.lora_A': 18811.1484375, 'layers.29._checkpoint_wrapped_module.attention.wq.lora_B': 6.657096673734486e-05, 'layers.29._checkpoint_wrapped_module.attention.wk.lora_A': 18899.97265625, 'layers.29._checkpoint_wrapped_module.attent ion.wk.lora_B': 4.27057602792047e-06, 'layers.29._checkpoint_wrapped_module.attention.wv.lora_A': 18896.2734375, 'layers.29._checkpoint_wrapped_module.attention.wv.lora_B': 1.2052823305130005} [rank0]:{'layers.29._checkpoint_wrapped_module.attention.wq.lora_A': 2407827.0, 'layers.29._checkpoint_wrapped_module.attention.wq.lora_B': 0.008521084673702717, 'layers.29._checkpoint_wrapped_module.attention.wk.lora_A': 2419196.5, 'layers.29._checkpoint_wrapped_module.attention.wk.lor a_B': 0.00027331686578691006, 'layers.29._checkpoint_wrapped_module.attention.wv.lora_A': 2418723.0, 'layers.29._checkpoint_wrapped_module.attention.wv.lora_B': 77.13806915283203}
@weifengpy any thoughts?
Hi @MinghaoYan , thanks for filing the issue. I practiced LoRA + FSDP in TorchTune so would love to understand if there are any FSDP bugs
Are you open to take a look at the loss together with me? just to make sure the lora is setup correctly
- are you loading from a pretrained checkpoint, say llama2 or Llama3?
- after loading the model, without applying any lora adapters, could I know the loss? For example, loss = 2.0. This is just checking if the model is loaded correctly
- after applying lora adapters, before 1st optim.step, could I know the loss? I would expect the the loss to be similar to 2.0 from step 2. This is checking if lora adapters are added correctly
If you just need a LoRA + FSDP recipe, or need a reference implmentation, here is one from TorchTune: https://github.com/pytorch/torchtune/blob/main/recipes/configs/dev/llama2/7B_lora_fsdp2.yaml
Thanks for your reply! I think I might not be loading correctly. I disabled all Lora implementation and reverted to the default llama-3-8b setup. Currently I am trying to copy the weights from a HuggingFace Llama-3-8B-Instruct checkpoint to an instantiated Transformer module in torchtitan manually by replacing the Transformer init_weight function with the following:
def init_weights(self):
with torch.device(self.freqs_cis.device):
self.freqs_cis = self._precompute_freqs_cis()
self._copy_weights()
def _copy_weights(self, pretrained_model_name="meta-llama/Meta-Llama-3-8B-Instruct"):
# Copy embedding weights
pretrained_model = AutoModelForCausalLM.from_pretrained(pretrained_model_name).to('cpu')
# Copy embedding weights
self.tok_embeddings.weight.data = pretrained_model.model.embed_tokens.weight.data.clone()
# print(self.tok_embeddings.weight.data)
# Copy transformer layer weights
for i, pretrained_layer in enumerate(pretrained_model.model.layers):
# Attention weights
assert self.layers[str(i)].attention.wq.weight.data.shape == pretrained_layer.self_attn.q_proj.weight.data.shape, f"Mismatch in shape for wq at layer {i}"
assert self.layers[str(i)].attention.wk.weight.data.shape == pretrained_layer.self_attn.k_proj.weight.data.shape, f"Mismatch in shape for wk at layer {i}"
assert self.layers[str(i)].attention.wv.weight.data.shape == pretrained_layer.self_attn.v_proj.weight.data.shape, f"Mismatch in shape for wv at layer {i}"
assert self.layers[str(i)].attention.wo.weight.data.shape == pretrained_layer.self_attn.o_proj.weight.data.shape, f"Mismatch in shape for wo at layer {i}"
self.layers[str(i)].attention.wq.weight.data = pretrained_layer.self_attn.q_proj.weight.data.clone()
self.layers[str(i)].attention.wk.weight.data = pretrained_layer.self_attn.k_proj.weight.data.clone()
self.layers[str(i)].attention.wv.weight.data = pretrained_layer.self_attn.v_proj.weight.data.clone()
self.layers[str(i)].attention.wo.weight.data = pretrained_layer.self_attn.o_proj.weight.data.clone()
# Feed-forward weights
assert self.layers[str(i)].feed_forward.w1.weight.data.shape == pretrained_layer.mlp.gate_proj.weight.data.shape, f"Mismatch in shape for w1 at layer {i}"
assert self.layers[str(i)].feed_forward.w2.weight.data.shape == pretrained_layer.mlp.down_proj.weight.data.shape, f"Mismatch in shape for w2 at layer {i}"
assert self.layers[str(i)].feed_forward.w3.weight.data.shape == pretrained_layer.mlp.up_proj.weight.data.shape, f"Mismatch in shape for w3 at layer {i}"
self.layers[str(i)].feed_forward.w1.weight.data = pretrained_layer.mlp.gate_proj.weight.data.clone()
self.layers[str(i)].feed_forward.w2.weight.data = pretrained_layer.mlp.down_proj.weight.data.clone()
self.layers[str(i)].feed_forward.w3.weight.data = pretrained_layer.mlp.up_proj.weight.data.clone()
# LayerNorm weights
assert self.layers[str(i)].attention_norm.weight.data.shape == pretrained_layer.input_layernorm.weight.data.shape, f"Mismatch in shape for attention_norm weight at layer {i}"
assert self.layers[str(i)].ffn_norm.weight.data.shape == pretrained_layer.post_attention_layernorm.weight.data.shape, f"Mismatch in shape for ffn_norm weight at layer {i}"
self.layers[str(i)].attention_norm.weight.data = pretrained_layer.input_layernorm.weight.data.clone()
self.layers[str(i)].ffn_norm.weight.data = pretrained_layer.post_attention_layernorm.weight.data.clone()
# Init LoRA weights
# nn.init.kaiming_uniform_(self.layers[str(i)].attention.wq.lora_A, a=math.sqrt(5))
# nn.init.zeros_(self.layers[str(i)].attention.wq.lora_B)
# nn.init.kaiming_uniform_(self.layers[str(i)].attention.wk.lora_A, a=math.sqrt(5))
# nn.init.zeros_(self.layers[str(i)].attention.wk.lora_B)
# nn.init.kaiming_uniform_(self.layers[str(i)].attention.wv.lora_A, a=math.sqrt(5))
# nn.init.zeros_(self.layers[str(i)].attention.wv.lora_B)
# Copy final layer norm
assert self.norm.weight.data.shape == pretrained_model.model.norm.weight.data.shape
self.norm.weight.data = pretrained_model.model.norm.weight.data.clone()
# Copy lm_head weights
assert self.output.weight.data.shape == pretrained_model.lm_head.weight.data.shape
self.output.weight.data = pretrained_model.lm_head.weight.data.clone()
del pretrained_model
I noticed that since the model weights don't fit on GPU, I loaded the pretrained model on CPU and then copied the weights over. This would cause problem later since some tensors would be on cpu instead of cuda. However, if I just call model.to("cuda") after init_weights(), this would throw an OOM error.
I was doing this since I couldn't find a better way to load from a HuggingFace checkpoint directly into the Transformer module in torchtitan, I would appreciate it if you have a pointer on how to load from a HF checkpoint in torchtitan.
I am hoping to build upon torchtitan since later on in my project, I might need to directly change the training function as well as model architecture, so a lightweight framework would help me down the line.
One more thing I discovered was that init_weights() function was called twice, once at through the Transformer module init function
with torch.device("meta"):
model = model_cls.from_model_args(model_config)
and once more later on explicitly
if parallel_dims.pp_enabled:
pp_schedule = build_pipeline_schedule(job_config, parallel_dims, stage, loss_fn)
else:
# If PP is enabled, we can't rely on init_weights, because some layers are missing.
# In the future, we may make init_weights handle missing layers, but also have to consider RNG seed propagation.
# allocate sharded model on GPU and initialize weights via DTensor
model.init_weights()
Not sure if this is the intended behavior, but just want to let you guys know.
how to load from a HF checkpoint in torchtitan
Here is my reference implementation to load HF checkpoint into FSDP model https://github.com/pytorch/torchtune/blob/main/recipes/dev/lora_finetune_fsdp2.py#L343.
full_sd below is referring to HF checkpoint.
def load_from_full_model_state_dict(
model: "FSDPModule",
full_sd: Dict[str, Any],
device: torch.device,
):
"""
Converting full state dict into a sharded state dict
and loading it into FSDP model
- 'full' means plain tensor
- 'sharded' means `DTensor` where reach rank has a shard of the plain tensor
"""
meta_sharded_sd = model.state_dict()
sharded_sd = {}
for param_name, full_tensor in full_sd.items():
sharded_meta_param = meta_sharded_sd.get(param_name)
full_tensor = full_tensor.to(sharded_meta_param.dtype).to(device)
sharded_tensor = distribute_tensor(
full_tensor,
sharded_meta_param.device_mesh,
sharded_meta_param.placements,
)
sharded_sd[param_name] = nn.Parameter(sharded_tensor)
# choose `assign=True` since we cannot call `copy_` on meta tensor
return model.load_state_dict(sharded_sd, strict=False, assign=True)
One more thing I discovered was that init_weights() function was called twice, once at through the Transformer module init function
Good question! We intentionally call init_weights() twice. 1st time is inside meta init, where we init parameters/tensors on meta device. 2nd time is on device='cuda', where we actually allocate tensor storage with real values
For LoRA/finetuning case, meta init is preferred since we are loading from checkpionts eventually. So we just init params on meta device, and copy tensors from checkpoints into model (load_from_full_model_state_dict). More about meta init: https://pytorch.org/tutorials/prototype/skip_param_init.html
Thank you very much for the pointer! After some more investigation, it does seem like the first step loss is too high (without any LoRA or any training) after loading weights from HF checkpoints directly. I get a loss of 11.79 at step 1, with train_configs/llama3_8b.toml configs and run_llama_train.sh on 4 A10 24GB GPUs. A minimal reproducible example would be instead of
if parallel_dims.pp_enabled:
pp_schedule = build_pipeline_schedule(job_config, parallel_dims, stage, loss_fn)
else:
# If PP is enabled, we can't rely on init_weights, because some layers are missing.
# In the future, we may make init_weights handle missing layers, but also have to consider RNG seed propagation.
# allocate sharded model on GPU and initialize weights via DTensor
model.init_weights()
I loaded weights from the checkpoints instead:
pretrained_model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
load_from_full_model_state_dict(model, pretrained_model, "cpu")
del pretrained_model
init_device = "cpu" if job_config.checkpoint.create_seed_checkpoint else "cuda"
model.to_empty(device=init_device)
if parallel_dims.pp_enabled:
pp_schedule = build_pipeline_schedule(job_config, parallel_dims, stage, loss_fn)
else:
# If PP is enabled, we can't rely on init_weights, because some layers are missing.
# In the future, we may make init_weights handle missing layers, but also have to consider RNG seed propagation.
# allocate sharded model on GPU and initialize weights via DTensor
# model.init_weights()
pass
I made one change to your reference implementation, where I mapped HF param names to the names defined in torchtitan Transformers module:
def load_from_full_model_state_dict(
model: "FSDPModule",
full_sd: Dict[str, Any],
device: torch.device,
):
"""
Converting full state dict into a sharded state dict
and loading it into FSDP model
- 'full' means plain tensor
- 'sharded' means `DTensor` where reach rank has a shard of the plain tensor
"""
print(model)
param_mapping = {
'model.embed_tokens.weight': 'tok_embeddings.weight'
}
for i, _ in enumerate(full_sd.model.layers):
param_mapping.update({
f'model.layers.{i}.self_attn.q_proj.weight': f'layers.{i}.attention.wq.weight',
f'model.layers.{i}.self_attn.k_proj.weight': f'layers.{i}.attention.wk.weight',
f'model.layers.{i}.self_attn.v_proj.weight': f'layers.{i}.attention.wv.weight',
f'model.layers.{i}.self_attn.o_proj.weight': f'layers.{i}.attention.wo.weight',
f'model.layers.{i}.mlp.gate_proj.weight': f'layers.{i}.feed_forward.w1.weight',
f'model.layers.{i}.mlp.down_proj.weight': f'layers.{i}.feed_forward.w2.weight',
f'model.layers.{i}.mlp.up_proj.weight': f'layers.{i}.feed_forward.w3.weight',
f'model.layers.{i}.input_layernorm.weight': f'layers.{i}.attention_norm.weight',
f'model.layers.{i}.post_attention_layernorm.weight': f'layers.{i}.ffn_norm.weight'
})
param_mapping.update({
'model.norm.weight': 'norm.weight',
'lm_head.weight': 'output.weight'
})
meta_sharded_sd = model.state_dict()
sharded_sd = {}
for param_name, full_tensor in full_sd.named_parameters():
sharded_meta_param = meta_sharded_sd.get(param_mapping[param_name])
# print(sharded_meta_param)
full_tensor = full_tensor.to(sharded_meta_param.dtype).to(device)
# print(param_name, full_tensor, sharded_meta_param)
sharded_tensor = distribute_tensor(
full_tensor,
sharded_meta_param.device_mesh,
sharded_meta_param.placements,
)
sharded_sd[param_name] = nn.Parameter(sharded_tensor)
# choose `assign=True` since we cannot call `copy_` on meta tensor
return model.load_state_dict(sharded_sd, strict=False, assign=True)
I would really appreciate it if you can spot anything wrong in my code or reproduce the loss that I had.
I made one change to your reference implementation, where I mapped HF param names to the names defined in torchtitan Transformers module
good catch by remapping parameter names
if you can spot anything wrong in my code
I would call load_from_full_model_state_dict after model.to_empty(...), and use torch.device("cuda")
model.to_emptymoves FSDP model frommetato 'cuda', but parameters values are unassignedload_from_full_model_state_dict(..., torch.device("cuda"): assign parameters values- I do not expect GPU OOM since
load_from_full_model_state_dictcasts full tensor into 1/N tensor. But let me know if you hit it
If you still hits error, I can try reproduce. Feel free to open a PR with your local changes
init_device = "cpu" if job_config.checkpoint.create_seed_checkpoint else "cuda"
model.to_empty(device=init_device)
if parallel_dims.pp_enabled:
pp_schedule = build_pipeline_schedule(job_config, parallel_dims, stage, loss_fn)
else:
# If PP is enabled, we can't rely on init_weights, because some layers are missing.
# In the future, we may make init_weights handle missing layers, but also have to consider RNG seed propagation.
# allocate sharded model on GPU and initialize weights via DTensor
# model.init_weights()
pass
pretrained_model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
load_from_full_model_state_dict(model, pretrained_model, torch.device("cuda"))
del pretrained_model
Thank you for your reply!
I moved load_from_full_model_state_dict to after model.to_empty(...), if I keep torch device as cpu, the behavior is the same. If I change device to cuda, it would incur this new problem where it complains that tensor is not leaf (this would correspond to the full_tensor variable in the load_from_full_model_state_dict function).
File "/home/ubuntu/.conda/envs/torchtitan/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper
return f(*args, **kwargs)
File "/home/ubuntu/torchtitan/train.py", line 252, in main
load_from_full_model_state_dict(model, pretrained_model, torch.device("cuda"))
File "/home/ubuntu/torchtitan/train.py", line 540, in load_from_full_model_state_dict
sharded_tensor = distribute_tensor(
File "/home/ubuntu/.conda/envs/torchtitan/lib/python3.10/site-packages/torch/distributed/_tensor/api.py", line 598, in distribute_tensor
raise RuntimeError(
RuntimeError: `distribute_tensor` should be used to distribute leaf tensors! but found non-leaf tensor!
I seem to be having some permission issues creating branches and PRs, I will look into it. Thanks again!
thanks for the patience. let me know if you cannot resolve it after investigation. I might draft some example code to load from HF checkpoint. This is a common ask and we can improve
Thank you! I have created a PR here: #427
Thank you! I have created a PR here: #427
thanks. I will give it a try. in the meanwhile, we have a script to convert HF to DCP format. Are you interested in giving it a try? https://github.com/pytorch/torchtitan/issues/305
@MinghaoYan
Currently I am trying to copy the weights from a HuggingFace Llama-3-8B-Instruct checkpoint to an instantiated Transformer module in torchtitan manually by replacing the Transformer init_weight function with the following
I wonder if you are aware of the model definition mismatch between Llama and HF's Transformer (#335). Basically a permutation of some weights is needed to make the conversion work.
I was not aware of this, thank you!