LLaVA icon indicating copy to clipboard operation
LLaVA copied to clipboard

[Usage] How can I change the language model into Qwen-7B?

Open jyC23333 opened this issue 1 year ago • 1 comments

Describe the issue

I want to change the LLM into Qwen,and I write a model file according to llava_llama.py:

#    Copyright 2023 Haotian Liu
#
#    Licensed under the Apache License, Version 2.0 (the "License");
#    you may not use this file except in compliance with the License.
#    You may obtain a copy of the License at
#
#        http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS,
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#    See the License for the specific language governing permissions and
#    limitations under the License.


from typing import List, Optional, Tuple, Union

import torch
import torch.nn as nn

from transformers import AutoConfig, AutoModelForCausalLM, \
                         LlamaConfig, LlamaModel, LlamaForCausalLM

from .qwen.configuration_qwen import QWenConfig


from transformers.modeling_outputs import CausalLMOutputWithPast

from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM


class LlavaQwenConfig(QWenConfig):
    model_type = "llava_qwen"


class LlavaQwenModel(LlavaMetaModel, AutoModelForCausalLM):
    config_class = LlavaQwenConfig

    def __init__(self, config: QWenConfig):
        super(LlavaQwenModel, self).__init__(config)


class LlavaQwenForCausalLM(AutoModelForCausalLM, LlavaMetaForCausalLM):
    config_class = LlavaQwenConfig

    def __init__(self, config):
        super(AutoModelForCausalLM, self).__init__(config)
        self.model = LlavaQwenModel(config)
        self.pretraining_tp = config.pretraining_tp
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        # Initialize weights and apply final processing
        self.post_init()

    def get_model(self):
        return self.model

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        images: Optional[torch.FloatTensor] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:

        if inputs_embeds is None:
            (
                input_ids,
                position_ids,
                attention_mask,
                past_key_values,
                inputs_embeds,
                labels
            ) = self.prepare_inputs_labels_for_multimodal(
                input_ids,
                position_ids,
                attention_mask,
                past_key_values,
                labels,
                images
            )

        return super().forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            labels=labels,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict
        )

    def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
        images = kwargs.pop("images", None)
        _inputs = super().prepare_inputs_for_generation(
            input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
        )
        if images is not None:
            _inputs['images'] = images
        return _inputs

AutoConfig.register("llava_qwen", LlavaQwenConfig)
AutoModelForCausalLM.register(LlavaQwenConfig, LlavaQwenForCausalLM)

But I got the following error:

ValueError: Unrecognized configuration class <class 'transformers_modules.Qwen.Qwen-7B.ef3c5c9c57b252f3149c1408daf4d649ec8b6c85.configuration_qwen.QWenConfig'> for this kind o
f AutoModel: LlavaQwenForCausalLM.                                                                                                                                             
Model type should be one of BartConfig, BertConfig, BertGenerationConfig, BigBirdConfig, BigBirdPegasusConfig, BioGptConfig, BlenderbotConfig, BlenderbotSmallConfig, BloomConf
ig, CamembertConfig, CodeGenConfig, CpmAntConfig, CTRLConfig, Data2VecTextConfig, ElectraConfig, ErnieConfig, FalconConfig, GitConfig, GPT2Config, GPT2Config, GPTBigCodeConfig
, GPTNeoConfig, GPTNeoXConfig, GPTNeoXJapaneseConfig, GPTJConfig, LlamaConfig, MarianConfig, MBartConfig, MegaConfig, MegatronBertConfig, MptConfig, MusicgenConfig, MvpConfig,
 OpenLlamaConfig, OpenAIGPTConfig, OPTConfig, PegasusConfig, PLBartConfig, ProphetNetConfig, QDQBertConfig, ReformerConfig, RemBertConfig, RobertaConfig, RobertaPreLayerNormCo
nfig, RoCBertConfig, RoFormerConfig, RwkvConfig, Speech2Text2Config, TransfoXLConfig, TrOCRConfig, XGLMConfig, XLMConfig, XLMProphetNetConfig, XLMRobertaConfig, XLMRobertaXLCo
nfig, XLNetConfig, XmodConfig, LlavaConfig, LlavaMPTConfig, LlavaQwenConfig.

jyC23333 avatar Jan 29 '24 03:01 jyC23333

Hi,I have the same issue. Have you solved it?

20191864218 avatar Feb 22 '24 17:02 20191864218

@20191864218 Hi, yes, I've already adapt the qwen model to llava. Many details should be noticed.

I suggest you to follow this repo to adapt qwen to llava: https://github.com/Ucas-HaoranWei/Vary Vary is made from LLaVA-v1, an old version of llava. If you don't mind the version of llava, you can refer to the repo.

By the way, Qwen team just pulished the newest qwen-1.5 series, I don't know whether it is easier to adapt. I hope so!

jyC23333 avatar Feb 23 '24 02:02 jyC23333