Gradients in GPT module of the finetuning/lora.py script are always zero
Hello,
I've noticed that the GPT module of the litgpt/finetune/lora.py script is initialized inside the fabric.init_module() context manager. Initializing the model this way implies that the reset_parameters() method of the LoRALinear and the LoRAQKVLinear classes is called under this context manager, so the lora_A and lora_B matrices are never really initialized (they are both zero). I think that is leading to zero gradient through the entire training.
Maybe there is a way to call the reset_parameters() of a LoRALinear module outside the __init__() method.
Thanks for bringing that up! I think reset_parameters() will not make the weights 0 though but reinitialize them when I understand correctly. So I think this should be okay but I may be overlooking something. (Please correct me if I'm wrong or am missing the point).
So, only LoRA matrix B is zero, but LoRA matrix A should be initialized with small random weights, i.e.,
std_dev = 1 / torch.sqrt(torch.tensor(lora_r).float())
self.A = nn.Parameter(torch.randn(in_dim, rank) * std_dev)
However, I am just seeing that we used Kaiming He initialization for the linear layer with sqrt(5):
https://github.com/Lightning-AI/litgpt/blob/449eb29b111324090fa7066e0b26e9166806b02e/litgpt/lora.py#L135
Maybe we should investigate some time if using the original initialization scaling based on the rank is actually better @awaelchli @carmocca ?
Hi @rasbt! I think you are correct. reset_parameters() do not make the weights 0, it initializes the lora_A matrix with Kaiming He initizalization and the lora_B matrix with zeros. What I'm saying is that the reset_parameters() is called inside the __init__() method, and when the GPT module is constructed under the fabric.init_module() context manager, the reset_parameters() is not actually serving its purpose (both matrices end up initialized with zeros). In order to correctly intialize the matrices, I had to call reset_parameters() after calling __init__() (outside the context manager).
@LautaroEst Which Fabric strategy are you using?
@awaelchli I'm using just one gpu, so I'm initializing the fabric object with
fabric = L.Fabric(accelerator="gpu", strategy="auto", devices=1, num_nodes=1, precision="bf16-true")
@rasbt We follow the same initialization as Microsoft's: https://github.com/microsoft/LoRA/blob/main/loralib/layers.py#L266-L271 which itself matches what you propose: https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/linear.py#L106-L109
I'm using just one gpu, so I'm initializing the fabric object with
In this case, empty_init=False is used: https://github.com/Lightning-AI/litgpt/blob/main/litgpt/finetune/lora.py#L170 so initialization should be happening normally