TensorRT
TensorRT copied to clipboard
🐛 [Bug] Part of the weights are placed to CPU during compilation
Bug Description
When compiling Bert, a device mismatch occurs. This seems to be caused by weights moved to CPU during compilation.
To Reproduce
Steps to reproduce the behavior:
Run this script:
import torch
import torch_tensorrt as torchtrt
from transformers import BertModel
inputs = [
torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda"),
]
model = BertModel.from_pretrained("bert-base-uncased").eval().to("cuda")
enabled_precisions = {torch.float}
debug = True
min_block_size = 1
use_python_runtime = False
exp_program = torch.export.export(model, tuple(inputs))
trt_gm = torchtrt.dynamo.compile(
exp_program,
tuple(inputs),
use_python_runtime=use_python_runtime,
enabled_precisions=enabled_precisions,
debug=debug,
min_block_size=min_block_size,
immutable_weights=False,
)
Expected behavior
Environment
Build information about Torch-TensorRT can be found by turning on debug messages
- Torch-TensorRT Version (e.g. 1.0.0): main branch
- PyTorch Version (e.g. 1.0): nightly
- OS (e.g., Linux): LInux
- How you installed PyTorch (
conda,pip,libtorch, source): pip
Additional context
If you add a device check in torchtrt.dynamo.compile function, the result is
{device(type='cuda', index=0)}
{device(type='cuda', index=0), device(type='cpu')}
It seems the culprit is immutable_weights=False because immutable_weights=True (default) compiles fine.
It's because when immutable_weight=False weights on the CPU device raise an error. When immutable_weight is True weights are still on CPU but no error is raised.