TensorRT-LLM
TensorRT-LLM copied to clipboard
BERT Model is Inaccurate
System Info
- CPU: i9 9900k
- GPU: RTX 4090
- TensorRT-LLM Version: 0.9.0.dev2024022000
- Cuda Version: Cuda 12.3
- Driver Version: 545.29.06
- OS: Arch Linux, kernel version 6.7.5
Who can help?
@byshiue
Information
- [X] The official example scripts
- [X] My own modified scripts
Tasks
- [ ] An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - [X] My own task or dataset (give details below)
Reproduction
- Create a build script to build the TRT-LLM engine (build.py):
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
from collections import OrderedDict
import numpy as np
import tensorrt as trt
import tensorrt_llm
import torch
from tensorrt_llm.builder import Builder
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.network import net_guard
from tensorrt_llm.plugin.plugin import ContextFMHAType
from transformers import AutoModel, BertConfig, BertPreTrainedModel
def parse_arguments():
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
parser.add_argument(
"--world_size", type=int, default=1, help="Tensor parallelism size"
)
parser.add_argument("--rank", type=int, default=0)
parser.add_argument(
"--dtype", type=str, default="float16", choices=["float16", "float32"]
)
parser.add_argument("--timing_cache", type=str, default="model.cache")
parser.add_argument(
"--profiling_verbosity",
type=str,
default="layer_names_only",
choices=["layer_names_only", "detailed", "none"],
help="The profiling verbosity for the generated TRT engine. Set to detailed can inspect tactic choices and kernel parameters.",
)
parser.add_argument("--log_level", type=str, default="info")
parser.add_argument("--max_batch_size", type=int, default=256)
parser.add_argument("--max_input_len", type=int, default=512)
parser.add_argument("--gpus_per_node", type=int, default=1)
parser.add_argument("--output_dir", type=str)
parser.add_argument(
"--use_bert_attention_plugin",
nargs="?",
const="float16",
type=str,
default=False,
choices=["float16", "float32"],
)
parser.add_argument(
"--use_gemm_plugin",
nargs="?",
const="float16",
type=str,
default=False,
choices=["float16", "float32"],
)
parser.add_argument("--enable_qk_half_accum", default=False, action="store_true")
parser.add_argument("--enable_context_fmha", default=False, action="store_true")
parser.add_argument(
"--enable_context_fmha_fp32_acc", default=False, action="store_true"
)
parser.add_argument("--model", type=str, help="Model id")
return parser.parse_args()
def get_engine_name(model, dtype, tp_size, rank):
return "{}_{}_tp{}_rank{}.engine".format(
model.replace("/", "--"), dtype, tp_size, rank
)
def extract_layer_idx(name):
ss = name.split(".")
for s in ss:
if s.isdigit():
return s
return None
def split(v, tp_size, idx, dim=0):
if tp_size == 1:
return v
if len(v.shape) == 1:
return np.ascontiguousarray(np.split(v, tp_size)[idx].copy())
elif len(v.shape) == 2:
return np.ascontiguousarray(np.split(v, tp_size, axis=dim)[idx].copy())
return None
def load_from_hf_model(
tensorrt_llm_model: tensorrt_llm.models.BertModel,
hf_model: BertPreTrainedModel,
hf_model_config: BertConfig,
rank=0,
tensor_parallel=1,
fp16=False,
):
qkv_weight = [[None, None, None] for _ in range(hf_model_config.num_hidden_layers)]
qkv_bias = [[None, None, None] for _ in range(hf_model_config.num_hidden_layers)]
torch_dtype = torch.float16 if fp16 else torch.float32
for k, v in hf_model.state_dict().items():
v = v.to(torch_dtype).cpu().numpy()
if "embeddings.word_embeddings.weight" in k:
tensorrt_llm_model.embedding.vocab_embedding.weight.value = v
elif "embeddings.position_embeddings.weight" in k:
tensorrt_llm_model.embedding.position_embedding.weight.value = v
elif "embeddings.token_type_embeddings.weight" in k:
tensorrt_llm_model.embedding.token_embedding.weight.value = v
elif "embeddings.LayerNorm.weight" in k:
tensorrt_llm_model.embedding.embedding_ln.weight.value = v
elif "embeddings.LayerNorm.bias" in k:
tensorrt_llm_model.embedding.embedding_ln.bias.value = v
else:
layer_idx = extract_layer_idx(k)
if layer_idx is None:
continue
idx = int(layer_idx)
if "attention.output.dense.weight" in k:
tensorrt_llm_model.layers[idx].attention.dense.weight.value = split(
v, tensor_parallel, rank, dim=1
)
elif "attention.output.dense.bias" in k:
tensorrt_llm_model.layers[idx].attention.dense.bias.value = v
elif "attention.output.LayerNorm.weight" in k:
tensorrt_llm_model.layers[idx].input_layernorm.weight.value = v
elif "attention.output.LayerNorm.bias" in k:
tensorrt_llm_model.layers[idx].input_layernorm.bias.value = v
elif "intermediate.dense.weight" in k:
tensorrt_llm_model.layers[idx].mlp.fc.weight.value = split(
v, tensor_parallel, rank
)
elif "intermediate.dense.bias" in k:
tensorrt_llm_model.layers[idx].mlp.fc.bias.value = split(
v, tensor_parallel, rank
)
elif "output.dense.weight" in k:
tensorrt_llm_model.layers[idx].mlp.proj.weight.value = split(
v, tensor_parallel, rank, dim=1
)
elif "output.dense.bias" in k:
tensorrt_llm_model.layers[idx].mlp.proj.bias.value = v
elif "output.LayerNorm.weight" in k:
tensorrt_llm_model.layers[idx].post_layernorm.weight.value = v
elif "output.LayerNorm.bias" in k:
tensorrt_llm_model.layers[idx].post_layernorm.bias.value = v
elif "attention.self.query.weight" in k:
qkv_weight[idx][0] = v
elif "attention.self.query.bias" in k:
qkv_bias[idx][0] = v
elif "attention.self.key.weight" in k:
qkv_weight[idx][1] = v
elif "attention.self.key.bias" in k:
qkv_bias[idx][1] = v
elif "attention.self.value.weight" in k:
qkv_weight[idx][2] = v
elif "attention.self.value.bias" in k:
qkv_bias[idx][2] = v
for i in range(hf_model_config.num_hidden_layers):
tensorrt_llm_model.layers[i].attention.qkv.weight.value = split(
np.concatenate(qkv_weight[i]), tensor_parallel, rank
)
tensorrt_llm_model.layers[i].attention.qkv.bias.value = split(
np.concatenate(qkv_bias[i]), tensor_parallel, rank
)
def main():
args = parse_arguments()
tensorrt_llm.logger.set_level(args.log_level)
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
bs_range = [1, (args.max_batch_size + 1) // 2, args.max_batch_size]
inlen_range = [1, (args.max_input_len + 1) // 2, args.max_input_len]
torch_dtype = torch.float16 if args.dtype == "float16" else torch.float32
trt_dtype = trt.float16 if args.dtype == "float16" else trt.float32
builder = Builder()
builder_config = builder.create_builder_config(
name=args.model,
precision=args.dtype,
timing_cache=args.timing_cache,
profiling_verbosity=args.profiling_verbosity,
tensor_parallel=args.world_size, # TP only
max_batch_size=args.max_batch_size,
max_input_len=args.max_input_len,
)
hf_model = (
AutoModel.from_pretrained(args.model, torch_dtype=torch_dtype)
.cuda()
.to(torch_dtype)
.eval()
)
output_name = "hidden_states"
tensorrt_llm_bert = tensorrt_llm.models.BertModel(
num_layers=hf_model.config.num_hidden_layers,
num_heads=hf_model.config.num_attention_heads,
hidden_size=hf_model.config.hidden_size,
vocab_size=hf_model.config.vocab_size,
hidden_act=hf_model.config.hidden_act,
max_position_embeddings=hf_model.config.max_position_embeddings,
type_vocab_size=hf_model.config.type_vocab_size,
pad_token_id=hf_model.config.pad_token_id,
is_roberta=False,
mapping=Mapping(
world_size=args.world_size, rank=args.rank, tp_size=args.world_size
), # TP only
dtype=trt_dtype,
)
load_from_hf_model(
tensorrt_llm_bert,
hf_model,
hf_model.config,
rank=args.rank,
tensor_parallel=args.world_size,
fp16=(args.dtype == "float16"),
)
network = builder.create_network()
network.plugin_config.to_legacy_setting()
if args.use_bert_attention_plugin:
network.plugin_config.set_bert_attention_plugin(
dtype=args.use_bert_attention_plugin
)
if args.use_gemm_plugin:
network.plugin_config.set_gemm_plugin(dtype=args.use_gemm_plugin)
if args.enable_qk_half_accum:
network.plugin_config.enable_qk_half_accum()
assert not (args.enable_context_fmha and args.enable_context_fmha_fp32_acc)
if args.enable_context_fmha:
network.plugin_config.set_context_fmha(ContextFMHAType.enabled)
if args.enable_context_fmha_fp32_acc:
network.plugin_config.set_context_fmha(ContextFMHAType.enabled_with_fp32_acc)
if args.world_size > 1:
network.plugin_config.set_nccl_plugin(args.dtype)
with net_guard(network):
# Prepare
network.set_named_parameters(tensorrt_llm_bert.named_parameters())
# Forward
input_ids = tensorrt_llm.Tensor(
name="input_ids",
dtype=trt.int32,
shape=[-1, -1],
dim_range=OrderedDict(
[("batch_size", [bs_range]), ("input_len", [inlen_range])]
),
)
# also called segment_ids
token_type_ids = tensorrt_llm.Tensor(
name="token_type_ids",
dtype=trt.int32,
shape=[-1, -1],
dim_range=OrderedDict(
[("batch_size", [bs_range]), ("input_len", [inlen_range])]
),
)
input_lengths = tensorrt_llm.Tensor(
name="input_lengths",
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([("batch_size", [bs_range])]),
)
output = tensorrt_llm_bert(
input_ids=input_ids,
input_lengths=input_lengths,
token_type_ids=token_type_ids,
)
output_dtype = trt.float16 if args.dtype == "float16" else trt.float32
output.mark_output(output_name, output_dtype)
# Network -> Engine
engine = builder.build_engine(network, builder_config)
assert engine is not None, "Failed to build engine."
engine_file = os.path.join(
args.output_dir,
get_engine_name(args.model, args.dtype, args.world_size, args.rank),
)
with open(engine_file, "wb") as f:
f.write(engine)
builder.save_config(builder_config, os.path.join(args.output_dir, "config.json"))
if __name__ == "__main__":
main()
- Run the build script with this command:
python3 build.py \
--dtype float16 \
--max_batch_size 512 \
--max_input_len 512 \
--gpus_per_node 1 \
--output_dir trtllm_bge \
--model BAAI/bge-base-en-v1.5 \
--use_bert_attention_plugin float16 \
--use_gemm_plugin float16 \
--enable_context_fmha \
- Create a run script to compare the TRT-LLM engine with the Huggingface model (run.py):
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import os
import time
import tensorrt as trt
import tensorrt_llm
import torch
from build import get_engine_name
from datasets import concatenate_datasets, load_dataset
from tensorrt_llm import logger
from tensorrt_llm.runtime import Session, TensorInfo
from torch.nn.functional import cosine_similarity
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer
def trt_dtype_to_torch(dtype):
if dtype == trt.float16:
return torch.float16
elif dtype == trt.float32:
return torch.float32
elif dtype == trt.int32:
return torch.int32
else:
raise TypeError("%s is not supported" % dtype)
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("--log_level", type=str, default="info")
parser.add_argument("--engine_dir", type=str, required=True)
parser.add_argument("--dataset", type=str, default="shreyasharma/sentences_truth")
parser.add_argument("--batch_size", type=int, default=256)
parser.add_argument("--run_hf", action="store_true")
parser.add_argument("--run_trtllm", action="store_true")
parser.add_argument("--remove_columns", default="labels", type=str)
parser.add_argument("--target_column", default="sentences", type=str)
return parser.parse_args()
def main():
args = parse_arguments()
tensorrt_llm.logger.set_level(args.log_level)
config_path = os.path.join(args.engine_dir, "config.json")
with open(config_path, "r") as f:
config = json.load(f)
dtype = config["builder_config"]["precision"]
world_size = config["builder_config"]["tensor_parallel"]
assert (
world_size == tensorrt_llm.mpi_world_size()
), f"Engine world size ({world_size}) != Runtime world size ({tensorrt_llm.mpi_world_size()})"
model_name = config["builder_config"]["name"]
runtime_rank = tensorrt_llm.mpi_rank() if world_size > 1 else 0
runtime_mapping = tensorrt_llm.Mapping(world_size, runtime_rank, tp_size=world_size)
torch.cuda.set_device(runtime_rank % runtime_mapping.gpus_per_node)
serialize_path = get_engine_name(model_name, dtype, world_size, runtime_rank)
serialize_path = os.path.join(args.engine_dir, serialize_path)
stream = torch.cuda.current_stream().cuda_stream
logger.info(f"Loading engine from {serialize_path}")
with open(serialize_path, "rb") as f:
engine_buffer = f.read()
logger.info(f"Creating session from engine")
session = Session.from_serialized_engine(engine_buffer)
tokenizer = AutoTokenizer.from_pretrained(model_name)
if args.run_hf:
hf_model = (
AutoModel.from_pretrained(
model_name,
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
)
.eval()
.cuda()
)
dataset = load_dataset(args.dataset)
dataset = concatenate_datasets(list(dataset.values())).remove_columns(
args.remove_columns
)
dataloader = DataLoader(dataset, batch_size=args.batch_size)
total_time_trtllm = 0
total_trtllm = 0
total_time_hf = 0
total_hf = 0
for batch in tqdm(dataloader, unit_scale=args.batch_size, unit=" samples"):
encoded_input = tokenizer(
batch[args.target_column],
padding="longest",
truncation=True,
return_tensors="pt",
)
input_ids = encoded_input["input_ids"].cuda()
input_lengths = torch.tensor(
[seq.shape[-1] for seq in encoded_input["input_ids"]]
).cuda()
token_type_ids = encoded_input["token_type_ids"].cuda()
inputs = {
"input_ids": input_ids,
"input_lengths": input_lengths,
"token_type_ids": token_type_ids,
}
output_info = session.infer_shapes(
[
TensorInfo("input_ids", trt.DataType.INT32, input_ids.shape),
TensorInfo("input_lengths", trt.DataType.INT32, input_lengths.shape),
TensorInfo("token_type_ids", trt.DataType.INT32, token_type_ids.shape),
]
)
outputs = {
t.name: torch.empty(
tuple(t.shape), dtype=trt_dtype_to_torch(t.dtype), device="cuda"
)
for t in output_info
}
output_name = "hidden_states"
assert (
output_name in outputs
), f"{output_name} not found in outputs, check if build.py set the name correctly"
if args.run_trtllm:
start = time.time()
ok = session.run(inputs, outputs, stream)
total_time_trtllm += time.time() - start
total_trtllm += 1
assert ok, "Runtime execution failed"
torch.cuda.synchronize()
res = outputs[output_name]
trtllm_embeddings = res[:, 0, :] # perform cls pooling
if args.run_hf:
with torch.inference_mode():
start = time.time()
model_output = hf_model(
input_ids=input_ids, token_type_ids=token_type_ids
)
total_time_hf += time.time() - start
total_hf += 1
hf_embeddings = model_output[0][:, 0, :] # perform cls pooling
if args.run_hf and args.run_trtllm:
print(
f"Average Cosine Distance: {(1 - cosine_similarity(trtllm_embeddings, hf_embeddings).mean().item()):.2f}"
)
if args.run_hf:
print(total_time_hf / total_hf)
if args.run_trtllm:
print(total_time_trtllm / total_trtllm)
if __name__ == "__main__":
main()
- Run this command to launch the run script:
python3 run.py \
--engine_dir trtllm_bge/ \
--batch_size 512 \
--dataset lighteval/med_paragraph_simplification \
--remove_columns answer \
--target_column query \
--run_hf \
--run_trtllm \
- Observe the measured cosine distance which is not close to 0.
Expected behavior
It is expected when running the huggingface model and the TRT-LLM model that the cosine distance between the results should be relatively close to 0 as both models were run on the same data.
actual behavior
The measured cosine distance between the huggingface model and TRT-LLM model outputs is not very close to 0, being approximately 0.5. While with random high dimensional vectors, a cosine distance of 0.5 means that the vectors are very similar (see distribution below), it is puzzling why this cosine distance for the models' embeddings is not closer to 0.
additional notes
I believe this could be some kind of numerical error in the Bert model. I've tried the model in float32, float16, and with the various plugins disabled or enabled. I've also tried an uninitialized TRT-LLM model compared to an initialized huggingface model and I observe that the cosine distance is higher.
Currently am having same issue.
Facing the same issue, for finetuned Bert model for sequence classification. The logits are wrong
I am encountering the same issue when using BertForSequenceClassification. It seems to be a decimal precision problem when I debug the pooler output. Does anyone have any suggestions for further debugging? @byshiue @symphonylyh
huggingface(pytorch) pooler output:
pooled_output=tensor([[ 8.7109e-01, 8.5156e-01, -6.5918e-01, -9.4678e-01, -9.9756e-01,
6.0010e-01, -8.6914e-01, -9.8438e-01, -9.9951e-01, 2.9297e-01,
-5.8301e-01, -9.4971e-01, 9.9414e-01, -8.8818e-01, -9.9268e-01,
-8.5791e-01, -8.3594e-01, 9.1602e-01, 9.9170e-01, 9.7998e-01,
-2.5708e-01, 9.9805e-01, 9.7070e-01, 9.7754e-01, -9.3945e-01,
5.1660e-01, -9.9902e-01, -9.9854e-01, 7.8125e-01, -3.3081e-01,
-7.6953e-01, -9.6631e-01, -5.7373e-01, -9.6729e-01, 8.8135e-01,
9.8926e-01, -9.8193e-01, 9.1748e-01, -8.1982e-01, 8.4778e-02,
-3.1104e-01, 9.6289e-01, 6.6602e-01, 9.9756e-01, -6.9629e-01,
9.5508e-01, 9.6826e-01, 9.9707e-01, -9.3555e-01, -9.9561e-01,
9.5654e-01, 9.5117e-01, -9.9414e-01, 7.2607e-01, -9.9365e-01,
6.7139e-01, -7.5146e-01, -9.6094e-01, -9.9023e-01, 7.2510e-01,
3.6987e-02, 9.8389e-01, 4.5386e-01, 4.1797e-01, -5.1318e-01,
-6.9336e-01, -9.0430e-01, -9.8096e-01, 8.4668e-01, -9.3896e-01,
6.7480e-01, 9.4092e-01, -9.0527e-01, 4.3506e-01, -7.7637e-01,
4.8682e-01, 4.2651e-01, 8.8086e-01, -2.6587e-01, -9.9170e-01,
-9.9316e-01, -9.9609e-01, 9.9365e-01, 1.4819e-01, 9.8926e-01,
8.3057e-01, -9.9609e-01, 5.6250e-01, -9.6533e-01, 8.0225e-01,
9.9365e-01, 9.7900e-01, 9.7998e-01, -3.0225e-01, 7.6221e-01,
9.8779e-01, -9.9023e-01, -4.3384e-01, 9.9805e-01, 9.0381e-01,
7.4463e-01, -1.9946e-01, 1.3721e-01, 8.5791e-01, 9.1455e-01,
-9.7705e-01, 9.9707e-01, 4.0039e-01, -3.3447e-02, -9.8779e-01,
8.7891e-01, -6.2061e-01, 1.0000e+00, -9.7949e-01, -9.9512e-01,
-3.2544e-01, -4.5972e-01, 9.6143e-01, -4.7778e-01, 1.0000e+00,
5.9912e-01, 4.6704e-01, -2.3779e-01, -9.9902e-01, 9.5996e-01,
-9.7607e-01, -8.7793e-01, -3.2886e-01, -9.9951e-01, -9.9902e-01,
5.2979e-01, -7.2998e-01, -7.7051e-01, 9.9316e-01, -9.5312e-01,
-9.9023e-01, -2.7759e-01, 7.6416e-01, -9.6289e-01, -9.9902e-01,
2.6440e-01, 9.9854e-01, 9.8193e-01, 2.1033e-01, -9.6973e-01,
9.6191e-01, -9.5020e-01, -5.9082e-01, 7.3193e-01, 9.8926e-01,
3.9276e-02, -6.1523e-01, 4.3213e-01, 1.9250e-01, -4.0454e-01,
-9.8340e-01, -8.7744e-01, 4.1870e-01, -9.1943e-01, 9.9707e-01,
9.1309e-01, -7.5098e-01, -6.7627e-01, 6.0303e-01, 9.6533e-01,
-8.7695e-01, -9.8389e-01, -9.9512e-01, -9.7705e-01, -9.4629e-01,
-1.6797e-01, -9.7412e-01, 8.2031e-01, 8.8916e-01, 7.0898e-01,
9.0332e-01, 5.8594e-01, 3.3521e-01, 9.9561e-01, -8.2471e-01,
6.8604e-01, -4.4189e-01, 6.0059e-01, 9.7314e-01, 3.5205e-01,
8.7695e-01, 8.6377e-01, 3.6438e-02, -3.2300e-01, -9.0137e-01,
2.9077e-01, -3.6987e-01, -6.9238e-01, 8.0225e-01, -3.7793e-01,
-9.5850e-01, -6.2256e-01, 9.6436e-01, 9.7461e-01, 9.7363e-01,
8.8184e-01, -2.3877e-01, 9.8193e-01, 8.6523e-01, -8.2617e-01,
3.3179e-01, -1.0000e+00, 9.7949e-01, 9.9805e-01, 9.2285e-01,
9.6777e-01, 8.4570e-01, 9.8828e-01, -8.3447e-01, 9.9902e-01,
9.8438e-01, 7.5293e-01, -9.9707e-01, -9.9902e-01, 7.4658e-01,
-9.9658e-01, 9.4824e-01, -9.9951e-01, 5.3857e-01, 9.8828e-01,
-9.9658e-01, -4.8120e-01, -6.7627e-01, -9.9854e-01, 6.0400e-01,
8.8525e-01, 1.7242e-02, 6.3965e-01, 9.1748e-01, -8.8965e-01,
-7.4219e-01, 9.3750e-01, -9.9121e-01, 9.9561e-01, 3.5767e-02,
9.3555e-01, -9.8535e-01, -9.2920e-01, -4.6338e-01, 9.2334e-01,
7.6562e-01, -9.2871e-01, 9.9268e-01, -5.9814e-01, -8.9551e-01,
9.0576e-01, -9.9365e-01, -9.9707e-01, -9.7559e-01, -9.3115e-01,
-9.7510e-01, -9.9902e-01, -8.5596e-01, -3.9233e-01, 5.8154e-01,
-3.6377e-01, 3.8916e-01, 9.5752e-01, -7.6465e-01, 7.7734e-01,
-6.8066e-01, -6.9775e-01, -4.7729e-01, -9.4189e-01, -9.7852e-01,
9.7412e-01, -2.7271e-01, 1.0754e-01, 9.1650e-01, -7.3877e-01,
-9.0430e-01, -5.8887e-01, 8.2422e-01, 9.9951e-01, 4.7388e-01,
8.1689e-01, -6.4819e-02, 7.4316e-01, 9.9512e-01, 6.9580e-01,
5.5713e-01, -9.6289e-01, 9.6387e-01, -7.8662e-01, 9.9805e-01,
-1.4307e-01, 9.8926e-01, 8.7988e-01, 8.6914e-01, -7.8857e-01,
4.1431e-01, 9.9365e-01, 6.7773e-01, -6.3379e-01, -9.6729e-01,
9.9902e-01, -9.6826e-01, 6.7285e-01, -9.5996e-01, 9.1504e-01,
7.5635e-01, -9.6826e-01, -3.7817e-01, -8.3789e-01, 8.2324e-01,
-9.9951e-01, 6.2451e-01, -9.2383e-01, -9.7656e-01, -9.4873e-01,
9.7705e-01, -8.4814e-01, -9.9951e-01, -9.5459e-01, -8.9062e-01,
9.9805e-01, -9.9414e-01, -7.2083e-02, 9.9023e-01, -8.7646e-01,
7.7197e-01, -1.0931e-01, -9.0527e-01, -9.9902e-01, -7.4414e-01,
-8.0371e-01, 9.5654e-01, -9.9951e-01, 3.3472e-01, -9.9023e-01,
-8.7061e-01, -6.3184e-01, 4.5972e-01, -1.0000e+00, -6.6797e-01,
9.2041e-01, -9.9854e-01, -9.9902e-01, -9.2334e-01, -9.3604e-01,
9.9609e-01, 9.3701e-01, 9.9854e-01, -1.5649e-01, -9.5215e-01,
9.9902e-01, 8.3984e-01, -9.9951e-01, 9.9951e-01, -9.9365e-01,
6.0150e-02, -9.6387e-01, 9.6533e-01, -9.9609e-01, -9.8975e-01,
-9.2188e-01, -8.6572e-01, -2.1912e-01, -1.4404e-01, -7.1582e-01,
1.0000e+00, -3.2959e-01, 9.4385e-01, 9.8779e-01, 9.8340e-01,
7.3242e-01, 9.8926e-01, -7.3682e-01, -9.9951e-01, -9.8730e-01,
8.5010e-01, -9.9316e-01, 9.3799e-01, -7.1729e-01, -9.6240e-01,
-3.8623e-01, 8.9941e-01, 9.9805e-01, 9.9414e-01, 9.9121e-01,
-9.8535e-01, 6.0596e-01, -9.9707e-01, 4.4434e-01, -2.9663e-01,
-9.0820e-01, -2.1484e-01, -9.8535e-01, -9.8535e-01, 5.3009e-02,
-8.5010e-01, 9.2627e-01, -9.9414e-01, -8.7354e-01, 9.7473e-02,
9.5557e-01, -9.9609e-01, 1.0000e+00, 4.2261e-01, 9.2871e-01,
9.2383e-01, 7.5439e-01, -6.9385e-01, -9.4385e-01, 4.3018e-01,
-9.1553e-01, -9.3408e-01, 6.5332e-01, -9.4629e-01, -9.2822e-01,
8.1006e-01, 8.6328e-01, -3.2080e-01, 6.9971e-01, -6.7285e-01,
4.7803e-01, -1.3293e-01, -9.6924e-01, -1.2842e-01, 9.6240e-01,
-9.1846e-01, -9.8730e-01, -9.8877e-01, -9.9121e-01, 9.9902e-01,
-6.2207e-01, -7.5195e-01, -8.9258e-01, 9.1553e-01, 9.8828e-01,
-9.6680e-01, 9.9902e-01, 9.4043e-01, 8.6670e-01, -9.9902e-01,
-4.9634e-01, 7.4902e-01, -6.8750e-01, -6.4392e-02, -9.5215e-01,
-9.9805e-01, 1.3708e-01, -4.8267e-01, 8.5205e-01, 9.1357e-01,
-6.0449e-01, 9.0137e-01, 2.6465e-01, 5.1221e-01, -9.9268e-01,
4.5044e-01, -9.4482e-01, 9.5654e-01, -9.1162e-01, 9.5312e-01,
-9.4043e-01, 9.7949e-01, 8.3691e-01, 9.6338e-01, -5.0146e-01,
-6.4502e-01, -2.7832e-01, 9.7803e-01, -9.3750e-01, 9.7559e-01,
-1.9638e-02, 9.5557e-01, -9.8340e-01, 9.9463e-01, -9.3018e-01,
-3.9380e-01, 2.6514e-01, -8.4400e-04, -7.6221e-01, 5.2100e-01,
2.7612e-01, -8.8184e-01, -4.9292e-01, -2.0923e-01, -9.7607e-01,
9.9756e-01, 9.7461e-01, 4.5801e-01, 9.4336e-01, 9.9219e-01,
2.1497e-01, 9.4116e-02, -3.9551e-01, -1.3989e-01, -8.9844e-01,
-5.2197e-01, 4.0259e-01, -8.3496e-01, -8.0859e-01, -8.6572e-01,
-8.1934e-01, -4.1895e-01, -3.8892e-01, -3.3838e-01, -9.2627e-01,
-2.1216e-01, -9.6973e-01, 2.0593e-01, 9.3750e-01, 3.2788e-01,
-2.3346e-02, -1.0000e+00, -7.3096e-01, -5.5542e-02, 9.8145e-01,
9.0088e-01, -8.8623e-01, 9.8340e-01, 9.6729e-01, -7.9590e-01,
-4.6753e-01, 9.9902e-01, -5.0928e-01, 8.5400e-01, -9.5361e-01,
7.2705e-01, 9.4043e-01, -9.5117e-01, 7.0215e-01, -9.1895e-01,
-9.9658e-01, 9.7754e-01, 9.9707e-01, -9.9170e-01, 4.2505e-01,
-9.8926e-01, -3.0664e-01, 5.9473e-01, 7.7197e-01, -9.6484e-01,
3.9551e-01, -9.8340e-01, -5.8740e-01, 9.9902e-01, 9.5020e-01,
9.9170e-01, 9.7168e-01, 9.7852e-01, 8.5742e-01, 6.1963e-01,
6.0742e-01, 1.8396e-01, -8.8818e-01, 9.9902e-01, -8.9453e-01,
-9.9365e-01, 9.4629e-01, 7.2852e-01, -9.5996e-01, 2.6587e-01,
9.6045e-01, -8.3057e-01, 9.4922e-01, 9.9658e-01, -1.0000e+00,
1.0000e+00, 6.3086e-01, 6.3770e-01, -5.7526e-02, 8.8037e-01,
2.4841e-01, 9.8145e-01, 8.2373e-01, 9.4336e-01, -9.9805e-01,
9.6826e-01, -8.6963e-01, 9.1211e-01, 8.5547e-01, 9.9756e-01,
9.7705e-01, -4.4751e-01, 8.0859e-01, 9.3652e-01, 9.0625e-01,
-8.2031e-01, -5.5786e-02, -8.3057e-01, -9.9902e-01, 9.8828e-01,
-7.9736e-01, -5.0098e-01, -1.0000e+00, 3.8208e-01, 9.8877e-01,
-8.3838e-01, -9.9561e-01, -9.9951e-01, -7.6660e-01, -9.6094e-01,
7.8223e-01, 9.0430e-01, 7.6758e-01, -9.4336e-01, 9.8779e-01,
9.7754e-01, 5.6885e-01, 5.7471e-01, 9.6533e-01, 9.9512e-01,
-7.6660e-01, -9.8926e-01, -9.4580e-01, -9.9951e-01, -6.5186e-01,
-1.4417e-01, -9.2090e-01, 9.4775e-01, -5.8203e-01, 9.9365e-01,
8.3740e-01, 7.1094e-01, 6.0791e-01, 9.0088e-01, -9.9170e-01,
-4.5874e-01, -9.7852e-01, -2.2327e-01, 9.9268e-01, -2.5488e-01,
9.7559e-01, 2.5415e-01, 9.1016e-01, 9.0088e-01, -9.9756e-01,
-3.6914e-01, 8.7500e-01, -9.9170e-01, -9.7998e-01, -9.8975e-01,
-8.3301e-01, 9.2432e-01, -4.0381e-01, 1.0000e+00, 6.9775e-01,
-6.8115e-01, 9.8096e-01, -9.3652e-01, 9.2139e-01, 1.6858e-01,
-9.9365e-01, -9.9072e-01, -6.9580e-01, -9.9854e-01, 7.6807e-01,
9.9512e-01, -9.5557e-01, -9.9951e-01, 8.7012e-01, -9.1553e-01,
8.5889e-01, 9.1455e-01, -9.9561e-01, -3.9673e-01, 9.5947e-02,
9.8389e-01, -8.8281e-01, 9.8535e-01, 8.7842e-01, -7.6709e-01,
1.0000e+00, -8.0859e-01, -9.9707e-01, 7.5830e-01, -9.7461e-01,
-9.9854e-01, 7.5073e-02, -7.7881e-01, 9.9756e-01, -9.9414e-01,
-8.2373e-01, -6.9189e-01, -9.6924e-01, -9.9512e-01, -5.2539e-01,
5.4138e-02, -9.9658e-01, 9.6924e-01, -9.8926e-01, 4.6899e-01,
-4.9121e-01, -3.7476e-01, -9.9707e-01, -9.3506e-01, -9.9951e-01,
-1.8005e-01, 8.3887e-01, -6.4844e-01, 1.2917e-02, 8.7500e-01,
7.5830e-01, -8.8135e-01, -8.1445e-01, -8.8330e-01, 9.1699e-01,
-2.6929e-01, 9.6338e-01, 5.8057e-01, 6.8604e-01, -9.9756e-01,
-9.3359e-01, -5.0439e-01, -5.9668e-01, -1.0000e+00, -7.4854e-01,
-9.1846e-01, 9.0186e-01, 9.2871e-01, -8.5938e-01, -9.5166e-01,
-9.8926e-01, -9.7754e-01, 9.9951e-01, -9.9951e-01, -9.9023e-01,
-8.5083e-02, -9.9707e-01, -9.9707e-01, -2.3206e-01, 6.0547e-01,
9.9756e-01, -9.8291e-01, -1.7609e-02, -9.5410e-01, 9.6484e-01,
-9.8975e-01, -8.0811e-01, -8.7695e-01, 3.1226e-01, 8.6230e-01,
-9.7754e-01, 8.4766e-01, -9.9902e-01, -9.3408e-01, -9.0820e-01,
-9.5068e-01, -4.0088e-01, 9.9951e-01, -9.9854e-01, -7.4219e-01,
-9.3945e-01, -8.6621e-01, -1.2549e-01, -9.9707e-01, -9.7803e-01,
-1.0000e+00, -9.8535e-01, 6.7041e-01, -9.9805e-01, -9.6973e-01,
8.7061e-01, -1.0000e+00, 1.9214e-01, -8.9404e-01, -9.1748e-01,
8.4863e-01, -9.8730e-01, -8.6279e-01]], dtype=torch.float16)
trt pooler output:
tensor([[ 0.8721, 0.8516, -0.6592, -0.9468, -0.9980, 0.5962, -0.8687, -0.9844,
-1.0000, 0.2915, -0.5830, -0.9507, 0.9941, -0.8882, -0.9922, -0.8574,
-0.8335, 0.9155, 0.9922, 0.9800, -0.2546, 0.9980, 0.9707, 0.9771,
-0.9395, 0.5225, -0.9995, -0.9985, 0.7822, -0.3308, -0.7681, -0.9658,
-0.5767, -0.9673, 0.8809, 0.9888, -0.9819, 0.9175, -0.8188, 0.0870,
-0.3081, 0.9629, 0.6665, 0.9976, -0.6968, 0.9546, 0.9673, 0.9966,
-0.9355, -0.9956, 0.9575, 0.9517, -0.9941, 0.7256, -0.9941, 0.6709,
-0.7534, -0.9619, -0.9902, 0.7241, 0.0383, 0.9839, 0.4536, 0.4172,
-0.5083, -0.6934, -0.9033, -0.9805, 0.8481, -0.9380, 0.6753, 0.9414,
-0.9058, 0.4341, -0.7769, 0.4851, 0.4202, 0.8809, -0.2686, -0.9922,
-0.9927, -0.9961, 0.9937, 0.1519, 0.9888, 0.8315, -0.9961, 0.5635,
-0.9653, 0.8027, 0.9937, 0.9785, 0.9800, -0.3008, 0.7607, 0.9883,
-0.9902, -0.4282, 0.9980, 0.9048, 0.7446, -0.2002, 0.1418, 0.8584,
0.9141, -0.9771, 0.9966, 0.3979, -0.0364, -0.9878, 0.8794, -0.6177,
1.0000, -0.9790, -0.9946, -0.3279, -0.4536, 0.9619, -0.4797, 1.0000,
0.5986, 0.4734, -0.2367, -0.9995, 0.9600, -0.9766, -0.8789, -0.3315,
-0.9995, -0.9995, 0.5293, -0.7285, -0.7715, 0.9937, -0.9536, -0.9902,
-0.2771, 0.7627, -0.9629, -0.9985, 0.2561, 0.9985, 0.9819, 0.2083,
-0.9692, 0.9619, -0.9507, -0.5913, 0.7295, 0.9888, 0.0452, -0.6172,
0.4304, 0.1924, -0.4055, -0.9829, -0.8770, 0.4163, -0.9209, 0.9966,
0.9121, -0.7505, -0.6733, 0.6035, 0.9653, -0.8770, -0.9839, -0.9946,
-0.9771, -0.9468, -0.1676, -0.9751, 0.8188, 0.8882, 0.7095, 0.9033,
0.5874, 0.3372, 0.9956, -0.8252, 0.6841, -0.4419, 0.6001, 0.9731,
0.3528, 0.8770, 0.8633, 0.0338, -0.3247, -0.9014, 0.2927, -0.3694,
-0.6938, 0.8013, -0.3708, -0.9585, -0.6216, 0.9648, 0.9751, 0.9736,
0.8823, -0.2338, 0.9819, 0.8662, -0.8271, 0.3333, -1.0000, 0.9790,
0.9980, 0.9224, 0.9678, 0.8452, 0.9883, -0.8350, 0.9995, 0.9844,
0.7563, -0.9966, -0.9985, 0.7476, -0.9961, 0.9482, -0.9995, 0.5366,
0.9883, -0.9966, -0.4829, -0.6758, -0.9985, 0.6064, 0.8843, 0.0119,
0.6382, 0.9170, -0.8911, -0.7402, 0.9370, -0.9907, 0.9956, 0.0325,
0.9355, -0.9849, -0.9287, -0.4646, 0.9229, 0.7656, -0.9282, 0.9922,
-0.6016, -0.8965, 0.9053, -0.9941, -0.9976, -0.9756, -0.9307, -0.9751,
-0.9995, -0.8569, -0.3904, 0.5825, -0.3635, 0.3857, 0.9575, -0.7642,
0.7793, -0.6816, -0.6924, -0.4766, -0.9419, -0.9785, 0.9751, -0.2751,
0.1073, 0.9160, -0.7388, -0.9033, -0.5898, 0.8232, 1.0000, 0.4717,
0.8188, -0.0664, 0.7432, 0.9946, 0.6997, 0.5562, -0.9629, 0.9634,
-0.7881, 0.9980, -0.1428, 0.9897, 0.8804, 0.8687, -0.7876, 0.4143,
0.9941, 0.6748, -0.6333, -0.9673, 0.9995, -0.9673, 0.6753, -0.9600,
0.9160, 0.7568, -0.9678, -0.3818, -0.8398, 0.8223, -0.9995, 0.6279,
-0.9233, -0.9771, -0.9487, 0.9771, -0.8481, -0.9995, -0.9541, -0.8911,
0.9980, -0.9941, -0.0793, 0.9902, -0.8760, 0.7744, -0.1143, -0.9053,
-0.9985, -0.7446, -0.8032, 0.9575, -1.0000, 0.3369, -0.9902, -0.8706,
-0.6304, 0.4546, -1.0000, -0.6650, 0.9194, -0.9985, -0.9995, -0.9233,
-0.9360, 0.9961, 0.9370, 0.9980, -0.1613, -0.9521, 0.9985, 0.8398,
-0.9995, 0.9995, -0.9937, 0.0573, -0.9634, 0.9653, -0.9961, -0.9897,
-0.9219, -0.8667, -0.2177, -0.1460, -0.7173, 1.0000, -0.3320, 0.9429,
0.9878, 0.9839, 0.7310, 0.9888, -0.7354, -1.0000, -0.9878, 0.8521,
-0.9937, 0.9375, -0.7183, -0.9629, -0.3796, 0.8994, 0.9980, 0.9941,
0.9917, -0.9849, 0.6064, -0.9966, 0.4434, -0.2942, -0.9087, -0.2111,
-0.9849, -0.9849, 0.0546, -0.8501, 0.9263, -0.9941, -0.8735, 0.0949,
0.9561, -0.9961, 1.0000, 0.4231, 0.9282, 0.9248, 0.7534, -0.6909,
-0.9434, 0.4331, -0.9146, -0.9341, 0.6519, -0.9468, -0.9282, 0.8105,
0.8633, -0.3210, 0.6997, -0.6699, 0.4756, -0.1262, -0.9692, -0.1296,
0.9624, -0.9194, -0.9868, -0.9883, -0.9907, 0.9985, -0.6230, -0.7505,
-0.8921, 0.9160, 0.9883, -0.9658, 0.9995, 0.9409, 0.8672, -0.9985,
-0.4983, 0.7495, -0.6870, -0.0656, -0.9521, -0.9980, 0.1368, -0.4819,
0.8516, 0.9126, -0.6035, 0.9009, 0.2666, 0.5112, -0.9927, 0.4546,
-0.9448, 0.9565, -0.9121, 0.9521, -0.9409, 0.9785, 0.8384, 0.9634,
-0.4983, -0.6465, -0.2756, 0.9785, -0.9360, 0.9756, -0.0225, 0.9556,
-0.9829, 0.9941, -0.9302, -0.3904, 0.2634, -0.0047, -0.7627, 0.5195,
0.2766, -0.8809, -0.4961, -0.2092, -0.9756, 0.9980, 0.9751, 0.4597,
0.9424, 0.9922, 0.2122, 0.0908, -0.3999, -0.1343, -0.8984, -0.5225,
0.3979, -0.8335, -0.8091, -0.8652, -0.8203, -0.4172, -0.3828, -0.3403,
-0.9268, -0.2136, -0.9697, 0.2062, 0.9360, 0.3284, -0.0222, -1.0000,
-0.7310, -0.0500, 0.9810, 0.9019, -0.8857, 0.9829, 0.9673, -0.7954,
-0.4648, 0.9995, -0.5083, 0.8545, -0.9536, 0.7280, 0.9409, -0.9512,
0.7026, -0.9175, -0.9961, 0.9775, 0.9976, -0.9917, 0.4272, -0.9888,
-0.3137, 0.5938, 0.7715, -0.9648, 0.3923, -0.9829, -0.5874, 0.9995,
0.9502, 0.9917, 0.9712, 0.9785, 0.8574, 0.6187, 0.6089, 0.1808,
-0.8877, 0.9985, -0.8950, -0.9937, 0.9468, 0.7271, -0.9600, 0.2649,
0.9604, -0.8301, 0.9487, 0.9966, -1.0000, 1.0000, 0.6318, 0.6372,
-0.0603, 0.8809, 0.2457, 0.9810, 0.8237, 0.9434, -0.9980, 0.9688,
-0.8687, 0.9121, 0.8564, 0.9980, 0.9771, -0.4453, 0.8096, 0.9355,
0.9067, -0.8203, -0.0590, -0.8315, -0.9995, 0.9883, -0.7979, -0.5039,
-1.0000, 0.3772, 0.9888, -0.8379, -0.9961, -1.0000, -0.7642, -0.9614,
0.7822, 0.9048, 0.7686, -0.9429, 0.9883, 0.9775, 0.5645, 0.5757,
0.9653, 0.9956, -0.7642, -0.9897, -0.9463, -0.9995, -0.6543, -0.1444,
-0.9209, 0.9482, -0.5830, 0.9937, 0.8364, 0.7104, 0.6074, 0.9004,
-0.9917, -0.4556, -0.9785, -0.2238, 0.9927, -0.2542, 0.9756, 0.2576,
0.9106, 0.9004, -0.9980, -0.3662, 0.8735, -0.9922, -0.9800, -0.9902,
-0.8330, 0.9229, -0.4023, 1.0000, 0.7012, -0.6826, 0.9805, -0.9355,
0.9214, 0.1700, -0.9937, -0.9902, -0.6968, -0.9980, 0.7681, 0.9956,
-0.9556, -1.0000, 0.8701, -0.9160, 0.8584, 0.9146, -0.9961, -0.3940,
0.0917, 0.9839, -0.8843, 0.9849, 0.8784, -0.7676, 1.0000, -0.8091,
-0.9976, 0.7627, -0.9751, -0.9980, 0.0776, -0.7764, 0.9980, -0.9941,
-0.8237, -0.6914, -0.9692, -0.9956, -0.5239, 0.0559, -0.9966, 0.9692,
-0.9888, 0.4661, -0.4932, -0.3745, -0.9966, -0.9355, -0.9995, -0.1835,
0.8389, -0.6475, 0.0099, 0.8755, 0.7563, -0.8828, -0.8140, -0.8843,
0.9175, -0.2693, 0.9634, 0.5767, 0.6890, -0.9976, -0.9341, -0.5059,
-0.5952, -1.0000, -0.7495, -0.9194, 0.9014, 0.9282, -0.8584, -0.9517,
-0.9897, -0.9771, 1.0000, -1.0000, -0.9902, -0.0882, -0.9966, -0.9966,
-0.2330, 0.6064, 0.9976, -0.9824, -0.0184, -0.9536, 0.9653, -0.9902,
-0.8076, -0.8760, 0.3142, 0.8623, -0.9775, 0.8481, -0.9995, -0.9341,
-0.9072, -0.9507, -0.3970, 1.0000, -0.9980, -0.7412, -0.9395, -0.8662,
-0.1250, -0.9976, -0.9780, -1.0000, -0.9849, 0.6733, -0.9980, -0.9707,
0.8701, -1.0000, 0.1927, -0.8945, -0.9180, 0.8481, -0.9868, -0.8633]],
device='cuda:0', dtype=torch.float16)
Following up on this again.
how to solve it? hf_bge_embedding is completely different from trtllm_bge_embedding.
how to solve it? hf_bge_embedding is completely different from trtllm_bge_embedding.
Oops, the type of trtllm inputs is trt.INT32, but the output type of AutoTokenizer is torch.LONG. Inconsistent input data types lead to inconsistent results.
same issue, any update?
as @longshuicui mentioned above, torch.LONG and torch.INT does make a difference, see https://github.com/NVIDIA/TensorRT-LLM/blob/f6821ee393be6ec92234f9bb47a4b09f6738050b/examples/enc_dec/run.py#L209
I saw the above run.py may also suffer from this. Does this help?
@symphonylyh I cast input_ids to int32 by input_ids = input_ids.to(torch.int32), got same result.
I use torch.randint() to generate input_ids instead of tokenizer, for the same input_ids, the result of trtllm model and hf model still different.
import argparse
import json
import os
import time
# isort: off
import torch
import tensorrt as trt
# isort: on
import string
import random
import tensorrt_llm
from tensorrt_llm import logger
from tensorrt_llm.runtime import Session, TensorInfo
from build import get_engine_name # isort:skip
from transformers import AutoTokenizer
def trt_dtype_to_torch(dtype):
if dtype == trt.float16:
return torch.float16
elif dtype == trt.float32:
return torch.float32
elif dtype == trt.int32:
return torch.int32
else:
raise TypeError("%s is not supported" % dtype)
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('--log_level', type=str, default='info')
parser.add_argument('--engine_dir', type=str, default='bert_outputs')
return parser.parse_args()
if __name__ == '__main__':
args = parse_arguments()
tensorrt_llm.logger.set_level(args.log_level)
config_path = os.path.join(args.engine_dir, 'config.json')
with open(config_path, 'r') as f:
config = json.load(f)
assert config["plugin_config"]["remove_input_padding"] == False, \
"Please refer to run_remove_input_padding.py for running BERT models with remove_input_padding enabled"
dtype = config['builder_config']['precision']
world_size = config['builder_config']['tensor_parallel']
assert world_size == tensorrt_llm.mpi_world_size(), \
f'Engine world size ({world_size}) != Runtime world size ({tensorrt_llm.mpi_world_size()})'
model_name = config['builder_config']['name']
runtime_rank = tensorrt_llm.mpi_rank() if world_size > 1 else 0
runtime_mapping = tensorrt_llm.Mapping(world_size,
runtime_rank,
tp_size=world_size)
torch.cuda.set_device(runtime_rank % runtime_mapping.gpus_per_node)
serialize_path = get_engine_name(model_name, dtype, world_size,
runtime_rank)
serialize_path = os.path.join(args.engine_dir, serialize_path)
stream = torch.cuda.current_stream().cuda_stream
logger.info(f'Loading engine from {serialize_path}')
with open(serialize_path, 'rb') as f:
engine_buffer = f.read()
logger.info(f'Creating session from engine')
session = Session.from_serialized_engine(engine_buffer)
tokenizer=AutoTokenizer.from_pretrained('BAAI/bge-large-zh-v1.5')
for i in range(1):
batch_size =1
seq_len = 20
input_ids = torch.randint(0, tokenizer.vocab_size, (batch_size, seq_len)).int().cuda()
input_lengths = seq_len * torch.ones(
(batch_size, ), dtype=torch.int32, device='cuda')
token_type_ids = torch.zeros((batch_size, seq_len)).int().cuda()
inputs = {
'input_ids': input_ids,
'input_lengths': input_lengths,
'token_type_ids': token_type_ids
}
output_info = session.infer_shapes([
TensorInfo('input_ids', trt.DataType.INT32, input_ids.shape),
TensorInfo('input_lengths', trt.DataType.INT32,
input_lengths.shape),
TensorInfo('token_type_ids', trt.DataType.INT32,
token_type_ids.shape),
])
outputs = {
t.name: torch.empty(tuple(t.shape),
dtype=trt_dtype_to_torch(t.dtype),
device='cuda')
for t in output_info
}
if (model_name == 'BertModel' or model_name == 'RobertaModel'):
output_name = 'hidden_states'
elif (model_name == 'BertForQuestionAnswering'
or model_name == 'RobertaForQuestionAnswering'):
output_name = 'logits'
elif (model_name == 'BertForSequenceClassification'
or model_name == 'RobertaForSequenceClassification'):
output_name = 'logits'
else:
assert False, f"Unknown BERT model {model_name}"
assert output_name in outputs, f'{output_name} not found in outputs, check if build.py set the name correctly'
ok = session.run(inputs, outputs, stream)
assert ok, "Runtime execution failed"
torch.cuda.synchronize()
res = outputs[output_name]
trtllm_embeddings = res[:, 0, :]
from transformers import AutoModel
model = AutoModel.from_pretrained("/data1/nfs15/nfs/bigdata/zhanglei/ml/inference/model-demo/hf/BAAI/bge-large-zh-v1.5")
model.cuda().eval()
model_output = model(input_ids=input_ids,token_type_ids=token_type_ids)
hf_embeddings = model_output[0][:, 0, :]
cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
print(f"fp Sentence embeddings cosine: {cos(hf_embeddings, trtllm_embeddings)}")
@Lzhang-hub Do you still have the question? If not, we will close it soon.