prompt_vocab_size is ignored in executor API
System Info
It seems like executor API ignores prompt_vocab_size argument and passes max_prompt_embedding_table_size to trt engine instead.
I observe such behaviour using either 0.10.0 python api (well, ModelRunnerCpp to be precise) or 0.9.0 (and 0.10.0) triton requests, but not 0.9.0 python api, so I assume the issue is in Executor API.
Who can help?
No response
Information
- [ ] The official example scripts
- [x] My own modified scripts
Tasks
- [ ] An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - [x] My own task or dataset (give details below)
Reproduction
- Define fake model, which exposes provided prompt_vocab_size via logits and config for this model
fake_model.py:
from typing import Optional
import numpy as np
from tensorrt_llm.functional import (
Tensor,
cast,
constant,
unsqueeze,
)
from tensorrt_llm.models.modeling_utils import PretrainedConfig, PretrainedModel
class FakeTransformer(object):
def __init__(self):
self.vocab_embedding = None
class FakeModel(PretrainedModel):
def __init__(self, config: PretrainedConfig):
super().__init__(config)
self.transformer = FakeTransformer()
def forward(
self,
input_ids: Tensor,
position_ids=None,
use_cache=False,
last_token_ids=None,
attention_mask=None,
kv_cache_params=None,
attention_params=None,
hidden_states=None,
prompt_embedding_table: Optional[Tensor] = None,
prompt_tasks: Optional[Tensor] = None,
prompt_vocab_size: Optional[Tensor] = None,
lora_params=None,
medusa_position_offsets=None,
medusa_packed_mask=None,
):
assert prompt_embedding_table is not None
zero = constant(np.zeros((1, 1), dtype=self.config.dtype))
# [1, vocab_size]
zeros = constant(np.zeros((1, self.config.vocab_size), dtype=self.config.dtype))
# repeat_interleave only supports int repeats, so we use addition + broadcasting instead
# [len(input_ids), vocab_size]
zeros_repeated = zeros + cast(unsqueeze(input_ids, 1), self.config.dtype) * zero
# fake_logits used to expose prompt_vocab_size value
fake_logits = zeros_repeated + cast(unsqueeze(unsqueeze(prompt_vocab_size, 0), 0), self.config.dtype)
fake_logits.mark_output("logits")
return fake_logits
config.json
{
"architecture": "FakeModel",
"dtype": "float16",
"logits_dtype": "float32",
"vocab_size": 16,
"max_position_embeddings": 128,
"hidden_size": 16,
"num_hidden_layers": 1,
"num_attention_heads": 1,
"num_key_value_heads": 1,
"head_size": 16,
"qk_layernorm": false,
"hidden_act": "silu",
"intermediate_size": 16,
"norm_epsilon": 1e-06,
"position_embedding_type": "rope_gpt_neox",
"use_parallel_embedding": false,
"embedding_sharding_dim": 0,
"share_embedding_table": false,
"mapping": {
"world_size": 1,
"tp_size": 1,
"pp_size": 1,
"gpus_per_node": 1
},
"quantization": {
"quant_algo": null,
"kv_cache_quant_algo": null,
"group_size": 128,
"smoothquant_val": null,
"has_zero_point": false,
"pre_quant_scale": false,
"exclude_modules": [
"lm_head"
]
},
"kv_dtype": "float16"
}
- build engine
trtllm-build \
--model_config=config.json
--model_cls_file=fake_model.py \
--model_cls_name=FakeModel \
--gather_context_logits \
--gather_generation_logits \
--max_prompt_embedding_table_size=20
- Run the engine via python api:
import numpy as np
import os
import torch
from tempfile import TemporaryDirectory
from tensorrt_llm.runtime import ModelRunnerCpp
with TemporaryDirectory() as tmpdir:
ptable_path = os.path.join(tmpdir, "ptable.npy")
np.save(ptable_path, np.ones((1, 7, 16)).astype(np.float16))
runner = ModelRunnerCpp.from_dir(engine_dir="./engine_outputs")
res = runner.generate(
batch_input_ids=[torch.IntTensor([3, 4, 5])],
max_new_tokens=1,
end_id=15,
pad_id=14,
return_dict=True,
output_sequence_lengths=True,
prompt_table=ptable_path
)
print(res['context_logits'])
Expected behavior
context logits are all 7.0 (length of prompt_embedding_table).
actual behavior
context logits are all 20.0 (max_prompt_embedding_table_size)
additional notes
Debugging confirms, that request passed to executor via enqueue_requests call contains correct .prompt_tuning_config - embedding_table of shape
If we replace prompt_vocab_size with shape(prompt_embedding_table, 1) in fake_model.py - result is the same.
This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 15 days."