litgpt icon indicating copy to clipboard operation
litgpt copied to clipboard

Support LoRA with multiple devices

Open rasbt opened this issue 1 year ago • 6 comments

When switching from DeepSpeed stage 2 over to DeepSpeed stage 3, there are currently issues with loading the model via the LoRA finetuning script:

...
 18176]) from checkpoint, the shape in current model is torch.Size([0]).
        size mismatch for transformer.h.31.norm_1.weight: copying a param with shape torch.Size([4544]) from checkpoint, the shape in current model is torch.Size([0]).
        size mismatch for transformer.h.31.norm_1.bias: copying a param with shape torch.Size([4544]) from checkpoint, the shape in current model is torch.Size([0]).
        size mismatch for transformer.h.31.attn.attn.weight: copying a param with shape torch.Size([4672, 4544]) from checkpoint, the shape in current model is torch.Size([0]).
        size mismatch for transformer.h.31.attn.proj.weight: copying a param with shape torch.Size([4544, 4544]) from checkpoint, the shape in current model is torch.Size([0]).
        size mismatch for transformer.h.31.mlp.fc.weight: copying a param with shape torch.Size([18176, 4544]) from checkpoint, the shape in current model is torch.Size([0]).
        size mismatch for transformer.h.31.mlp.proj.weight: copying a param with shape torch.Size([4544, 18176]) from checkpoint, the shape in current model is torch.Size([0]).
        size mismatch for transformer.ln_f.weight: copying a param with shape torch.Size([4544]) from checkpoint, the shape in current model is torch.Size([0]).
        size mismatch for transformer.ln_f.bias: copying a param with shape torch.Size([4544]) from checkpoint, the shape in current model is torch.Size([0]).

But I see that an FSDP update is in the works via #118, so I am not sure if that's worth addressing right now.

rasbt avatar Jun 16 '23 18:06 rasbt

FSDP also fails with a similar error.

This is because we are accessing the lora parameters in .train() instead of .forward(): https://github.com/Lightning-AI/lit-llama/blob/main/lit_llama/lora.py#L270-L273

@awaelchli suggested removing these calls from the fine-tuning scripts:

It's only needed if you want to merge the lora weights into the regular weights (for a faster forward or if you don't need the delta anymore). But we don't care about that in finetuning. And we save the lora weights separately anyway. So you have two options: Refactor the code out of the train() eval() methods in to a separate merge_weights() method. Or just remove the train() eval() calls from the finetuning scripts.

carmocca avatar Jun 16 '23 18:06 carmocca

That makes sense, thanks for clarifying! I could give it a try when I am back from CVPR next week (currently working on code for the talk) but if there is anyone reading this who wants to give this a try, please go ahead :)

rasbt avatar Jun 16 '23 20:06 rasbt

tried that, same error appears. both adapter and lora fail to load the checkpoint.

do you have other solution?

alexeiga avatar Jun 18 '23 12:06 alexeiga

The above issue should be just for LoRA. Is the error exactly the same for Adapter?

carmocca avatar Jun 19 '23 18:06 carmocca

This doesn't work for devices=1 with lora

griff4692 avatar Jun 20 '23 17:06 griff4692

The above issue should be just for LoRA. Is the error exactly the same for Adapter?

yes, same issue loading weights with deepspeed level 3

alexeiga avatar Jun 21 '23 10:06 alexeiga

any update regarding this issue?

alexeiga avatar Jul 17 '23 07:07 alexeiga