vllm icon indicating copy to clipboard operation
vllm copied to clipboard

[Question] Usage with Multimodal LLM

Open Hiusam opened this issue 1 year ago • 12 comments

Dear Authors,

Thank you so much for your wonderful work. I want to ask if I am running LLaVA(https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava.py), a multimodal LLM built upon LLaMA by adding an image encoder, what is the most convenient method to incorporate VLLM?

I think I can follow the instructions in "https://vllm.readthedocs.io/en/latest/models/adding_model.html". Are there any more convenient ways?

Hiusam avatar Jun 29 '23 08:06 Hiusam

I don't think vllm support image/audio embedding so far. Methods and abstraction on Multi-modal embedding needs to be supported first

yhyu13 avatar Jun 29 '23 13:06 yhyu13

Thanks for bringing this up! Theoretically, our current PagedAttention kernel and memory manager should support llava without any kernel modifications. The potential issue I can think of is how to schedule the image encoder along with the iteration-level continuous batching of LLMs to make sure the two networks run efficiently. I think following the "adding models" guide is a good first step. After that, we can focus on how to schedule the image encoder efficiently.

zhuohan123 avatar Jun 29 '23 14:06 zhuohan123

I have some questions if I want to add a llava Model, what should I do for load_weights func, I don't have any idea. I read model py files in vllm/model_executor/models but i dont't konw why code are writes this.

# llama.py LlamaForCausalLM.load_weights, why code should be this?
def load_weights(self,
                     model_name_or_path: str,
                     cache_dir: Optional[str] = None,
                     use_np_cache: bool = False):
        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
        tensor_model_parallel_rank = get_tensor_model_parallel_rank()
        state_dict = self.state_dict()

        for name, loaded_weight in hf_model_weights_iterator(
                model_name_or_path, cache_dir, use_np_cache):
            if "rotary_emb.inv_freq" in name:
                continue

            if "embed_tokens" in name or "lm_head" in name:
                param = state_dict[name]
                # Consider padding in the vocab size.
                padded_vocab_size = (param.shape[0] *
                                     tensor_model_parallel_world_size)
                num_extra_rows = padded_vocab_size - self.config.vocab_size
                extra_rows = torch.empty(num_extra_rows,
                                         loaded_weight.shape[1])
                extra_rows = extra_rows.to(loaded_weight)
                loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)

            is_attention_weight = False
            for stride_id, att_weight_name in enumerate(
                ["q_proj", "k_proj", "v_proj"]):
                if att_weight_name not in name:
                    continue
                param = state_dict[name.replace(att_weight_name, "qkv_proj")]
                shard_size = param.shape[0] // 3
                loaded_weight = loaded_weight[
                    shard_size * tensor_model_parallel_rank:shard_size *
                    (tensor_model_parallel_rank + 1)]
                param_slice = param.data[shard_size * stride_id:shard_size *
                                         (stride_id + 1)]
                assert param_slice.shape == loaded_weight.shape
                param_slice.copy_(loaded_weight)
                is_attention_weight = True
                break
            if is_attention_weight:
                continue

            is_gate_up_weight = False
            for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
                if weight_name not in name:
                    continue
                param = state_dict[name.replace(weight_name, "gate_up_proj")]
                shard_size = param.shape[0] // 2
                loaded_weight = loaded_weight[
                    shard_size * tensor_model_parallel_rank:shard_size *
                    (tensor_model_parallel_rank + 1)]
                param_slice = param.data[shard_size * stride_id:shard_size *
                                         (stride_id + 1)]
                assert param_slice.shape == loaded_weight.shape
                param_slice.copy_(loaded_weight)
                is_gate_up_weight = True
                break
            if is_gate_up_weight:
                continue

            param = state_dict[name]
            load_tensor_parallel_weights(param, loaded_weight, name,
                                         self._column_parallel_weights,
                                         self._row_parallel_weights,
                                         tensor_model_parallel_rank)

second question, llava forward interface is like this, how should I change code to make vllm support it ?

def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        images: Optional[torch.FloatTensor] = None,
        return_dict: Optional[bool] = None,
    )

or how long the llava model will be support by offical could you give me some guide ? @zhuohan123 wish your reply , thanks very much。

akxxsb avatar Jul 17 '23 12:07 akxxsb

@akxxsb Did you ever figure this out?

For our multi-modal use case we are just using the native LLaVA serve files though it's inefficient/slow but looking into solutions to make this better.

aldrinc avatar Jul 25 '23 20:07 aldrinc

Did you find any solution?

TalhaUusuf avatar Aug 03 '23 00:08 TalhaUusuf

@akxxsb Did you ever figure this out?

For our multi-modal use case we are just using the native LLaVA serve files though it's inefficient/slow but looking into solutions to make this better.

No,I am a research and development engineer, not specialized in machine learning. I am not very familiar with machine learning and deep learning and have no idea how to implement it. I am still waiting for an official solution.

akxxsb avatar Aug 04 '23 03:08 akxxsb

Any news on this one? It would be really nice to be able to serve a more performance optimized llava instance.

tachyean avatar Sep 05 '23 09:09 tachyean

Is this feature(#775) still in development? Are there any plans to support LLaVA-1.5?

nitky avatar Oct 09 '23 15:10 nitky

The potential issue I can think of is how to schedule the image encoder along with the iteration-level continuous batching of LLMs to make sure the two networks run efficiently

@zhuohan123 in my experience, sampling from LLM dominates the runtime of multimodal applications. With that, I am personally happy to run with https://github.com/vllm-project/vllm/pull/1265 and embed/encode images (or any other modalities) as a separate step outside vllm.

In other words, I believe https://github.com/vllm-project/vllm/pull/1265 solves multimodality for vllm, at least for now, because:

  • image-embedding isn't yet as standardized as LLM-sampling is;
  • embedding an image requires fixed predictable compute and is easy to batch and run separately.

What do you think?

dimitry12 avatar Oct 20 '23 18:10 dimitry12

I agree with @dimitry12

Among other things, I believe that changing to accept embeds as input would be the smallest first step towards supporting multi-modality.

The case for a multi-modal model adopting a vision encoder and LLM like Llava-1.5 and Mplug-Owl could be supported simply.

In my case, I would batch process the vision encoding in a separate framework, and use the vLLM to perform accelerated computation on the language model. (As we know, computation on the language model is the biggest bottleneck in multi modal).

pfldy2850 avatar Dec 05 '23 10:12 pfldy2850

@zhuohan123 is this on the roadmap?

hmellor avatar Mar 25 '24 10:03 hmellor

this would be a great addition!

cassanof avatar Mar 29 '24 21:03 cassanof

Closing this since Llava1.5 (or a general vision language framework) has been already added in https://github.com/vllm-project/vllm/pull/3042. We will continue working on supporting other models at our best effort, but any community contribution is very much welcomed!

ywang96 avatar Apr 03 '24 17:04 ywang96