adapters
adapters copied to clipboard
Cannot adapt multiple sub-models of multimodal model with `AdapterModelInterface`
Environment info
adaptersversion: HEAD (1.1.1 @ 0470c18)- Platform: Linux-4.18.0-513.5.1.el8_9.x86_64-x86_64-with-glibc2.35
- Python version: 3.11.11
- PyTorch version (GPU?): 2.6.0+cu124 (True)
- Tensorflow version (GPU?): not installed (NA)
- Using GPU in script?: yes
- Using distributed or parallel set-up in script?: no
Information
Model I am using (Bert, XLNet ...): LLaVA-NeXT (with AdapterModelInterface)
Language I am using the model on (English, Chinese ...): English
Adapter setup I am using (if any): see below
The problem arises when using:
- [ ] the official example scripts: (give details below)
- [x] my own modified scripts: (give details below)
The tasks I am working on is:
- [ ] an official GLUE/SQUaD task: (give the name)
- [x] my own task or dataset: (give details below)
To reproduce
Consider the following snippet:
import adapters
import torch
from adapters import AdapterModelInterface
from transformers import LlavaNextForConditionalGeneration
model = LlavaNextForConditionalGeneration.from_pretrained(
"llava-hf/llama3-llava-next-8b-hf",
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
device_map="auto",
attn_implementation="flash_attention_2",
)
lora_interface_lm = AdapterModelInterface(
adapter_methods=["lora"],
model_embeddings="language_model.model.embed_tokens",
model_layers="language_model.model.layers",
layer_self_attn="self_attn",
layer_cross_attn=None, # type: ignore
attn_k_proj="k_proj",
attn_q_proj="q_proj",
attn_v_proj="v_proj",
attn_o_proj="o_proj",
layer_intermediate_proj="mlp.up_proj",
layer_output_proj="mlp.down_proj",
)
adapters.init(model, interface=lora_interface_lm)
model.add_adapter("lora_lm", config="lora")
print(model.adapter_summary())
lora_interface_vision = AdapterModelInterface(
adapter_methods=["lora"],
model_embeddings="vision_tower.vision_model.embeddings.patch_embedding",
model_layers="vision_tower.vision_model.encoder.layers",
layer_self_attn="self_attn",
layer_cross_attn=None, # type: ignore
attn_k_proj="k_proj",
attn_q_proj="q_proj",
attn_v_proj="v_proj",
attn_o_proj="out_proj",
layer_intermediate_proj="mlp.fc1",
layer_output_proj="mlp.fc2",
)
adapters.init(llava_next.pre_trained_model, interface=lora_interface_vision)
llava_next.pre_trained_model.add_adapter("lora_vision", config="lora")
print(model.adapter_summary())
print(model)
When trying to adapt both the Language Model and the Vision Tower as shown above, the adapter summaries show that the adapters for both models have the exact same number of trainable parameters. After just adapting the LM:
================================================================================
Name Architecture #Param %Param Active Train
--------------------------------------------------------------------------------
lora_lm lora 3,407,872 0.041 0 1
--------------------------------------------------------------------------------
Full model 8,355,276,800 100.000 1
================================================================================
After adapting LM and VT:
================================================================================
Name Architecture #Param %Param Active Train
--------------------------------------------------------------------------------
lora_lm lora 3,407,872 0.041 0 1
lora_vision lora 3,407,872 0.041 0 1
--------------------------------------------------------------------------------
Full model 8,355,276,800 100.000 1
================================================================================
However, when just adapting the VT, the numbers change to the expected values:
================================================================================
Name Architecture #Param %Param Active Train
--------------------------------------------------------------------------------
lora_vision lora 786,432 0.009 0 1
--------------------------------------------------------------------------------
Full model 8,355,276,800 100.000 1
================================================================================
These differences are not only seen in the number of parameters but also when looking at the model summaries, showing that in the former case, no LoRA layers are being added to the VT:
LlavaNextForConditionalGeneration(
(vision_tower): CLIPVisionModel(
(vision_model): CLIPVisionTransformer(
(embeddings): CLIPVisionEmbeddings(
(patch_embedding): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
(position_embedding): Embedding(577, 1024)
)
(pre_layrnorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(encoder): CLIPEncoder(
(layers): ModuleList(
(0-23): 24 x CLIPEncoderLayer(
(self_attn): CLIPFlashAttention2(
(k_proj): Linear(in_features=1024, out_features=1024, bias=True)
(v_proj): Linear(in_features=1024, out_features=1024, bias=True)
(q_proj): Linear(in_features=1024, out_features=1024, bias=True)
(out_proj): Linear(in_features=1024, out_features=1024, bias=True)
)
(layer_norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): CLIPMLP(
(activation_fn): QuickGELUActivation()
(fc1): Linear(in_features=1024, out_features=4096, bias=True)
(fc2): Linear(in_features=4096, out_features=1024, bias=True)
)
(layer_norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
)
)
(post_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
)
(multi_modal_projector): LlavaNextMultiModalProjector(
(linear_1): Linear(in_features=1024, out_features=4096, bias=True)
(act): GELUActivation()
(linear_2): Linear(in_features=4096, out_features=4096, bias=True)
)
(language_model): LlamaForCausalLM(
(model): LlamaModel(
(embed_tokens): Embedding(128320, 4096)
(layers): ModuleList(
(0-31): 32 x LlamaDecoderLayer(
(self_attn): LlamaAttention(
(q_proj): LoRALinearTorch(
in_features=4096, out_features=4096, bias=False
(shared_parameters): ModuleDict()
(loras): ModuleDict(
(lora_lm): LoRA()
(lora_vision): LoRA()
)
)
(k_proj): LoRALinearTorch(
in_features=4096, out_features=1024, bias=False
(shared_parameters): ModuleDict()
(loras): ModuleDict()
)
(v_proj): LoRALinearTorch(
in_features=4096, out_features=1024, bias=False
(shared_parameters): ModuleDict()
(loras): ModuleDict(
(lora_lm): LoRA()
(lora_vision): LoRA()
)
)
(o_proj): Linear(in_features=4096, out_features=4096, bias=False)
)
(mlp): LlamaMLP(
(gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
(up_proj): LoRALinearTorch(
in_features=4096, out_features=14336, bias=False
(shared_parameters): ModuleDict()
(loras): ModuleDict()
)
(down_proj): LoRALinearTorch(
in_features=14336, out_features=4096, bias=False
(shared_parameters): ModuleDict()
(loras): ModuleDict()
)
(act_fn): SiLU()
)
(input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
(post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
)
)
(norm): LlamaRMSNorm((4096,), eps=1e-05)
(rotary_emb): LlamaRotaryEmbedding()
)
(lm_head): Linear(in_features=4096, out_features=128320, bias=False)
)
(shared_parameters): ModuleDict()
)
However, this is indeed happening in the latter case (again, where we just adapt the VT):
LlavaNextForConditionalGeneration(
(vision_tower): CLIPVisionModel(
(vision_model): CLIPVisionTransformer(
(embeddings): CLIPVisionEmbeddings(
(patch_embedding): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
(position_embedding): Embedding(577, 1024)
)
(pre_layrnorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(encoder): CLIPEncoder(
(layers): ModuleList(
(0-23): 24 x CLIPEncoderLayer(
(self_attn): CLIPFlashAttention2(
(k_proj): LoRALinearTorch(
in_features=1024, out_features=1024, bias=True
(shared_parameters): ModuleDict()
(loras): ModuleDict()
)
(v_proj): LoRALinearTorch(
in_features=1024, out_features=1024, bias=True
(shared_parameters): ModuleDict()
(loras): ModuleDict(
(lora_vision): LoRA()
)
)
(q_proj): LoRALinearTorch(
in_features=1024, out_features=1024, bias=True
(shared_parameters): ModuleDict()
(loras): ModuleDict(
(lora_vision): LoRA()
)
)
(out_proj): Linear(in_features=1024, out_features=1024, bias=True)
)
(layer_norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): CLIPMLP(
(activation_fn): QuickGELUActivation()
(fc1): LoRALinearTorch(
in_features=1024, out_features=4096, bias=True
(shared_parameters): ModuleDict()
(loras): ModuleDict()
)
(fc2): LoRALinearTorch(
in_features=4096, out_features=1024, bias=True
(shared_parameters): ModuleDict()
(loras): ModuleDict()
)
)
(layer_norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
)
)
(post_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
)
(multi_modal_projector): LlavaNextMultiModalProjector(
(linear_1): Linear(in_features=1024, out_features=4096, bias=True)
(act): GELUActivation()
(linear_2): Linear(in_features=4096, out_features=4096, bias=True)
)
(language_model): LlamaForCausalLM(
(model): LlamaModel(
(embed_tokens): Embedding(128320, 4096)
(layers): ModuleList(
(0-31): 32 x LlamaDecoderLayer(
(self_attn): LlamaAttention(
(q_proj): Linear(in_features=4096, out_features=4096, bias=False)
(k_proj): Linear(in_features=4096, out_features=1024, bias=False)
(v_proj): Linear(in_features=4096, out_features=1024, bias=False)
(o_proj): Linear(in_features=4096, out_features=4096, bias=False)
)
(mlp): LlamaMLP(
(gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
(up_proj): Linear(in_features=4096, out_features=14336, bias=False)
(down_proj): Linear(in_features=14336, out_features=4096, bias=False)
(act_fn): SiLU()
)
(input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
(post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
)
)
(norm): LlamaRMSNorm((4096,), eps=1e-05)
(rotary_emb): LlamaRotaryEmbedding()
)
(lm_head): Linear(in_features=4096, out_features=128320, bias=False)
)
(shared_parameters): ModuleDict()
)
Expected behavior
I would expect the LoRA layers to be added everywhere I specify them to be, and to be usable subsequently. However, it seems as if adding the adapters to the VT after I already added some to the LM simply takes no effect. This shouldn't be the case.
Hey @kurzdev, thanks for bringing this up! You're fully correct that the vision and language adapters are not added as expected with the provided snippet. This is due to a limitation with the AdapterModelInterface that allows any model to only specify one interface at a time. Ie since you're trying to add a second interface to a model which already holds an interface, the second will have no effect.
What I would suggest instead is to treat the vision and text towers as fully independent models and add interfaces directly to the sub-models, ie:
import adapters
import torch
from adapters import AdapterModelInterface
from transformers import LlavaNextForConditionalGeneration
model = LlavaNextForConditionalGeneration.from_pretrained(
"llava-hf/llama3-llava-next-8b-hf",
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
device_map="auto",
# attn_implementation="flash_attention_2",
)
lora_interface_lm = AdapterModelInterface(
adapter_methods=["lora"],
model_embeddings="embed_tokens",
model_layers="layers",
layer_self_attn="self_attn",
layer_cross_attn=None, # type: ignore
attn_k_proj="k_proj",
attn_q_proj="q_proj",
attn_v_proj="v_proj",
attn_o_proj="o_proj",
layer_intermediate_proj="mlp.up_proj",
layer_output_proj="mlp.down_proj",
)
adapters.init(model.language_model, interface=lora_interface_lm)
model.language_model.add_adapter("lora_lm", config="lora")
print(model.language_model.adapter_summary())
lora_interface_vision = AdapterModelInterface(
adapter_methods=["lora"],
model_embeddings="vision_model.embeddings.patch_embedding",
model_layers="vision_model.encoder.layers",
layer_self_attn="self_attn",
layer_cross_attn=None, # type: ignore
attn_k_proj="k_proj",
attn_q_proj="q_proj",
attn_v_proj="v_proj",
attn_o_proj="out_proj",
layer_intermediate_proj="mlp.fc1",
layer_output_proj="mlp.fc2",
)
adapters.init(model.vision_tower, interface=lora_interface_vision)
model.vision_tower.add_adapter("lora_vision", config="lora")
print(model.vision_tower.adapter_summary())
print(model)
Note how here, the interfaces are directly initialized on model.language_model and model.vision_tower. This comes with the downside that you cannot call adapter methods on the parent model, but gives full control over adapters in both sub-models.
Hope this helps and let us know if you think there's any way we could improve this going forward!
Hey @calpt, thanks for looking into it and the helpful advice! That makes a lot of sense. 😊
Correct me if I'm wrong but doing so prevents jointly fine-tuning both submodels via AdapterTrainer because of this check first and foremost, doesn't it? (feel free to call me out on this because I haven't had the capacity to try yet... 🫣)
Generally I wouldn't mind having to initialise the submodules manually as long as I can still jointly fine-tune the model as a whole in the end. Not sure how big of a workload implementing this feature (given my above deduction about its nonexistence is right, lol) would entail though.
Feel free to point me in a certain direction if this doesn't seem to be too big of a task/a good first issue, would be happy to assist! 🙂
Having the same issue
@kurzdev yes, you are correct, the AdapterTrainer needs modifications to run this kind of model setup. As a temporary patch, you could disable these checks & inits in the Trainer init and change the save/ load logic to call the respective adapter methods on both sub-models.
What I think we should attempt as a "clean fix" of this is a sort of proxy interface on the parent model that basically forwards all interface calls to the two sub-models. Generally, this should not be too large of a change code-wise, but might need some tweaking of ugly library internals :)
This issue has been automatically marked as stale because it has been without activity for 90 days. This issue will be closed in 14 days unless you comment or remove the stale label.
This issue has been automatically marked as stale because it has been without activity for 90 days. This issue will be closed in 14 days unless you comment or remove the stale label.