TensorRT-LLM
TensorRT-LLM copied to clipboard
[bug] --use_paged_context_fmha enable broken
My model is
{
"mlp_bias": false,
"attn_bias": false,
"rotary_base": 300000,
"rotary_scaling": null,
"residual_mlp": false,
"disable_weight_only_quant_plugin": false,
"moe": {
"num_experts": 0,
"top_k": 0,
"normalization_mode": null,
"sparse_mixer_epsilon": 0.01,
"tp_mode": 0
},
"remove_duplicated_kv_heads": false,
"architecture": "LlamaForCausalLM",
"dtype": "float16",
"vocab_size": 42016,
"hidden_size": 6656,
"num_hidden_layers": 60,
"num_attention_heads": 64,
"hidden_act": "silu",
"logits_dtype": "float32",
"norm_epsilon": 1e-06,
"position_embedding_type": "rope_gpt_neox",
"max_position_embeddings": 32768,
"num_key_value_heads": 16,
"intermediate_size": 17920,
"mapping": {
"world_size": 1,
"gpus_per_node": 8,
"cp_size": 1,
"tp_size": 1,
"pp_size": 1,
"moe_tp_size": 1,
"moe_ep_size": 1
},
"quantization": {
"quant_algo": null,
"kv_cache_quant_algo": null,
"group_size": 128,
"smoothquant_val": 0.5,
"clamp_val": null,
"has_zero_point": false,
"pre_quant_scale": false,
"exclude_modules": null
},
"use_parallel_embedding": false,
"embedding_sharding_dim": 0,
"share_embedding_table": false,
"head_size": 128,
"qk_layernorm": false
}
When I build engine with command:
trtllm-build \
--checkpoint_dir 1-gpu-tmp \
--output_dir 1-gpu \
--max_input_len 31744 \
--max_seq_len 32768 \
--max_num_tokens 4096 \
--max_batch_size 256 \
--gemm_plugin float16 \
--use_paged_context_fmha enable
All works fine.
But when I build engine with command
trtllm-build \
--checkpoint_dir 1-gpu-tmp \
--output_dir 1-gpu \
--max_input_len 31744 \
--max_seq_len 32768 \
--max_num_tokens 32768 \
--max_batch_size 256 \
--gemm_plugin float16 \
--use_paged_context_fmha enable
NaN occurs in logits.
Key difference is --max_num_tokens
value.
Behavior reproduces with all combination of runtime options enableBlockReuse(true/false) and enableChunkedContext(true/false) on Executor API 31ac30e928a
If I set --use_paged_context_fmha disable
all works fine with --max_num_tokens 32768