TensorRT icon indicating copy to clipboard operation
TensorRT copied to clipboard

🐛 [Bug] Part of the weights are placed to CPU during compilation

Open cehongwang opened this issue 8 months ago • 3 comments

Bug Description

When compiling Bert, a device mismatch occurs. This seems to be caused by weights moved to CPU during compilation.

To Reproduce

Steps to reproduce the behavior:

Run this script:

import torch
import torch_tensorrt as torchtrt

from transformers import BertModel

inputs = [
        torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda"),
    ]
model = BertModel.from_pretrained("bert-base-uncased").eval().to("cuda")
enabled_precisions = {torch.float}
debug = True
min_block_size = 1
use_python_runtime = False

exp_program = torch.export.export(model, tuple(inputs))

trt_gm = torchtrt.dynamo.compile(
    exp_program,
    tuple(inputs),
    use_python_runtime=use_python_runtime,
    enabled_precisions=enabled_precisions,
    debug=debug,
    min_block_size=min_block_size,
    immutable_weights=False,
)


Expected behavior

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0): main branch
  • PyTorch Version (e.g. 1.0): nightly
  • OS (e.g., Linux): LInux
  • How you installed PyTorch (conda, pip, libtorch, source): pip

Additional context

cehongwang avatar Mar 25 '25 05:03 cehongwang

Image

If you add a device check in torchtrt.dynamo.compile function, the result is

{device(type='cuda', index=0)}
{device(type='cuda', index=0), device(type='cpu')}

cehongwang avatar Mar 25 '25 05:03 cehongwang

It seems the culprit is immutable_weights=False because immutable_weights=True (default) compiles fine.

HolyWu avatar Mar 25 '25 15:03 HolyWu

It's because when immutable_weight=False weights on the CPU device raise an error. When immutable_weight is True weights are still on CPU but no error is raised.

cehongwang avatar Mar 25 '25 16:03 cehongwang