mergoo icon indicating copy to clipboard operation
mergoo copied to clipboard

LoRA MoE with k_proj, up_proj, down_proj

Open aksh555 opened this issue 9 months ago • 5 comments

Hi, thanks for the library! When we try to compose LoRA experts that have k_proj, up_proj, down_proj in the target_modules, we face a shape mismatch error. Everything works fine when the target modules are only q_proj and v_proj. Any suggestions on how to fix this?

aksh555 avatar May 12 '24 20:05 aksh555

Hi,

Thanks for opening the issue. Can you please provide more details on your implementation, and experts used for merging?

alirezamshi avatar May 13 '24 11:05 alirezamshi

The experts are LoRA finetuned Llama 7B models with target modules ['down_proj', 'v_proj', 'up_proj', 'q_proj', 'k_proj'] and this was the error encontered Traceback (most recent call last): ... File "/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 116, in forward return F.linear(input, self.weight, self.bias) RuntimeError: mat1 and mat2 shapes cannot be multiplied (183x11008 and 4096x2)

aksh555 avatar May 13 '24 14:05 aksh555

@aksh555, could you share the scripts for replicating the model composition and the forward pass?

gitsailor5 avatar May 13 '24 21:05 gitsailor5

I met the same bug. Looking forward a debuger :)

The code used to merge llama3-lora:

"""
Replaces ff layers using MOE. rest all will be averaged
"""
import torch
from mergoo.compose_experts import ComposeExperts
from mergoo.models.modeling_llama import LlamaForCausalLM

model_id = "moe_model/llama3_lora_moe"
config = {
    "model_type": "llama",
    "num_experts_per_tok": 2,
    "base_model": "../Meta-Llama-3-8B-Instruct",
    "experts": [
        {"expert_name": "adapter_1", "model_id": "llama3_lora_1"},
        {"expert_name": "adapter_2", "model_id": "llama3_lora_2"},
    ],
}

# create checkpoint
import os

if not os.path.exists(model_id):
    expertcomposer = ComposeExperts(config)
    expertcomposer.compose()
    expertcomposer.save_checkpoint(model_id)


# load the composed checkkpoint
model = LlamaForCausalLM.from_pretrained(
    model_id, torch_dtype=torch.float16, device_map="auto"
)  # 'gate' / router layers are untrained hence loaded warning would appeare for them
out = model(torch.tensor([[1, 2, 3, 33, 44]], device=model.device))
print("done")

The error messages: RuntimeError: mat1 and mat2 shapes cannot be multiplied (5x14336 and 4096x2)

The overall error information:

Traceback (most recent call last):
  File "/data/wentao/slz/mergoo/examples/compose_lora_mistral.py", line 35, in <module>
    out = model(torch.tensor([[1, 2, 3, 33, 44]], device=model.device))
  File "/home/wentao/anaconda3/envs/qwen_moe/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/wentao/anaconda3/envs/qwen_moe/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/wentao/anaconda3/envs/qwen_moe/lib/python3.10/site-packages/mergoo/models/modeling_llama.py", line 1177, in forward
    outputs = self.model(
  File "/home/wentao/anaconda3/envs/qwen_moe/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/wentao/anaconda3/envs/qwen_moe/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/wentao/anaconda3/envs/qwen_moe/lib/python3.10/site-packages/mergoo/models/modeling_llama.py", line 1020, in forward
    layer_outputs = decoder_layer(
  File "/home/wentao/anaconda3/envs/qwen_moe/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/wentao/anaconda3/envs/qwen_moe/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/wentao/anaconda3/envs/qwen_moe/lib/python3.10/site-packages/mergoo/models/modeling_llama.py", line 756, in forward
    hidden_states = self.mlp(hidden_states)
  File "/home/wentao/anaconda3/envs/qwen_moe/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/wentao/anaconda3/envs/qwen_moe/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/wentao/anaconda3/envs/qwen_moe/lib/python3.10/site-packages/mergoo/models/modeling_llama.py", line 242, in forward
    down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  File "/home/wentao/anaconda3/envs/qwen_moe/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/wentao/anaconda3/envs/qwen_moe/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/wentao/anaconda3/envs/qwen_moe/lib/python3.10/site-packages/mergoo/compose_layers.py", line 126, in forward
    gate_logits = self.gate(x)  # b,s,N
  File "/home/wentao/anaconda3/envs/qwen_moe/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/wentao/anaconda3/envs/qwen_moe/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/wentao/anaconda3/envs/qwen_moe/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 114, in forward
    return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (5x14336 and 4096x2)

Aurora-slz avatar May 30 '24 12:05 Aurora-slz

Hi @Aurora-slz,

Could you please share the adapter configurations for llama3_lora_1 and llama3_lora_2? If they are available on HuggingFace, providing the URLs would be very helpful. This will assist us in replicating the issue mentioned above.

gitsailor5 avatar Jun 06 '24 17:06 gitsailor5