TensorRT-LLM icon indicating copy to clipboard operation
TensorRT-LLM copied to clipboard

How obtain the classification label of BERT model?

Open zhangjiawei5911 opened this issue 1 year ago • 1 comments

System Info

NVIDIA V100 nvcr.io/nvidia/tritonserver:23.10-trtllm-python-py3

Who can help?

No response

Information

  • [X] The official example scripts
  • [ ] My own modified scripts

Tasks

  • [X] An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • [ ] My own task or dataset (give details below)

Reproduction

python3 build.py --dtype=float16 --log_level=verbose python3 run.py

Expected behavior

I am using TensorRT-LLM to accelerate the inference of a BERT model, which is a multi-classification model.I achieved it by modifying the build.py and run.py files in the /example/bert/ directory. However, in the run.py file, the final output is a tensor with shape [batch_size, max_input_length, label_num]. I expect the output tensor to be of shape [batch_size, label_num], what should I do?

actual behavior

python3 build.py --dtype=float16 --log_level=verbose python3 run.py

additional notes

output_info = session.infer_shapes([ TensorInfo('input_ids', trt.DataType.INT32, (input_ids_tmp.shape[0], input_ids_tmp.shape[1])), TensorInfo('input_lengths', trt.DataType.INT32, (input_ids_tmp.shape[0], )), TensorInfo('token_type_ids', trt.DataType.INT32, token_type_ids_tmp.shape)])

outputs = { t.name: torch.empty(tuple(t.shape), dtype=trt_dtype_to_torch(t.dtype), device='cuda') for t in output_info }

zhangjiawei5911 avatar Jan 22 '24 14:01 zhangjiawei5911

@symphonylyh any updates on this?

poweiw avatar May 16 '25 20:05 poweiw

@zhangjiawei5911 , Apologies for the very delayed response. Is this ticket still relevant?

By the way, if you were seeing 3D output tensors, you were likely using BertModel, which returns hidden_states with shape [batch_size, max_seq_len, hidden_dim]. Instead, you might want to use a model like BertForSequenceClassification, which outputs 'logits' with shape [batch_size, num_labels].

karljang avatar Oct 21 '25 05:10 karljang

Issue has not received an update in over 14 days. Adding stale label.

github-actions[bot] avatar Nov 05 '25 03:11 github-actions[bot]

Closing issue as stale, please feel free to open new one if the problem persists.

karljang avatar Nov 14 '25 17:11 karljang