vllm icon indicating copy to clipboard operation
vllm copied to clipboard

Need help with supporting "core42/jais-13b-chat" model

Open sam-iink opened this issue 1 year ago • 6 comments

Hello Team,

I am attempting to add support for "core42/jais-13b-chat" model for vLLM. I have completed most of the required changes except for AliBi embeddings.

This is how the model looks like if loaded with HF:

JAISLMHeadModel(
  (transformer): JAISModel(
    (wte): Embedding(84992, 5120)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0-39): 40 x JAISBlock(
        (ln_1): LayerNorm((5120,), eps=1e-05, elementwise_affine=True)
        (attn): JAISAttention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
        )
        (ln_2): LayerNorm((5120,), eps=1e-05, elementwise_affine=True)
        (mlp): JAISMLP(
          (c_fc): Conv1D()
          (c_fc2): Conv1D()
          (c_proj): Conv1D()
          (act): SwiGLUActivation()
          (dropout): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((5120,), eps=1e-05, elementwise_affine=True)
    (relative_pe): AlibiPositionEmbeddingLayer()
  )
  (lm_head): Linear(in_features=5120, out_features=84992, bias=False)
)

and this is how it looks after I made the necessary changes for vLLM 0.2.1-post1:

JAISLMHeadModel(                                                                        
  (transformer): JAISModel(
    (wte): VocabParallelEmbedding()                                         
    (h): ModuleList(                                                                    
      (0-39): 40 x JAISBlock(                                                           
        (ln_1): LayerNorm((5120,), eps=1e-05, elementwise_affine=True)                  
        (attn): JAISAttention(                                              
          (c_attn): ColumnParallelLinear()                                  
          (c_proj): RowParallelLinear()                                                 
          (attn): PagedAttentionWithALiBi()
        )                                                                               
        (ln_2): LayerNorm((5120,), eps=1e-05, elementwise_affine=True)           
        (mlp): JAISMLP(                                                                 
          (c_fc): ColumnParallelLinear()                                                
          (c_fc2): ColumnParallelLinear()
          (c_proj): RowParallelLinear()
          (act): SiluAndMul()                                               
        )                                                                               
      )                                                                                 
    )                                                                                   
    (ln_f): LayerNorm((5120,), eps=1e-05, elementwise_affine=True)
  )                                                                                     
  (sampler): Sampler()                                                                  
)

However, the checkpoint loading does not work due to (relative_pe): AlibiPositionEmbeddingLayer(). I don't know how to make the corresponding changes in the class definition.

Can someone please help me on this?

Jais class definition: https://huggingface.co/core42/jais-13b-chat/blob/main/modeling_jais.py

sam-iink avatar Nov 28 '23 11:11 sam-iink

Hi @samujjwal-sam thanks for you effort in adding the jais model support. For now, I believe we can remove the layer and just use PagedAttentionWithALiBi instead of the normal attention? You can refer to our implementation of ALiBi models like BLOOM and MPT.

If you cannot find a workaround, please feel free to submit a PR so that we can look through the code together.

WoosukKwon avatar Nov 29 '23 05:11 WoosukKwon

Thanks @WoosukKwon for the suggestion. I think I am able to make it work now. I was confused with AliBi as in the HF implementation there is only one AliBi layer against all the layers (relative_pe): AlibiPositionEmbeddingLayer(), however, PagedAttentionWithALiBi adds AliBi to each attention layer. I am initializing the same value for each AliBi call now. This seems to work.

By the way, can you please take a look if the layer replacements are in order? I mean a quick verification if I am using ColumnParallelLinear and RowParallelLinear in proper places?

sam-iink avatar Nov 29 '23 05:11 sam-iink

@samujjwal-sam I was wondering if you've had the opportunity to implement JAIS GPT in VLM models. can you create pull request for https://huggingface.co/core42/jais-13b-chat/blob/main/modeling_jais.py this model.

sernddev avatar Dec 06 '23 11:12 sernddev

@samujjwal-sam eagerly waiting for running Jais with vllm. Hope it comes out soon!

Apoorv7092 avatar Dec 08 '23 06:12 Apoorv7092

Hi @samujjwal-sam I was womdering if you have managed to make it work for JAIS in vLLM. Looking forward to see JAIS supported by vLLM as it is one of the best Arabic open source LLMs so far.

7ossam81 avatar Dec 10 '23 07:12 7ossam81

Hi @samujjwal-sam , @WoosukKwon, when we can expect JAIS with Vllm ?

JAADARI avatar Jan 04 '24 18:01 JAADARI

Hey guys, thanks for enthusiasm. I am currently working on it. Currently facing some issues with CUDA.

sam-iink avatar Jan 10 '24 09:01 sam-iink

Getting the below issue with latest vLLM-0.2.7, although Jais works with vLLM-0.2.1-post1.

