TensorRT-LLM icon indicating copy to clipboard operation
TensorRT-LLM copied to clipboard

BERT Model is Inaccurate

Open Broyojo opened this issue 1 year ago • 4 comments

System Info

  • CPU: i9 9900k
  • GPU: RTX 4090
  • TensorRT-LLM Version: 0.9.0.dev2024022000
  • Cuda Version: Cuda 12.3
  • Driver Version: 545.29.06
  • OS: Arch Linux, kernel version 6.7.5

Who can help?

@byshiue

Information

  • [X] The official example scripts
  • [X] My own modified scripts

Tasks

  • [ ] An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • [X] My own task or dataset (give details below)

Reproduction

  1. Create a build script to build the TRT-LLM engine (build.py):
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
from collections import OrderedDict

import numpy as np
import tensorrt as trt
import tensorrt_llm
import torch
from tensorrt_llm.builder import Builder
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.network import net_guard
from tensorrt_llm.plugin.plugin import ContextFMHAType
from transformers import AutoModel, BertConfig, BertPreTrainedModel


def parse_arguments():
    parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
    parser.add_argument(
        "--world_size", type=int, default=1, help="Tensor parallelism size"
    )
    parser.add_argument("--rank", type=int, default=0)
    parser.add_argument(
        "--dtype", type=str, default="float16", choices=["float16", "float32"]
    )
    parser.add_argument("--timing_cache", type=str, default="model.cache")
    parser.add_argument(
        "--profiling_verbosity",
        type=str,
        default="layer_names_only",
        choices=["layer_names_only", "detailed", "none"],
        help="The profiling verbosity for the generated TRT engine. Set to detailed can inspect tactic choices and kernel parameters.",
    )
    parser.add_argument("--log_level", type=str, default="info")
    parser.add_argument("--max_batch_size", type=int, default=256)
    parser.add_argument("--max_input_len", type=int, default=512)
    parser.add_argument("--gpus_per_node", type=int, default=1)
    parser.add_argument("--output_dir", type=str)
    parser.add_argument(
        "--use_bert_attention_plugin",
        nargs="?",
        const="float16",
        type=str,
        default=False,
        choices=["float16", "float32"],
    )
    parser.add_argument(
        "--use_gemm_plugin",
        nargs="?",
        const="float16",
        type=str,
        default=False,
        choices=["float16", "float32"],
    )
    parser.add_argument("--enable_qk_half_accum", default=False, action="store_true")
    parser.add_argument("--enable_context_fmha", default=False, action="store_true")
    parser.add_argument(
        "--enable_context_fmha_fp32_acc", default=False, action="store_true"
    )
    parser.add_argument("--model", type=str, help="Model id")
    return parser.parse_args()


def get_engine_name(model, dtype, tp_size, rank):
    return "{}_{}_tp{}_rank{}.engine".format(
        model.replace("/", "--"), dtype, tp_size, rank
    )


def extract_layer_idx(name):
    ss = name.split(".")
    for s in ss:
        if s.isdigit():
            return s
    return None


def split(v, tp_size, idx, dim=0):
    if tp_size == 1:
        return v
    if len(v.shape) == 1:
        return np.ascontiguousarray(np.split(v, tp_size)[idx].copy())
    elif len(v.shape) == 2:
        return np.ascontiguousarray(np.split(v, tp_size, axis=dim)[idx].copy())
    return None


def load_from_hf_model(
    tensorrt_llm_model: tensorrt_llm.models.BertModel,
    hf_model: BertPreTrainedModel,
    hf_model_config: BertConfig,
    rank=0,
    tensor_parallel=1,
    fp16=False,
):
    qkv_weight = [[None, None, None] for _ in range(hf_model_config.num_hidden_layers)]

    qkv_bias = [[None, None, None] for _ in range(hf_model_config.num_hidden_layers)]

    torch_dtype = torch.float16 if fp16 else torch.float32
    for k, v in hf_model.state_dict().items():
        v = v.to(torch_dtype).cpu().numpy()
        if "embeddings.word_embeddings.weight" in k:
            tensorrt_llm_model.embedding.vocab_embedding.weight.value = v
        elif "embeddings.position_embeddings.weight" in k:
            tensorrt_llm_model.embedding.position_embedding.weight.value = v
        elif "embeddings.token_type_embeddings.weight" in k:
            tensorrt_llm_model.embedding.token_embedding.weight.value = v
        elif "embeddings.LayerNorm.weight" in k:
            tensorrt_llm_model.embedding.embedding_ln.weight.value = v
        elif "embeddings.LayerNorm.bias" in k:
            tensorrt_llm_model.embedding.embedding_ln.bias.value = v
        else:
            layer_idx = extract_layer_idx(k)
            if layer_idx is None:
                continue
            idx = int(layer_idx)
            if "attention.output.dense.weight" in k:
                tensorrt_llm_model.layers[idx].attention.dense.weight.value = split(
                    v, tensor_parallel, rank, dim=1
                )
            elif "attention.output.dense.bias" in k:
                tensorrt_llm_model.layers[idx].attention.dense.bias.value = v
            elif "attention.output.LayerNorm.weight" in k:
                tensorrt_llm_model.layers[idx].input_layernorm.weight.value = v
            elif "attention.output.LayerNorm.bias" in k:
                tensorrt_llm_model.layers[idx].input_layernorm.bias.value = v
            elif "intermediate.dense.weight" in k:
                tensorrt_llm_model.layers[idx].mlp.fc.weight.value = split(
                    v, tensor_parallel, rank
                )
            elif "intermediate.dense.bias" in k:
                tensorrt_llm_model.layers[idx].mlp.fc.bias.value = split(
                    v, tensor_parallel, rank
                )
            elif "output.dense.weight" in k:
                tensorrt_llm_model.layers[idx].mlp.proj.weight.value = split(
                    v, tensor_parallel, rank, dim=1
                )
            elif "output.dense.bias" in k:
                tensorrt_llm_model.layers[idx].mlp.proj.bias.value = v
            elif "output.LayerNorm.weight" in k:
                tensorrt_llm_model.layers[idx].post_layernorm.weight.value = v
            elif "output.LayerNorm.bias" in k:
                tensorrt_llm_model.layers[idx].post_layernorm.bias.value = v
            elif "attention.self.query.weight" in k:
                qkv_weight[idx][0] = v
            elif "attention.self.query.bias" in k:
                qkv_bias[idx][0] = v
            elif "attention.self.key.weight" in k:
                qkv_weight[idx][1] = v
            elif "attention.self.key.bias" in k:
                qkv_bias[idx][1] = v
            elif "attention.self.value.weight" in k:
                qkv_weight[idx][2] = v
            elif "attention.self.value.bias" in k:
                qkv_bias[idx][2] = v

    for i in range(hf_model_config.num_hidden_layers):
        tensorrt_llm_model.layers[i].attention.qkv.weight.value = split(
            np.concatenate(qkv_weight[i]), tensor_parallel, rank
        )
        tensorrt_llm_model.layers[i].attention.qkv.bias.value = split(
            np.concatenate(qkv_bias[i]), tensor_parallel, rank
        )


