TensorRT-LLM icon indicating copy to clipboard operation
TensorRT-LLM copied to clipboard

[bug] --use_paged_context_fmha enable broken

Open akhoroshev opened this issue 5 months ago • 1 comments

My model is

{
    "mlp_bias": false,
    "attn_bias": false,
    "rotary_base": 300000,
    "rotary_scaling": null,
    "residual_mlp": false,
    "disable_weight_only_quant_plugin": false,
    "moe": {
        "num_experts": 0,
        "top_k": 0,
        "normalization_mode": null,
        "sparse_mixer_epsilon": 0.01,
        "tp_mode": 0
    },
    "remove_duplicated_kv_heads": false,
    "architecture": "LlamaForCausalLM",
    "dtype": "float16",
    "vocab_size": 42016,
    "hidden_size": 6656,
    "num_hidden_layers": 60,
    "num_attention_heads": 64,
    "hidden_act": "silu",
    "logits_dtype": "float32",
    "norm_epsilon": 1e-06,
    "position_embedding_type": "rope_gpt_neox",
    "max_position_embeddings": 32768,
    "num_key_value_heads": 16,
    "intermediate_size": 17920,
    "mapping": {
        "world_size": 1,
        "gpus_per_node": 8,
        "cp_size": 1,
        "tp_size": 1,
        "pp_size": 1,
        "moe_tp_size": 1,
        "moe_ep_size": 1
    },
    "quantization": {
        "quant_algo": null,
        "kv_cache_quant_algo": null,
        "group_size": 128,
        "smoothquant_val": 0.5,
        "clamp_val": null,
        "has_zero_point": false,
        "pre_quant_scale": false,
        "exclude_modules": null
    },
    "use_parallel_embedding": false,
    "embedding_sharding_dim": 0,
    "share_embedding_table": false,
    "head_size": 128,
    "qk_layernorm": false
}

When I build engine with command:

trtllm-build  \
--checkpoint_dir 1-gpu-tmp \
--output_dir 1-gpu \
--max_input_len 31744 \
--max_seq_len 32768 \
--max_num_tokens 4096 \
--max_batch_size 256 \
--gemm_plugin float16 \
--use_paged_context_fmha enable

All works fine.

But when I build engine with command

trtllm-build  \
--checkpoint_dir 1-gpu-tmp \
--output_dir 1-gpu \
--max_input_len 31744 \
--max_seq_len 32768 \
--max_num_tokens 32768 \
--max_batch_size 256 \
--gemm_plugin float16 \
--use_paged_context_fmha enable

NaN occurs in logits.

Key difference is --max_num_tokens value.

Behavior reproduces with all combination of runtime options enableBlockReuse(true/false) and enableChunkedContext(true/false) on Executor API 31ac30e928a

If I set --use_paged_context_fmha disable all works fine with --max_num_tokens 32768

akhoroshev avatar Sep 25 '24 06:09 akhoroshev