TensorRT-LLM icon indicating copy to clipboard operation
TensorRT-LLM copied to clipboard

Retrofit sentence bert into BertForSeqClassification

Open caronzh03 opened this issue 8 months ago • 3 comments

Primary goal of this PR is to validate my idea of retrofitting a BERT-based SentenceTransformer model into BertForSequenceClassification model, and use trtllm's LLM API to do inference.

The SentenceTransformer model has 3 modules:

SentenceTransformer(
  (0): Transformer({'max_seq_length': 512, 'do_lower_case': False}) with Transformer model: BertModel 
  (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
  (2): Dense({'in_features': 768, 'out_features': 192, 'bias': False, 'activation_function': 'torch.nn.modules.linear.Identity'})
)

To map the weights above to a BertForSequenceClassification model, I'm planning to:

  1. Copy the Transformer layers' weights.
  2. Change the Pooler layer implementation in modeling_bert.py to computing the average of token embeddings. (This PR has the changes needed)
  3. Copy the Dense layer (same as torch.nn.Linear)'s weights to classifier layer in BertForSequenceClassification.

I have not been able to verify the above method due to some Docker setup issues. @qixiang-99 , if you could help comment on whether the above approach is reasonable, that'll greatly help.

Note - I wasn't planning on merging this PR, because this is a very specific experiment for my own use case. However, if people see values in generalizing this approach, we can certainly polish this and create a new PR against main branch.

caronzh03 avatar Apr 28 '25 22:04 caronzh03

requesting @qixiang-99 's review

btw, @caronzh03 can you re-target the PR to main instead of 0.19 release branch? the rule doesn't allow merges into a release branch

symphonylyh avatar Apr 28 '25 23:04 symphonylyh

Hi @caronzh03 , your plan seems reasonable. One thing to consider is the naming convention. You might find the weights name conversion function, which details the naming mapping individually, helpful. Let me know if I can assist further.

qixiang-99 avatar Apr 29 '25 16:04 qixiang-99

hi @qixiang-99 , I finally got the Docker setup working and was able to test out my changes. I have aBertForSequenceClassification model with 192 classes:

config = BertConfig.from_pretrained(
    my_model_path,
    num_labels=192, 
    ...
)
my_model = BertForSequenceClassification(config)

And when I tried to load this model using trtllm's PyTorch API and run inference on it using the modified quickstart.py, I got this error:

tensorrt_llm.executor.utils.RequestError: LogitsStorage overflow. This storage can only hold 16 logits (0 already filled) but trying to append 192 more logits
Traceback (most recent call last):
  File "/mnt/task_runtime/myenv/lib/python3.10/site-packages/tensorrt_llm/_torch/pyexecutor/py_executor.py", line 1732, in _update_requests
    self.decoder.update_requests(decoder_state)
  File "/mnt/task_runtime/myenv/lib/python3.10/site-packages/tensorrt_llm/_torch/pyexecutor/decoder.py", line 78, in update_requests
    request.py_result.append_context_logits(logits)
  File "/mnt/task_runtime/myenv/lib/python3.10/site-packages/tensorrt_llm/_torch/pyexecutor/llm_request.py", line 138, in append_context_logits
    self._context_logits.append(context_logits)
  File "/mnt/task_runtime/myenv/lib/python3.10/site-packages/tensorrt_llm/_torch/pyexecutor/llm_request.py", line 79, in append
    raise ValueError(
ValueError: LogitsStorage overflow. This storage can only hold 33 logits (0 already filled) but trying to append 192 more logits
[05/16/2025-17:26:23] [TRT-LLM] [E] Encountered an error in decode: LogitsStorage overflow. This storage can only hold 33 logits (0 already filled) but trying to append 192 more logits

Wondering why we have a limit on logit size, and how do I fix this issue?

caronzh03 avatar May 17 '25 00:05 caronzh03

We do not accept any changes in the release branch. Please target main.

MartinMarciniszyn avatar May 19 '25 07:05 MartinMarciniszyn

closing in favor of another PR against main: https://github.com/NVIDIA/TensorRT-LLM/pull/4462

caronzh03 avatar May 20 '25 00:05 caronzh03