TensorRT-LLM icon indicating copy to clipboard operation
TensorRT-LLM copied to clipboard

[model support] please support gemma2

Open lullabies777 opened this issue 1 year ago • 8 comments

I found that the scripts in GEMMA do not support GEMMA2. Is there any plan to add support for GEMMA2?

lullabies777 avatar Jul 18 '24 15:07 lullabies777

+1

@QiJune @byshiue Cant we convert huggingface gemma 2 7b model to tensor RT LLM checkpoint ?

jayakommuru avatar Jul 23 '24 05:07 jayakommuru

@QiJune @byshiue @AdamzNV Can you please confirm if gemma 2 is supported currently? if not, by when would the support for it would be available?

Thanks

jayakommuru avatar Jul 24 '24 04:07 jayakommuru

It's not supported now, and we are proactively looking.

AdamzNV avatar Jul 24 '24 07:07 AdamzNV

@AdamzNV cant we use gemma's convert_checkpoint code to convert gemma-2 to tensorRT_LLM format? as Llama's convert_checkpoint is working fine for llama-2 as well.

jayakommuru avatar Jul 25 '24 05:07 jayakommuru

@jayakommuru No way to escape. Gemma2 has a new key, and the logic of Gemma1 can't handle it.

AdamzNV avatar Jul 25 '24 06:07 AdamzNV

when i convert Gemma2, i get this error:

Don't know how to rename transformer.model.layers.0.pre_feedforward_layernorm.weight

Alireza3242 avatar Jul 31 '24 07:07 Alireza3242

I build gemma 2. but i use sliding window which is a little different with gemma 2 architecture in transformers. you can build gemma 2 with gemma 1 files with these changes: 1- In convert_checkpoint.py: change architecture="GemmaForCausalLM" to architecture="Gemma2ForCausalLM"

add these lines in main() after creating trt_llm_config:

trt_llm_config.attn_logit_softcapping = ckpt_config["attn_logit_softcapping"]
trt_llm_config.final_logit_softcapping = ckpt_config["final_logit_softcapping"]
trt_llm_config.sliding_window = ckpt_config["sliding_window"]
trt_llm_config.sliding_window_size = ckpt_config["sliding_window_size"]

add these lines in get_config function:

config_new["attn_logit_softcapping"] = hf_config["attn_logit_softcapping"]
config_new["final_logit_softcapping"] = hf_config["final_logit_softcapping"]
config_new["sliding_window"] = hf_config["sliding_window"]
config_new["sliding_window_size"] = hf_config["sliding_window_size"]

add these lines to rename_to_trt_llm function:

(r"model.layers.(\d+).pre_feedforward_layernorm.weight",
r"layers.\1.pre_feedforward_layernorm.weight"),
(r"model.layers.(\d+).post_feedforward_layernorm.weight",
r"layers.\1.post_feedforward_layernorm.weight"),

in convert_from_checkpoint function add these 2 line:

"pre_feedforward_layernorm",
"post_feedforward_layernorm"

after

elif any(keyword in name for keyword in (
                    "pre_attention_norm.scale",
                    "pre_ffw_norm.scale",
                    "final_norm.scale",
                    "pre_attention_norm/vars/0",
                    "pre_ffw_norm/vars/0",
                    "rms_normalization/vars/0",
                    "input_layernorm",
                    "post_attention_layernorm",
                    "model.norm.weight",

2- in /usr/local/lib/python3.10/dist-packages/tensorrt_llm/functional.py add this function:

def soft_capping(input: Tensor, beta: float) -> Tensor:
    devided = input / beta
    tanh_layer = default_trtnet().add_activation(devided.trt_tensor,
                                               trt.ActivationType.TANH)
    out1_tensor = _create_tensor(tanh_layer.get_output(0), tanh_layer)
    out2_tensor = out1_tensor * beta
    return out2_tensor

3- in /usr/local/lib/python3.10/dist-packages/tensorrt_llm/models/init.py add this:

from .gemma2.model import Gemma2ForCausalLM

also add Gemma2ForCausalLM in __all__ and MODEL_MAP 4-

cp -r /usr/local/lib/python3.10/dist-packages/tensorrt_llm/models/gemma /usr/local/lib/python3.10/dist-packages/tensorrt_llm/models/gemma2
cp /usr/local/lib/python3.10/dist-packages/tensorrt_llm/models/modeling_utils.py /usr/local/lib/python3.10/dist-packages/tensorrt_llm/models/gemma2/my_modeling_utils.py

5- Change import paths in files in gemme2/* and my_modeling_utils.py to use these files. 6- Change GemmaForCausalLM to Gemma2ForCausalLM in gemme2/model.py

7- In /usr/local/lib/python3.10/dist-packages/tensorrt_llm/models/gemma2/model.py change Gemma2DecoderLayer __init__ and forward: Add this 2 lines in __init__:

self.pre_feedforward_layernorm = RmsNorm(normalized_shape=config.hidden_size, 
                                                 eps=config.norm_epsilon,
                                                 dtype=config.dtype)
self.post_feedforward_layernorm = RmsNorm(normalized_shape=config.hidden_size, 
                                                 eps=config.norm_epsilon,
                                                 dtype=config.dtype)

in forward function replace this:

hidden_states = residual + attention_output
residual = hidden_states
hidden_states = self.post_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states,
lora_layer_params=lora_layer_params)

with:

hidden_states = self.post_layernorm(attention_output)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.pre_feedforward_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states, lora_layer_params=lora_layer_params)
hidden_states = self.post_feedforward_layernorm(hidden_states)

8- In weight.py in load_from_hf_gemma function after:

            if 'input_layernorm.weight' in k:
                weights['transformer.layers.{}.input_layernorm.weight'.format(
                    idx)] = torch_to_numpy(numpy_to_torch(v) + 1.0)
            elif 'post_attention_layernorm.weight' in k:
                weights['transformer.layers.{}.post_layernorm.weight'.format(
                    idx)] = torch_to_numpy(numpy_to_torch(v) + 1.0)

add

            elif 'pre_feedforward_layernorm.weight' in k:
                weights['transformer.layers.{}.pre_feedforward_layernorm.weight'.format(
                    idx)] = torch_to_numpy(numpy_to_torch(v) + 1.0)
            elif 'post_feedforward_layernorm.weight' in k:
                weights['transformer.layers.{}.post_feedforward_layernorm.weight'.format(
                    idx)] = torch_to_numpy(numpy_to_torch(v) + 1.0)

9- in modeling_utils.py add:

from ...functional import soft_capping

also after

lm_logits = self.lm_head(hidden_states)

add:

lm_logits = soft_capping(lm_logits, self.config.final_logit_softcapping)

Alireza3242 avatar Aug 03 '24 12:08 Alireza3242

Gemma 2 2b is revolutionary. More than 100 t/s on llama.cpp, 10 t/s on CPU. I'd love to see what t/s you can get on Tensor-RT.

phly95 avatar Aug 11 '24 03:08 phly95

could we also add awq_quantization support for gemma-2 model? using https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/quantization/quantize_by_modelopt.py

raghavgarg97 avatar Sep 10 '24 11:09 raghavgarg97

As more and more new models enter the market, we have prepared comprehensive instructions for TRT-LLM developers on adapting to new models of interest. We encourage our community developers to expand the range of supported models, fostering an open ecosystem with rapid iterations.

Please try following these instructions and let us know if you encounter any issues during the adaptation process. We greatly appreciate your dedication.

AdamzNV avatar Oct 31 '24 05:10 AdamzNV