Traceback (most recent call last):
  File "/path_to_vllm_env/vllm/lib/python3.9/site-packages/vllm/worker/model_runner.py", line 581, in capture
    hidden_states = self.model(
  File "/path_to_vllm_env/vllm/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/path_to_vllm_env/vllm/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/path_to_vllm_env/vllm/lib/python3.9/site-packages/vllm/model_executor/models/jais.py", line 389, in forward
    hidden_states = self.transformer(input_ids, positions, kv_caches,
  File "/path_to_vllm_env/vllm/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/path_to_vllm_env/vllm/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/path_to_vllm_env/vllm/lib/python3.9/site-packages/vllm/model_executor/models/jais.py", line 330, in forward
    print(hidden_states)
  File "/path_to_vllm_env/vllm/lib/python3.9/site-packages/torch/_tensor.py", line 431, in __repr__
    return torch._tensor_str._str(self, tensor_contents=tensor_contents)
  File "/path_to_vllm_env/vllm/lib/python3.9/site-packages/torch/_tensor_str.py", line 664, in _str
    return _str_intern(self, tensor_contents=tensor_contents)
  File "/path_to_vllm_env/vllm/lib/python3.9/site-packages/torch/_tensor_str.py", line 595, in _str_intern
    tensor_str = _tensor_str(self, indent)
  File "/path_to_vllm_env/vllm/lib/python3.9/site-packages/torch/_tensor_str.py", line 347, in _tensor_str
    formatter = _Formatter(get_summarized_data(self) if summarize else self)
  File "/path_to_vllm_env/vllm/lib/python3.9/site-packages/torch/_tensor_str.py", line 137, in __init__
    nonzero_finite_vals = torch.masked_select(
RuntimeError: CUDA error: operation not permitted when stream is capturing
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/path_to_vllm_code/vllm/vllm_generate_responses.py", line 248, in <module>
    main(args)
  File "/path_to_vllm_code/vllm/vllm_generate_responses.py", line 223, in main
    llm = load_model_vllm(model_path, dtype=args.model_dtype, tensor_parallel_size=int(args.gpus))
  File "/path_to_vllm_code/vllm/vllm_generate_responses.py", line 54, in load_model_vllm
    llm = LLM(model=model_path, trust_remote_code=True, 
  File "/path_to_vllm_env/vllm/lib/python3.9/site-packages/vllm/entrypoints/llm.py", line 106, in __init__
    self.llm_engine = LLMEngine.from_engine_args(engine_args)
  File "/path_to_vllm_env/vllm/lib/python3.9/site-packages/vllm/engine/llm_engine.py", line 309, in from_engine_args
    engine = cls(*engine_configs,
  File "/path_to_vllm_env/vllm/lib/python3.9/site-packages/vllm/engine/llm_engine.py", line 114, in __init__
    self._init_cache()
  File "/path_to_vllm_env/vllm/lib/python3.9/site-packages/vllm/engine/llm_engine.py", line 298, in _init_cache
    self._run_workers("warm_up_model")
  File "/path_to_vllm_env/vllm/lib/python3.9/site-packages/vllm/engine/llm_engine.py", line 795, in _run_workers
    driver_worker_output = getattr(self.driver_worker,
  File "/path_to_vllm_env/vllm/lib/python3.9/site-packages/vllm/worker/worker.py", line 125, in warm_up_model
    self.model_runner.capture_model(self.gpu_cache)
  File "/path_to_vllm_env/vllm/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/path_to_vllm_env/vllm/lib/python3.9/site-packages/vllm/worker/model_runner.py", line 534, in capture_model
    graph_runner.capture(
  File "/path_to_vllm_env/vllm/lib/python3.9/site-packages/vllm/worker/model_runner.py", line 581, in capture
    hidden_states = self.model(
  File "/path_to_vllm_env/vllm/lib/python3.9/site-packages/torch/cuda/graphs.py", line 197, in __exit__
    self.cuda_graph.capture_end()
  File "/path_to_vllm_env/vllm/lib/python3.9/site-packages/torch/cuda/graphs.py", line 88, in capture_end
    super().capture_end()
RuntimeError: CUDA error: operation failed due to a previous error during capture
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

The above issue occurs particularly at this line in the model class (in vLLM-0.2.7 but not in vLLM-0.2.1-post1): The code is located at Jais class @ HuggingFace](https://huggingface.co/core42/jais-30b-chat-v1/blob/main/modeling_jais.py#L869)

        hidden_states = torch.mul(hidden_states, torch.tensor(
            float(self.embeddings_scale), dtype=hidden_states.dtype, device=hidden_states.device
        ))

@WoosukKwon Could you please take a look?

sam-iink avatar Jan 23 '24 07:01 sam-iink

Thank you @samujjwal-sam for the update! Could you please tell us what changes are required to make JAIS works with vLLM-0.2.1-post1?

7ossam81 avatar Jan 23 '24 08:01 7ossam81

Hey @7ossam81 , I have the code ready, will make it public by EOD today. You can just directly run it. Please wait for some time.

sam-iink avatar Jan 23 '24 09:01 sam-iink

Thank you @samujjwal-sam ! appreciate it !

7ossam81 avatar Jan 23 '24 09:01 7ossam81

Thank you @samujjwal-sam for the update! And when we can expect JAIS with Vllm ?

ML-Mr-J avatar Jan 24 '24 12:01 ML-Mr-J

Hi @sam-iink can you share with us your work so everyone can test jais with vllm, thank you for your Collaboration!

JAADARI avatar Jan 24 '24 13:01 JAADARI

Hello everyone, apologies for the delay. Here's a fork of the original repo with Jais support: https://github.com/SamujjwalSam/vllm-jais/tree/jais

SamujjwalSam avatar Jan 25 '24 08:01 SamujjwalSam

Thanks @SamujjwalSam!

7ossam81 avatar Jan 25 '24 08:01 7ossam81

@SamujjwalSam could you please help to share ETA for code merge to main? Is that in plan

novaturient95 avatar Feb 14 '24 13:02 novaturient95

Close this as https://github.com/vllm-project/vllm/pull/3183 has supported. And will release in v0.3.4.

esmeetu avatar Mar 25 '24 12:03 esmeetu