def main():
    args = parse_arguments()
    tensorrt_llm.logger.set_level(args.log_level)
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    bs_range = [1, (args.max_batch_size + 1) // 2, args.max_batch_size]
    inlen_range = [1, (args.max_input_len + 1) // 2, args.max_input_len]
    torch_dtype = torch.float16 if args.dtype == "float16" else torch.float32
    trt_dtype = trt.float16 if args.dtype == "float16" else trt.float32

    builder = Builder()
    builder_config = builder.create_builder_config(
        name=args.model,
        precision=args.dtype,
        timing_cache=args.timing_cache,
        profiling_verbosity=args.profiling_verbosity,
        tensor_parallel=args.world_size,  # TP only
        max_batch_size=args.max_batch_size,
        max_input_len=args.max_input_len,
    )

    hf_model = (
        AutoModel.from_pretrained(args.model, torch_dtype=torch_dtype)
        .cuda()
        .to(torch_dtype)
        .eval()
    )

    output_name = "hidden_states"

    tensorrt_llm_bert = tensorrt_llm.models.BertModel(
        num_layers=hf_model.config.num_hidden_layers,
        num_heads=hf_model.config.num_attention_heads,
        hidden_size=hf_model.config.hidden_size,
        vocab_size=hf_model.config.vocab_size,
        hidden_act=hf_model.config.hidden_act,
        max_position_embeddings=hf_model.config.max_position_embeddings,
        type_vocab_size=hf_model.config.type_vocab_size,
        pad_token_id=hf_model.config.pad_token_id,
        is_roberta=False,
        mapping=Mapping(
            world_size=args.world_size, rank=args.rank, tp_size=args.world_size
        ),  # TP only
        dtype=trt_dtype,
    )
    load_from_hf_model(
        tensorrt_llm_bert,
        hf_model,
        hf_model.config,
        rank=args.rank,
        tensor_parallel=args.world_size,
        fp16=(args.dtype == "float16"),
    )

    network = builder.create_network()
    network.plugin_config.to_legacy_setting()
    if args.use_bert_attention_plugin:
        network.plugin_config.set_bert_attention_plugin(
            dtype=args.use_bert_attention_plugin
        )
    if args.use_gemm_plugin:
        network.plugin_config.set_gemm_plugin(dtype=args.use_gemm_plugin)
    if args.enable_qk_half_accum:
        network.plugin_config.enable_qk_half_accum()
    assert not (args.enable_context_fmha and args.enable_context_fmha_fp32_acc)
    if args.enable_context_fmha:
        network.plugin_config.set_context_fmha(ContextFMHAType.enabled)
    if args.enable_context_fmha_fp32_acc:
        network.plugin_config.set_context_fmha(ContextFMHAType.enabled_with_fp32_acc)
    if args.world_size > 1:
        network.plugin_config.set_nccl_plugin(args.dtype)
    with net_guard(network):
        # Prepare
        network.set_named_parameters(tensorrt_llm_bert.named_parameters())

        # Forward
        input_ids = tensorrt_llm.Tensor(
            name="input_ids",
            dtype=trt.int32,
            shape=[-1, -1],
            dim_range=OrderedDict(
                [("batch_size", [bs_range]), ("input_len", [inlen_range])]
            ),
        )

        # also called segment_ids
        token_type_ids = tensorrt_llm.Tensor(
            name="token_type_ids",
            dtype=trt.int32,
            shape=[-1, -1],
            dim_range=OrderedDict(
                [("batch_size", [bs_range]), ("input_len", [inlen_range])]
            ),
        )

        input_lengths = tensorrt_llm.Tensor(
            name="input_lengths",
            dtype=trt.int32,
            shape=[-1],
            dim_range=OrderedDict([("batch_size", [bs_range])]),
        )

        output = tensorrt_llm_bert(
            input_ids=input_ids,
            input_lengths=input_lengths,
            token_type_ids=token_type_ids,
        )

        output_dtype = trt.float16 if args.dtype == "float16" else trt.float32
        output.mark_output(output_name, output_dtype)

    # Network -> Engine
    engine = builder.build_engine(network, builder_config)
    assert engine is not None, "Failed to build engine."
    engine_file = os.path.join(
        args.output_dir,
        get_engine_name(args.model, args.dtype, args.world_size, args.rank),
    )
    with open(engine_file, "wb") as f:
        f.write(engine)
    builder.save_config(builder_config, os.path.join(args.output_dir, "config.json"))


if __name__ == "__main__":
    main()
  1. Run the build script with this command:
python3 build.py \
    --dtype float16 \
    --max_batch_size 512 \
    --max_input_len 512 \
    --gpus_per_node 1 \
    --output_dir trtllm_bge \
    --model BAAI/bge-base-en-v1.5 \
    --use_bert_attention_plugin float16 \
    --use_gemm_plugin float16 \
    --enable_context_fmha \
  1. Create a run script to compare the TRT-LLM engine with the Huggingface model (run.py):
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import json
import os
import time

import tensorrt as trt
import tensorrt_llm
import torch
from build import get_engine_name
from datasets import concatenate_datasets, load_dataset
from tensorrt_llm import logger
from tensorrt_llm.runtime import Session, TensorInfo
from torch.nn.functional import cosine_similarity
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer


def trt_dtype_to_torch(dtype):
    if dtype == trt.float16:
        return torch.float16
    elif dtype == trt.float32:
        return torch.float32
    elif dtype == trt.int32:
        return torch.int32
    else:
        raise TypeError("%s is not supported" % dtype)


def parse_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument("--log_level", type=str, default="info")
    parser.add_argument("--engine_dir", type=str, required=True)
    parser.add_argument("--dataset", type=str, default="shreyasharma/sentences_truth")
    parser.add_argument("--batch_size", type=int, default=256)
    parser.add_argument("--run_hf", action="store_true")
    parser.add_argument("--run_trtllm", action="store_true")
    parser.add_argument("--remove_columns", default="labels", type=str)
    parser.add_argument("--target_column", default="sentences", type=str)
    return parser.parse_args()


def main():
    args = parse_arguments()
    tensorrt_llm.logger.set_level(args.log_level)

    config_path = os.path.join(args.engine_dir, "config.json")
    with open(config_path, "r") as f:
        config = json.load(f)
    dtype = config["builder_config"]["precision"]
    world_size = config["builder_config"]["tensor_parallel"]
    assert (
        world_size == tensorrt_llm.mpi_world_size()
    ), f"Engine world size ({world_size}) != Runtime world size ({tensorrt_llm.mpi_world_size()})"

    model_name = config["builder_config"]["name"]
    runtime_rank = tensorrt_llm.mpi_rank() if world_size > 1 else 0

    runtime_mapping = tensorrt_llm.Mapping(world_size, runtime_rank, tp_size=world_size)
    torch.cuda.set_device(runtime_rank % runtime_mapping.gpus_per_node)

    serialize_path = get_engine_name(model_name, dtype, world_size, runtime_rank)
    serialize_path = os.path.join(args.engine_dir, serialize_path)

    stream = torch.cuda.current_stream().cuda_stream
    logger.info(f"Loading engine from {serialize_path}")
    with open(serialize_path, "rb") as f:
        engine_buffer = f.read()
    logger.info(f"Creating session from engine")
    session = Session.from_serialized_engine(engine_buffer)

    tokenizer = AutoTokenizer.from_pretrained(model_name)

    if args.run_hf:
        hf_model = (
            AutoModel.from_pretrained(
                model_name,
                low_cpu_mem_usage=True,
                torch_dtype=torch.float16,
            )
            .eval()
            .cuda()
        )

    dataset = load_dataset(args.dataset)
    dataset = concatenate_datasets(list(dataset.values())).remove_columns(
        args.remove_columns
    )

    dataloader = DataLoader(dataset, batch_size=args.batch_size)

    total_time_trtllm = 0
    total_trtllm = 0

    total_time_hf = 0
    total_hf = 0

    for batch in tqdm(dataloader, unit_scale=args.batch_size, unit=" samples"):
        encoded_input = tokenizer(
            batch[args.target_column],
            padding="longest",
            truncation=True,
            return_tensors="pt",
        )

        input_ids = encoded_input["input_ids"].cuda()
        input_lengths = torch.tensor(
            [seq.shape[-1] for seq in encoded_input["input_ids"]]
        ).cuda()
        token_type_ids = encoded_input["token_type_ids"].cuda()

        inputs = {
            "input_ids": input_ids,
            "input_lengths": input_lengths,
            "token_type_ids": token_type_ids,
        }

        output_info = session.infer_shapes(
            [
                TensorInfo("input_ids", trt.DataType.INT32, input_ids.shape),
                TensorInfo("input_lengths", trt.DataType.INT32, input_lengths.shape),
                TensorInfo("token_type_ids", trt.DataType.INT32, token_type_ids.shape),
            ]
        )

        outputs = {
            t.name: torch.empty(
                tuple(t.shape), dtype=trt_dtype_to_torch(t.dtype), device="cuda"
            )
            for t in output_info
        }

        output_name = "hidden_states"
        assert (
            output_name in outputs
        ), f"{output_name} not found in outputs, check if build.py set the name correctly"

        if args.run_trtllm:
            start = time.time()
            ok = session.run(inputs, outputs, stream)
            total_time_trtllm += time.time() - start
            total_trtllm += 1
            assert ok, "Runtime execution failed"
            torch.cuda.synchronize()
            res = outputs[output_name]
            trtllm_embeddings = res[:, 0, :]  # perform cls pooling

        if args.run_hf:
            with torch.inference_mode():
                start = time.time()
                model_output = hf_model(
                    input_ids=input_ids, token_type_ids=token_type_ids
                )
                total_time_hf += time.time() - start
                total_hf += 1
                hf_embeddings = model_output[0][:, 0, :]  # perform cls pooling

        if args.run_hf and args.run_trtllm:
            print(
                f"Average Cosine Distance: {(1 - cosine_similarity(trtllm_embeddings, hf_embeddings).mean().item()):.2f}"
            )

    if args.run_hf:
        print(total_time_hf / total_hf)
    if args.run_trtllm:
        print(total_time_trtllm / total_trtllm)


if __name__ == "__main__":
    main()
  1. Run this command to launch the run script:
python3 run.py \
    --engine_dir trtllm_bge/ \
    --batch_size 512 \
    --dataset lighteval/med_paragraph_simplification \
    --remove_columns answer \
    --target_column query \
    --run_hf \
    --run_trtllm \
  1. Observe the measured cosine distance which is not close to 0.

Expected behavior

It is expected when running the huggingface model and the TRT-LLM model that the cosine distance between the results should be relatively close to 0 as both models were run on the same data.

actual behavior

The measured cosine distance between the huggingface model and TRT-LLM model outputs is not very close to 0, being approximately 0.5. While with random high dimensional vectors, a cosine distance of 0.5 means that the vectors are very similar (see distribution below), it is puzzling why this cosine distance for the models' embeddings is not closer to 0.

cosine_distance

additional notes

I believe this could be some kind of numerical error in the Bert model. I've tried the model in float32, float16, and with the various plugins disabled or enabled. I've also tried an uninitialized TRT-LLM model compared to an initialized huggingface model and I observe that the cosine distance is higher.

Broyojo avatar Feb 27 '24 23:02 Broyojo

Currently am having same issue.

timmytwoteeth avatar Mar 01 '24 23:03 timmytwoteeth

Facing the same issue, for finetuned Bert model for sequence classification. The logits are wrong

parikshitsaikia1619 avatar Mar 03 '24 19:03 parikshitsaikia1619

I am encountering the same issue when using BertForSequenceClassification. It seems to be a decimal precision problem when I debug the pooler output. Does anyone have any suggestions for further debugging? @byshiue @symphonylyh

huggingface(pytorch) pooler output:
pooled_output=tensor([[ 8.7109e-01,  8.5156e-01, -6.5918e-01, -9.4678e-01, -9.9756e-01,
          6.0010e-01, -8.6914e-01, -9.8438e-01, -9.9951e-01,  2.9297e-01,
         -5.8301e-01, -9.4971e-01,  9.9414e-01, -8.8818e-01, -9.9268e-01,
         -8.5791e-01, -8.3594e-01,  9.1602e-01,  9.9170e-01,  9.7998e-01,
         -2.5708e-01,  9.9805e-01,  9.7070e-01,  9.7754e-01, -9.3945e-01,
          5.1660e-01, -9.9902e-01, -9.9854e-01,  7.8125e-01, -3.3081e-01,
         -7.6953e-01, -9.6631e-01, -5.7373e-01, -9.6729e-01,  8.8135e-01,
          9.8926e-01, -9.8193e-01,  9.1748e-01, -8.1982e-01,  8.4778e-02,
         -3.1104e-01,  9.6289e-01,  6.6602e-01,  9.9756e-01, -6.9629e-01,
          9.5508e-01,  9.6826e-01,  9.9707e-01, -9.3555e-01, -9.9561e-01,
          9.5654e-01,  9.5117e-01, -9.9414e-01,  7.2607e-01, -9.9365e-01,
          6.7139e-01, -7.5146e-01, -9.6094e-01, -9.9023e-01,  7.2510e-01,
          3.6987e-02,  9.8389e-01,  4.5386e-01,  4.1797e-01, -5.1318e-01,
         -6.9336e-01, -9.0430e-01, -9.8096e-01,  8.4668e-01, -9.3896e-01,
          6.7480e-01,  9.4092e-01, -9.0527e-01,  4.3506e-01, -7.7637e-01,
          4.8682e-01,  4.2651e-01,  8.8086e-01, -2.6587e-01, -9.9170e-01,
         -9.9316e-01, -9.9609e-01,  9.9365e-01,  1.4819e-01,  9.8926e-01,
          8.3057e-01, -9.9609e-01,  5.6250e-01, -9.6533e-01,  8.0225e-01,
          9.9365e-01,  9.7900e-01,  9.7998e-01, -3.0225e-01,  7.6221e-01,
          9.8779e-01, -9.9023e-01, -4.3384e-01,  9.9805e-01,  9.0381e-01,
          7.4463e-01, -1.9946e-01,  1.3721e-01,  8.5791e-01,  9.1455e-01,
         -9.7705e-01,  9.9707e-01,  4.0039e-01, -3.3447e-02, -9.8779e-01,
          8.7891e-01, -6.2061e-01,  1.0000e+00, -9.7949e-01, -9.9512e-01,
         -3.2544e-01, -4.5972e-01,  9.6143e-01, -4.7778e-01,  1.0000e+00,
          5.9912e-01,  4.6704e-01, -2.3779e-01, -9.9902e-01,  9.5996e-01,
         -9.7607e-01, -8.7793e-01, -3.2886e-01, -9.9951e-01, -9.9902e-01,
          5.2979e-01, -7.2998e-01, -7.7051e-01,  9.9316e-01, -9.5312e-01,
         -9.9023e-01, -2.7759e-01,  7.6416e-01, -9.6289e-01, -9.9902e-01,
          2.6440e-01,  9.9854e-01,  9.8193e-01,  2.1033e-01, -9.6973e-01,
          9.6191e-01, -9.5020e-01, -5.9082e-01,  7.3193e-01,  9.8926e-01,
          3.9276e-02, -6.1523e-01,  4.3213e-01,  1.9250e-01, -4.0454e-01,
         -9.8340e-01, -8.7744e-01,  4.1870e-01, -9.1943e-01,  9.9707e-01,
          9.1309e-01, -7.5098e-01, -6.7627e-01,  6.0303e-01,  9.6533e-01,
         -8.7695e-01, -9.8389e-01, -9.9512e-01, -9.7705e-01, -9.4629e-01,
         -1.6797e-01, -9.7412e-01,  8.2031e-01,  8.8916e-01,  7.0898e-01,
          9.0332e-01,  5.8594e-01,  3.3521e-01,  9.9561e-01, -8.2471e-01,
          6.8604e-01, -4.4189e-01,  6.0059e-01,  9.7314e-01,  3.5205e-01,
          8.7695e-01,  8.6377e-01,  3.6438e-02, -3.2300e-01, -9.0137e-01,
          2.9077e-01, -3.6987e-01, -6.9238e-01,  8.0225e-01, -3.7793e-01,
         -9.5850e-01, -6.2256e-01,  9.6436e-01,  9.7461e-01,  9.7363e-01,
          8.8184e-01, -2.3877e-01,  9.8193e-01,  8.6523e-01, -8.2617e-01,
          3.3179e-01, -1.0000e+00,  9.7949e-01,  9.9805e-01,  9.2285e-01,
          9.6777e-01,  8.4570e-01,  9.8828e-01, -8.3447e-01,  9.9902e-01,
          9.8438e-01,  7.5293e-01, -9.9707e-01, -9.9902e-01,  7.4658e-01,
         -9.9658e-01,  9.4824e-01, -9.9951e-01,  5.3857e-01,  9.8828e-01,
         -9.9658e-01, -4.8120e-01, -6.7627e-01, -9.9854e-01,  6.0400e-01,
          8.8525e-01,  1.7242e-02,  6.3965e-01,  9.1748e-01, -8.8965e-01,
         -7.4219e-01,  9.3750e-01, -9.9121e-01,  9.9561e-01,  3.5767e-02,
          9.3555e-01, -9.8535e-01, -9.2920e-01, -4.6338e-01,  9.2334e-01,
          7.6562e-01, -9.2871e-01,  9.9268e-01, -5.9814e-01, -8.9551e-01,
          9.0576e-01, -9.9365e-01, -9.9707e-01, -9.7559e-01, -9.3115e-01,
         -9.7510e-01, -9.9902e-01, -8.5596e-01, -3.9233e-01,  5.8154e-01,
         -3.6377e-01,  3.8916e-01,  9.5752e-01, -7.6465e-01,  7.7734e-01,
         -6.8066e-01, -6.9775e-01, -4.7729e-01, -9.4189e-01, -9.7852e-01,
          9.7412e-01, -2.7271e-01,  1.0754e-01,  9.1650e-01, -7.3877e-01,
         -9.0430e-01, -5.8887e-01,  8.2422e-01,  9.9951e-01,  4.7388e-01,
          8.1689e-01, -6.4819e-02,  7.4316e-01,  9.9512e-01,  6.9580e-01,
          5.5713e-01, -9.6289e-01,  9.6387e-01, -7.8662e-01,  9.9805e-01,
         -1.4307e-01,  9.8926e-01,  8.7988e-01,  8.6914e-01, -7.8857e-01,
          4.1431e-01,  9.9365e-01,  6.7773e-01, -6.3379e-01, -9.6729e-01,
          9.9902e-01, -9.6826e-01,  6.7285e-01, -9.5996e-01,  9.1504e-01,
          7.5635e-01, -9.6826e-01, -3.7817e-01, -8.3789e-01,  8.2324e-01,
         -9.9951e-01,  6.2451e-01, -9.2383e-01, -9.7656e-01, -9.4873e-01,
          9.7705e-01, -8.4814e-01, -9.9951e-01, -9.5459e-01, -8.9062e-01,
          9.9805e-01, -9.9414e-01, -7.2083e-02,  9.9023e-01, -8.7646e-01,
          7.7197e-01, -1.0931e-01, -9.0527e-01, -9.9902e-01, -7.4414e-01,
         -8.0371e-01,  9.5654e-01, -9.9951e-01,  3.3472e-01, -9.9023e-01,
         -8.7061e-01, -6.3184e-01,  4.5972e-01, -1.0000e+00, -6.6797e-01,
          9.2041e-01, -9.9854e-01, -9.9902e-01, -9.2334e-01, -9.3604e-01,
          9.9609e-01,  9.3701e-01,  9.9854e-01, -1.5649e-01, -9.5215e-01,
          9.9902e-01,  8.3984e-01, -9.9951e-01,  9.9951e-01, -9.9365e-01,
          6.0150e-02, -9.6387e-01,  9.6533e-01, -9.9609e-01, -9.8975e-01,
         -9.2188e-01, -8.6572e-01, -2.1912e-01, -1.4404e-01, -7.1582e-01,
          1.0000e+00, -3.2959e-01,  9.4385e-01,  9.8779e-01,  9.8340e-01,
          7.3242e-01,  9.8926e-01, -7.3682e-01, -9.9951e-01, -9.8730e-01,
          8.5010e-01, -9.9316e-01,  9.3799e-01, -7.1729e-01, -9.6240e-01,
         -3.8623e-01,  8.9941e-01,  9.9805e-01,  9.9414e-01,  9.9121e-01,
         -9.8535e-01,  6.0596e-01, -9.9707e-01,  4.4434e-01, -2.9663e-01,
         -9.0820e-01, -2.1484e-01, -9.8535e-01, -9.8535e-01,  5.3009e-02,
         -8.5010e-01,  9.2627e-01, -9.9414e-01, -8.7354e-01,  9.7473e-02,
          9.5557e-01, -9.9609e-01,  1.0000e+00,  4.2261e-01,  9.2871e-01,
          9.2383e-01,  7.5439e-01, -6.9385e-01, -9.4385e-01,  4.3018e-01,
         -9.1553e-01, -9.3408e-01,  6.5332e-01, -9.4629e-01, -9.2822e-01,
          8.1006e-01,  8.6328e-01, -3.2080e-01,  6.9971e-01, -6.7285e-01,
          4.7803e-01, -1.3293e-01, -9.6924e-01, -1.2842e-01,  9.6240e-01,
         -9.1846e-01, -9.8730e-01, -9.8877e-01, -9.9121e-01,  9.9902e-01,
         -6.2207e-01, -7.5195e-01, -8.9258e-01,  9.1553e-01,  9.8828e-01,
         -9.6680e-01,  9.9902e-01,  9.4043e-01,  8.6670e-01, -9.9902e-01,
         -4.9634e-01,  7.4902e-01, -6.8750e-01, -6.4392e-02, -9.5215e-01,
         -9.9805e-01,  1.3708e-01, -4.8267e-01,  8.5205e-01,  9.1357e-01,
         -6.0449e-01,  9.0137e-01,  2.6465e-01,  5.1221e-01, -9.9268e-01,
          4.5044e-01, -9.4482e-01,  9.5654e-01, -9.1162e-01,  9.5312e-01,
         -9.4043e-01,  9.7949e-01,  8.3691e-01,  9.6338e-01, -5.0146e-01,
         -6.4502e-01, -2.7832e-01,  9.7803e-01, -9.3750e-01,  9.7559e-01,
         -1.9638e-02,  9.5557e-01, -9.8340e-01,  9.9463e-01, -9.3018e-01,
         -3.9380e-01,  2.6514e-01, -8.4400e-04, -7.6221e-01,  5.2100e-01,
          2.7612e-01, -8.8184e-01, -4.9292e-01, -2.0923e-01, -9.7607e-01,
          9.9756e-01,  9.7461e-01,  4.5801e-01,  9.4336e-01,  9.9219e-01,
          2.1497e-01,  9.4116e-02, -3.9551e-01, -1.3989e-01, -8.9844e-01,
         -5.2197e-01,  4.0259e-01, -8.3496e-01, -8.0859e-01, -8.6572e-01,
         -8.1934e-01, -4.1895e-01, -3.8892e-01, -3.3838e-01, -9.2627e-01,
         -2.1216e-01, -9.6973e-01,  2.0593e-01,  9.3750e-01,  3.2788e-01,
         -2.3346e-02, -1.0000e+00, -7.3096e-01, -5.5542e-02,  9.8145e-01,
          9.0088e-01, -8.8623e-01,  9.8340e-01,  9.6729e-01, -7.9590e-01,
         -4.6753e-01,  9.9902e-01, -5.0928e-01,  8.5400e-01, -9.5361e-01,
          7.2705e-01,  9.4043e-01, -9.5117e-01,  7.0215e-01, -9.1895e-01,
         -9.9658e-01,  9.7754e-01,  9.9707e-01, -9.9170e-01,  4.2505e-01,
         -9.8926e-01, -3.0664e-01,  5.9473e-01,  7.7197e-01, -9.6484e-01,
          3.9551e-01, -9.8340e-01, -5.8740e-01,  9.9902e-01,  9.5020e-01,
          9.9170e-01,  9.7168e-01,  9.7852e-01,  8.5742e-01,  6.1963e-01,
          6.0742e-01,  1.8396e-01, -8.8818e-01,  9.9902e-01, -8.9453e-01,
         -9.9365e-01,  9.4629e-01,  7.2852e-01, -9.5996e-01,  2.6587e-01,
          9.6045e-01, -8.3057e-01,  9.4922e-01,  9.9658e-01, -1.0000e+00,
          1.0000e+00,  6.3086e-01,  6.3770e-01, -5.7526e-02,  8.8037e-01,
          2.4841e-01,  9.8145e-01,  8.2373e-01,  9.4336e-01, -9.9805e-01,
          9.6826e-01, -8.6963e-01,  9.1211e-01,  8.5547e-01,  9.9756e-01,
          9.7705e-01, -4.4751e-01,  8.0859e-01,  9.3652e-01,  9.0625e-01,
         -8.2031e-01, -5.5786e-02, -8.3057e-01, -9.9902e-01,  9.8828e-01,
         -7.9736e-01, -5.0098e-01, -1.0000e+00,  3.8208e-01,  9.8877e-01,
         -8.3838e-01, -9.9561e-01, -9.9951e-01, -7.6660e-01, -9.6094e-01,
          7.8223e-01,  9.0430e-01,  7.6758e-01, -9.4336e-01,  9.8779e-01,
          9.7754e-01,  5.6885e-01,  5.7471e-01,  9.6533e-01,  9.9512e-01,
         -7.6660e-01, -9.8926e-01, -9.4580e-01, -9.9951e-01, -6.5186e-01,
         -1.4417e-01, -9.2090e-01,  9.4775e-01, -5.8203e-01,  9.9365e-01,
          8.3740e-01,  7.1094e-01,  6.0791e-01,  9.0088e-01, -9.9170e-01,
         -4.5874e-01, -9.7852e-01, -2.2327e-01,  9.9268e-01, -2.5488e-01,
          9.7559e-01,  2.5415e-01,  9.1016e-01,  9.0088e-01, -9.9756e-01,
         -3.6914e-01,  8.7500e-01, -9.9170e-01, -9.7998e-01, -9.8975e-01,
         -8.3301e-01,  9.2432e-01, -4.0381e-01,  1.0000e+00,  6.9775e-01,
         -6.8115e-01,  9.8096e-01, -9.3652e-01,  9.2139e-01,  1.6858e-01,
         -9.9365e-01, -9.9072e-01, -6.9580e-01, -9.9854e-01,  7.6807e-01,
          9.9512e-01, -9.5557e-01, -9.9951e-01,  8.7012e-01, -9.1553e-01,
          8.5889e-01,  9.1455e-01, -9.9561e-01, -3.9673e-01,  9.5947e-02,
          9.8389e-01, -8.8281e-01,  9.8535e-01,  8.7842e-01, -7.6709e-01,
          1.0000e+00, -8.0859e-01, -9.9707e-01,  7.5830e-01, -9.7461e-01,
         -9.9854e-01,  7.5073e-02, -7.7881e-01,  9.9756e-01, -9.9414e-01,
         -8.2373e-01, -6.9189e-01, -9.6924e-01, -9.9512e-01, -5.2539e-01,
          5.4138e-02, -9.9658e-01,  9.6924e-01, -9.8926e-01,  4.6899e-01,
         -4.9121e-01, -3.7476e-01, -9.9707e-01, -9.3506e-01, -9.9951e-01,
         -1.8005e-01,  8.3887e-01, -6.4844e-01,  1.2917e-02,  8.7500e-01,
          7.5830e-01, -8.8135e-01, -8.1445e-01, -8.8330e-01,  9.1699e-01,
         -2.6929e-01,  9.6338e-01,  5.8057e-01,  6.8604e-01, -9.9756e-01,
         -9.3359e-01, -5.0439e-01, -5.9668e-01, -1.0000e+00, -7.4854e-01,
         -9.1846e-01,  9.0186e-01,  9.2871e-01, -8.5938e-01, -9.5166e-01,
         -9.8926e-01, -9.7754e-01,  9.9951e-01, -9.9951e-01, -9.9023e-01,
         -8.5083e-02, -9.9707e-01, -9.9707e-01, -2.3206e-01,  6.0547e-01,
          9.9756e-01, -9.8291e-01, -1.7609e-02, -9.5410e-01,  9.6484e-01,
         -9.8975e-01, -8.0811e-01, -8.7695e-01,  3.1226e-01,  8.6230e-01,
         -9.7754e-01,  8.4766e-01, -9.9902e-01, -9.3408e-01, -9.0820e-01,
         -9.5068e-01, -4.0088e-01,  9.9951e-01, -9.9854e-01, -7.4219e-01,
         -9.3945e-01, -8.6621e-01, -1.2549e-01, -9.9707e-01, -9.7803e-01,
         -1.0000e+00, -9.8535e-01,  6.7041e-01, -9.9805e-01, -9.6973e-01,
          8.7061e-01, -1.0000e+00,  1.9214e-01, -8.9404e-01, -9.1748e-01,
          8.4863e-01, -9.8730e-01, -8.6279e-01]], dtype=torch.float16)
trt pooler output:
tensor([[ 0.8721,  0.8516, -0.6592, -0.9468, -0.9980,  0.5962, -0.8687, -0.9844,
         -1.0000,  0.2915, -0.5830, -0.9507,  0.9941, -0.8882, -0.9922, -0.8574,
         -0.8335,  0.9155,  0.9922,  0.9800, -0.2546,  0.9980,  0.9707,  0.9771,
         -0.9395,  0.5225, -0.9995, -0.9985,  0.7822, -0.3308, -0.7681, -0.9658,
         -0.5767, -0.9673,  0.8809,  0.9888, -0.9819,  0.9175, -0.8188,  0.0870,
         -0.3081,  0.9629,  0.6665,  0.9976, -0.6968,  0.9546,  0.9673,  0.9966,
         -0.9355, -0.9956,  0.9575,  0.9517, -0.9941,  0.7256, -0.9941,  0.6709,
         -0.7534, -0.9619, -0.9902,  0.7241,  0.0383,  0.9839,  0.4536,  0.4172,
         -0.5083, -0.6934, -0.9033, -0.9805,  0.8481, -0.9380,  0.6753,  0.9414,
         -0.9058,  0.4341, -0.7769,  0.4851,  0.4202,  0.8809, -0.2686, -0.9922,
         -0.9927, -0.9961,  0.9937,  0.1519,  0.9888,  0.8315, -0.9961,  0.5635,
         -0.9653,  0.8027,  0.9937,  0.9785,  0.9800, -0.3008,  0.7607,  0.9883,
         -0.9902, -0.4282,  0.9980,  0.9048,  0.7446, -0.2002,  0.1418,  0.8584,
          0.9141, -0.9771,  0.9966,  0.3979, -0.0364, -0.9878,  0.8794, -0.6177,
          1.0000, -0.9790, -0.9946, -0.3279, -0.4536,  0.9619, -0.4797,  1.0000,
          0.5986,  0.4734, -0.2367, -0.9995,  0.9600, -0.9766, -0.8789, -0.3315,
         -0.9995, -0.9995,  0.5293, -0.7285, -0.7715,  0.9937, -0.9536, -0.9902,
         -0.2771,  0.7627, -0.9629, -0.9985,  0.2561,  0.9985,  0.9819,  0.2083,
         -0.9692,  0.9619, -0.9507, -0.5913,  0.7295,  0.9888,  0.0452, -0.6172,
          0.4304,  0.1924, -0.4055, -0.9829, -0.8770,  0.4163, -0.9209,  0.9966,
          0.9121, -0.7505, -0.6733,  0.6035,  0.9653, -0.8770, -0.9839, -0.9946,
         -0.9771, -0.9468, -0.1676, -0.9751,  0.8188,  0.8882,  0.7095,  0.9033,
          0.5874,  0.3372,  0.9956, -0.8252,  0.6841, -0.4419,  0.6001,  0.9731,
          0.3528,  0.8770,  0.8633,  0.0338, -0.3247, -0.9014,  0.2927, -0.3694,
         -0.6938,  0.8013, -0.3708, -0.9585, -0.6216,  0.9648,  0.9751,  0.9736,
          0.8823, -0.2338,  0.9819,  0.8662, -0.8271,  0.3333, -1.0000,  0.9790,
          0.9980,  0.9224,  0.9678,  0.8452,  0.9883, -0.8350,  0.9995,  0.9844,
          0.7563, -0.9966, -0.9985,  0.7476, -0.9961,  0.9482, -0.9995,  0.5366,
          0.9883, -0.9966, -0.4829, -0.6758, -0.9985,  0.6064,  0.8843,  0.0119,
          0.6382,  0.9170, -0.8911, -0.7402,  0.9370, -0.9907,  0.9956,  0.0325,
          0.9355, -0.9849, -0.9287, -0.4646,  0.9229,  0.7656, -0.9282,  0.9922,
         -0.6016, -0.8965,  0.9053, -0.9941, -0.9976, -0.9756, -0.9307, -0.9751,
         -0.9995, -0.8569, -0.3904,  0.5825, -0.3635,  0.3857,  0.9575, -0.7642,
          0.7793, -0.6816, -0.6924, -0.4766, -0.9419, -0.9785,  0.9751, -0.2751,
          0.1073,  0.9160, -0.7388, -0.9033, -0.5898,  0.8232,  1.0000,  0.4717,
          0.8188, -0.0664,  0.7432,  0.9946,  0.6997,  0.5562, -0.9629,  0.9634,
         -0.7881,  0.9980, -0.1428,  0.9897,  0.8804,  0.8687, -0.7876,  0.4143,
          0.9941,  0.6748, -0.6333, -0.9673,  0.9995, -0.9673,  0.6753, -0.9600,
          0.9160,  0.7568, -0.9678, -0.3818, -0.8398,  0.8223, -0.9995,  0.6279,
         -0.9233, -0.9771, -0.9487,  0.9771, -0.8481, -0.9995, -0.9541, -0.8911,
          0.9980, -0.9941, -0.0793,  0.9902, -0.8760,  0.7744, -0.1143, -0.9053,
         -0.9985, -0.7446, -0.8032,  0.9575, -1.0000,  0.3369, -0.9902, -0.8706,
         -0.6304,  0.4546, -1.0000, -0.6650,  0.9194, -0.9985, -0.9995, -0.9233,
         -0.9360,  0.9961,  0.9370,  0.9980, -0.1613, -0.9521,  0.9985,  0.8398,
         -0.9995,  0.9995, -0.9937,  0.0573, -0.9634,  0.9653, -0.9961, -0.9897,
         -0.9219, -0.8667, -0.2177, -0.1460, -0.7173,  1.0000, -0.3320,  0.9429,
          0.9878,  0.9839,  0.7310,  0.9888, -0.7354, -1.0000, -0.9878,  0.8521,
         -0.9937,  0.9375, -0.7183, -0.9629, -0.3796,  0.8994,  0.9980,  0.9941,
          0.9917, -0.9849,  0.6064, -0.9966,  0.4434, -0.2942, -0.9087, -0.2111,
         -0.9849, -0.9849,  0.0546, -0.8501,  0.9263, -0.9941, -0.8735,  0.0949,
          0.9561, -0.9961,  1.0000,  0.4231,  0.9282,  0.9248,  0.7534, -0.6909,
         -0.9434,  0.4331, -0.9146, -0.9341,  0.6519, -0.9468, -0.9282,  0.8105,
          0.8633, -0.3210,  0.6997, -0.6699,  0.4756, -0.1262, -0.9692, -0.1296,
          0.9624, -0.9194, -0.9868, -0.9883, -0.9907,  0.9985, -0.6230, -0.7505,
         -0.8921,  0.9160,  0.9883, -0.9658,  0.9995,  0.9409,  0.8672, -0.9985,
         -0.4983,  0.7495, -0.6870, -0.0656, -0.9521, -0.9980,  0.1368, -0.4819,
          0.8516,  0.9126, -0.6035,  0.9009,  0.2666,  0.5112, -0.9927,  0.4546,
         -0.9448,  0.9565, -0.9121,  0.9521, -0.9409,  0.9785,  0.8384,  0.9634,
         -0.4983, -0.6465, -0.2756,  0.9785, -0.9360,  0.9756, -0.0225,  0.9556,
         -0.9829,  0.9941, -0.9302, -0.3904,  0.2634, -0.0047, -0.7627,  0.5195,
          0.2766, -0.8809, -0.4961, -0.2092, -0.9756,  0.9980,  0.9751,  0.4597,
          0.9424,  0.9922,  0.2122,  0.0908, -0.3999, -0.1343, -0.8984, -0.5225,
          0.3979, -0.8335, -0.8091, -0.8652, -0.8203, -0.4172, -0.3828, -0.3403,
         -0.9268, -0.2136, -0.9697,  0.2062,  0.9360,  0.3284, -0.0222, -1.0000,
         -0.7310, -0.0500,  0.9810,  0.9019, -0.8857,  0.9829,  0.9673, -0.7954,
         -0.4648,  0.9995, -0.5083,  0.8545, -0.9536,  0.7280,  0.9409, -0.9512,
          0.7026, -0.9175, -0.9961,  0.9775,  0.9976, -0.9917,  0.4272, -0.9888,
         -0.3137,  0.5938,  0.7715, -0.9648,  0.3923, -0.9829, -0.5874,  0.9995,
          0.9502,  0.9917,  0.9712,  0.9785,  0.8574,  0.6187,  0.6089,  0.1808,
         -0.8877,  0.9985, -0.8950, -0.9937,  0.9468,  0.7271, -0.9600,  0.2649,
          0.9604, -0.8301,  0.9487,  0.9966, -1.0000,  1.0000,  0.6318,  0.6372,
         -0.0603,  0.8809,  0.2457,  0.9810,  0.8237,  0.9434, -0.9980,  0.9688,
         -0.8687,  0.9121,  0.8564,  0.9980,  0.9771, -0.4453,  0.8096,  0.9355,
          0.9067, -0.8203, -0.0590, -0.8315, -0.9995,  0.9883, -0.7979, -0.5039,
         -1.0000,  0.3772,  0.9888, -0.8379, -0.9961, -1.0000, -0.7642, -0.9614,
          0.7822,  0.9048,  0.7686, -0.9429,  0.9883,  0.9775,  0.5645,  0.5757,
          0.9653,  0.9956, -0.7642, -0.9897, -0.9463, -0.9995, -0.6543, -0.1444,
         -0.9209,  0.9482, -0.5830,  0.9937,  0.8364,  0.7104,  0.6074,  0.9004,
         -0.9917, -0.4556, -0.9785, -0.2238,  0.9927, -0.2542,  0.9756,  0.2576,
          0.9106,  0.9004, -0.9980, -0.3662,  0.8735, -0.9922, -0.9800, -0.9902,
         -0.8330,  0.9229, -0.4023,  1.0000,  0.7012, -0.6826,  0.9805, -0.9355,
          0.9214,  0.1700, -0.9937, -0.9902, -0.6968, -0.9980,  0.7681,  0.9956,
         -0.9556, -1.0000,  0.8701, -0.9160,  0.8584,  0.9146, -0.9961, -0.3940,
          0.0917,  0.9839, -0.8843,  0.9849,  0.8784, -0.7676,  1.0000, -0.8091,
         -0.9976,  0.7627, -0.9751, -0.9980,  0.0776, -0.7764,  0.9980, -0.9941,
         -0.8237, -0.6914, -0.9692, -0.9956, -0.5239,  0.0559, -0.9966,  0.9692,
         -0.9888,  0.4661, -0.4932, -0.3745, -0.9966, -0.9355, -0.9995, -0.1835,
          0.8389, -0.6475,  0.0099,  0.8755,  0.7563, -0.8828, -0.8140, -0.8843,
          0.9175, -0.2693,  0.9634,  0.5767,  0.6890, -0.9976, -0.9341, -0.5059,
         -0.5952, -1.0000, -0.7495, -0.9194,  0.9014,  0.9282, -0.8584, -0.9517,
         -0.9897, -0.9771,  1.0000, -1.0000, -0.9902, -0.0882, -0.9966, -0.9966,
         -0.2330,  0.6064,  0.9976, -0.9824, -0.0184, -0.9536,  0.9653, -0.9902,
         -0.8076, -0.8760,  0.3142,  0.8623, -0.9775,  0.8481, -0.9995, -0.9341,
         -0.9072, -0.9507, -0.3970,  1.0000, -0.9980, -0.7412, -0.9395, -0.8662,
         -0.1250, -0.9976, -0.9780, -1.0000, -0.9849,  0.6733, -0.9980, -0.9707,
          0.8701, -1.0000,  0.1927, -0.8945, -0.9180,  0.8481, -0.9868, -0.8633]],
       device='cuda:0', dtype=torch.float16)

calico-niko avatar Apr 01 '24 07:04 calico-niko

Following up on this again.

timmytwoteeth avatar May 21 '24 02:05 timmytwoteeth

how to solve it? hf_bge_embedding is completely different from trtllm_bge_embedding.

longshuicui avatar Jun 05 '24 06:06 longshuicui

how to solve it? hf_bge_embedding is completely different from trtllm_bge_embedding.

Oops, the type of trtllm inputs is trt.INT32, but the output type of AutoTokenizer is torch.LONG. Inconsistent input data types lead to inconsistent results.

longshuicui avatar Jun 06 '24 06:06 longshuicui

same issue, any update?

Lzhang-hub avatar Oct 30 '24 11:10 Lzhang-hub

as @longshuicui mentioned above, torch.LONG and torch.INT does make a difference, see https://github.com/NVIDIA/TensorRT-LLM/blob/f6821ee393be6ec92234f9bb47a4b09f6738050b/examples/enc_dec/run.py#L209

I saw the above run.py may also suffer from this. Does this help?

symphonylyh avatar Oct 30 '24 20:10 symphonylyh

@symphonylyh I cast input_ids to int32 by input_ids = input_ids.to(torch.int32), got same result.

Lzhang-hub avatar Oct 31 '24 02:10 Lzhang-hub

I use torch.randint() to generate input_ids instead of tokenizer, for the same input_ids, the result of trtllm model and hf model still different.

import argparse
import json
import os
import time
# isort: off
import torch
import tensorrt as trt
# isort: on
import string
import random

import tensorrt_llm
from tensorrt_llm import logger
from tensorrt_llm.runtime import Session, TensorInfo

from build import get_engine_name  # isort:skip

from transformers import AutoTokenizer

def trt_dtype_to_torch(dtype):
    if dtype == trt.float16:
        return torch.float16
    elif dtype == trt.float32:
        return torch.float32
    elif dtype == trt.int32:
        return torch.int32
    else:
        raise TypeError("%s is not supported" % dtype)


def parse_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument('--log_level', type=str, default='info')
    parser.add_argument('--engine_dir', type=str, default='bert_outputs')

    return parser.parse_args()


if __name__ == '__main__':
    args = parse_arguments()

    tensorrt_llm.logger.set_level(args.log_level)

    config_path = os.path.join(args.engine_dir, 'config.json')
    with open(config_path, 'r') as f:
        config = json.load(f)

    assert config["plugin_config"]["remove_input_padding"] == False, \
        "Please refer to run_remove_input_padding.py for running BERT models with remove_input_padding enabled"

    dtype = config['builder_config']['precision']
    world_size = config['builder_config']['tensor_parallel']
    assert world_size == tensorrt_llm.mpi_world_size(), \
        f'Engine world size ({world_size}) != Runtime world size ({tensorrt_llm.mpi_world_size()})'

    model_name = config['builder_config']['name']
    runtime_rank = tensorrt_llm.mpi_rank() if world_size > 1 else 0

    runtime_mapping = tensorrt_llm.Mapping(world_size,
                                           runtime_rank,
                                           tp_size=world_size)
    torch.cuda.set_device(runtime_rank % runtime_mapping.gpus_per_node)

    serialize_path = get_engine_name(model_name, dtype, world_size,
                                     runtime_rank)
    serialize_path = os.path.join(args.engine_dir, serialize_path)

    stream = torch.cuda.current_stream().cuda_stream
    logger.info(f'Loading engine from {serialize_path}')
    with open(serialize_path, 'rb') as f:
        engine_buffer = f.read()
    logger.info(f'Creating session from engine')
    session = Session.from_serialized_engine(engine_buffer)

    tokenizer=AutoTokenizer.from_pretrained('BAAI/bge-large-zh-v1.5')

    for i in range(1):
        batch_size =1
        seq_len = 20 

        input_ids = torch.randint(0, tokenizer.vocab_size, (batch_size, seq_len)).int().cuda()
        input_lengths = seq_len * torch.ones(
            (batch_size, ), dtype=torch.int32, device='cuda')
        token_type_ids = torch.zeros((batch_size, seq_len)).int().cuda()

        inputs = {
            'input_ids': input_ids,
            'input_lengths': input_lengths,
            'token_type_ids': token_type_ids
        }
        output_info = session.infer_shapes([
            TensorInfo('input_ids', trt.DataType.INT32, input_ids.shape),
            TensorInfo('input_lengths', trt.DataType.INT32,
                       input_lengths.shape),
            TensorInfo('token_type_ids', trt.DataType.INT32,
                       token_type_ids.shape),
        ])
        outputs = {
            t.name: torch.empty(tuple(t.shape),
                                dtype=trt_dtype_to_torch(t.dtype),
                                device='cuda')
            for t in output_info
        }
        if (model_name == 'BertModel' or model_name == 'RobertaModel'):
            output_name = 'hidden_states'
        elif (model_name == 'BertForQuestionAnswering'
              or model_name == 'RobertaForQuestionAnswering'):
            output_name = 'logits'
        elif (model_name == 'BertForSequenceClassification'
              or model_name == 'RobertaForSequenceClassification'):
            output_name = 'logits'
        else:
            assert False, f"Unknown BERT model {model_name}"

        assert output_name in outputs, f'{output_name} not found in outputs, check if build.py set the name correctly'

        ok = session.run(inputs, outputs, stream)
        assert ok, "Runtime execution failed"
        torch.cuda.synchronize()
        res = outputs[output_name]

        trtllm_embeddings = res[:, 0, :]

        from transformers import AutoModel
        model = AutoModel.from_pretrained("/data1/nfs15/nfs/bigdata/zhanglei/ml/inference/model-demo/hf/BAAI/bge-large-zh-v1.5")
        model.cuda().eval()
        model_output = model(input_ids=input_ids,token_type_ids=token_type_ids)
        hf_embeddings = model_output[0][:, 0, :] 
        cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
        print(f"fp Sentence embeddings cosine: {cos(hf_embeddings, trtllm_embeddings)}")


Lzhang-hub avatar Oct 31 '24 03:10 Lzhang-hub

@Lzhang-hub Do you still have the question? If not, we will close it soon.

hello-11 avatar Nov 15 '24 10:11 hello-11