litgpt
litgpt copied to clipboard
Support LoRA with multiple devices
When switching from DeepSpeed stage 2 over to DeepSpeed stage 3, there are currently issues with loading the model via the LoRA finetuning script:
...
18176]) from checkpoint, the shape in current model is torch.Size([0]).
size mismatch for transformer.h.31.norm_1.weight: copying a param with shape torch.Size([4544]) from checkpoint, the shape in current model is torch.Size([0]).
size mismatch for transformer.h.31.norm_1.bias: copying a param with shape torch.Size([4544]) from checkpoint, the shape in current model is torch.Size([0]).
size mismatch for transformer.h.31.attn.attn.weight: copying a param with shape torch.Size([4672, 4544]) from checkpoint, the shape in current model is torch.Size([0]).
size mismatch for transformer.h.31.attn.proj.weight: copying a param with shape torch.Size([4544, 4544]) from checkpoint, the shape in current model is torch.Size([0]).
size mismatch for transformer.h.31.mlp.fc.weight: copying a param with shape torch.Size([18176, 4544]) from checkpoint, the shape in current model is torch.Size([0]).
size mismatch for transformer.h.31.mlp.proj.weight: copying a param with shape torch.Size([4544, 18176]) from checkpoint, the shape in current model is torch.Size([0]).
size mismatch for transformer.ln_f.weight: copying a param with shape torch.Size([4544]) from checkpoint, the shape in current model is torch.Size([0]).
size mismatch for transformer.ln_f.bias: copying a param with shape torch.Size([4544]) from checkpoint, the shape in current model is torch.Size([0]).
But I see that an FSDP update is in the works via #118, so I am not sure if that's worth addressing right now.
FSDP also fails with a similar error.
This is because we are accessing the lora parameters in .train()
instead of .forward()
: https://github.com/Lightning-AI/lit-llama/blob/main/lit_llama/lora.py#L270-L273
@awaelchli suggested removing these calls from the fine-tuning scripts:
It's only needed if you want to merge the lora weights into the regular weights (for a faster forward or if you don't need the delta anymore). But we don't care about that in finetuning. And we save the lora weights separately anyway. So you have two options: Refactor the code out of the train() eval() methods in to a separate merge_weights() method. Or just remove the train() eval() calls from the finetuning scripts.
That makes sense, thanks for clarifying! I could give it a try when I am back from CVPR next week (currently working on code for the talk) but if there is anyone reading this who wants to give this a try, please go ahead :)
tried that, same error appears. both adapter and lora fail to load the checkpoint.
do you have other solution?
The above issue should be just for LoRA. Is the error exactly the same for Adapter?
This doesn't work for devices=1
with lora
The above issue should be just for LoRA. Is the error exactly the same for Adapter?
yes, same issue loading weights with deepspeed level 3
any update regarding this issue?