poor performance of batched matmul for larger batch sizes
Description
I am currently evaluating machine learning frameworks as a tool for GPU accelerated simulations. I noticed that TensorRT is orders of magnitude slower than the other backends I tried. This can be largely attributed to a single operation, batched matrix-matrix multiply(bmm). Both compilation time and inference time increase massively with batch size, even compared to basic PyTorch:
| batch size | PyTorch [s] | TensorRT [s] | TensorRT compilation [s] |
|---|---|---|---|
| 2^10 | 0.0158 | 0.0001 | 7.10 |
| 2^16 | 0.016 | 0.151 | 38.89 |
| 2^18 | 0.016 | 0.625 | 143.25 |
In my case the batch dimension is the number of elements in the mesh, which can be much larger than common batch sizes in deep learning. Therefore this issue may be out of scope. However, I have got some promising results with TorchInductor compared to my handwritten CUDA, so I believe this has some potential.
Environment
TensorRT Version: 8.6.1.post1
NVIDIA GPU: RTX 3090 24GB
NVIDIA Driver Version: 545.23.08
CUDA Version: 12.3
CUDNN Version: 8.9.2.26
Operating System:
Python Version: Python 3.10.12
PyTorch Version: 2.2.0.dev20231207+cu121
Baremetal or Container (if so, version): Baremetal
Relevant Files
This is the code used to generate the results from above:
import tensorrt
import torch
import time
class TestModel(torch.nn.Module):
def __init__(self):
super(TestModel, self).__init__()
def forward(self, param_map, vecs):
return torch.bmm(param_map, vecs)
def make_engine(model, inputs, path, logger):
onnx_path = f"{path}.onnx"
torch.onnx.export(model, inputs, onnx_path)
builder = tensorrt.Builder(logger)
network = builder.create_network(1 << int(tensorrt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = tensorrt.OnnxParser(network, logger)
success = parser.parse_from_file(onnx_path)
for idx in range(parser.num_errors):
print(parser.get_error(idx))
if not success:
print("[Error] Parsing failed.")
exit()
config = builder.create_builder_config()
serialized_engine = builder.build_serialized_network(network, config)
trt_model_path = f"{path}.trt"
with open(trt_model_path, "wb") as f:
f.write(serialized_engine)
return serialized_engine
if __name__ == '__main__':
model = TestModel()
model.eval()
n = 2**16
param_map = torch.rand([n, 3,4], device='cuda')
vecs = torch.rand([n, 4, 1], device='cuda')
start_ref = time.perf_counter()
res = model.forward(param_map, vecs)
end_ref = time.perf_counter()
logger = tensorrt.Logger()
start_compile = time.perf_counter()
serialized_engine = make_engine(model, (param_map, vecs), "testmodel", logger)
end_compile = time.perf_counter()
runtime = tensorrt.Runtime(logger)
engine = runtime.deserialize_cuda_engine(serialized_engine)
context = engine.create_execution_context()
res_trt = torch.zeros_like(res)
start_trt = time.perf_counter()
context.execute_v2([param_map.data_ptr(), vecs.data_ptr(), res_trt.data_ptr()])
end_trt = time.perf_counter()
# verify result
print(f"norm(res) = {torch.norm(res).item()}, norm(res-res_trt) = {torch.norm(res-res_trt).item()}")
# timings
print(f"t ref {end_ref-start_ref}s")
print(f"t trt {end_trt - start_trt}s")
print(f"t compile {end_compile - start_compile}")
Steps To Reproduce
Just run the python script above.
@nvpohanh any comments? ^ ^
What if you add an extra batch dimension. so the inputs be like 1xold_batchxlenx...?
@zerollzeng Could you file an internal tracker for this if we can repro this?
I think the problem is that the gemm size is too small (only 3x4*4x1) and is not a typical gemm size in Deep Learning workloads. That's why we have never tried to optimize this case.
That said, we would be interested in seeing what (and whether) special kernel that PyTorch uses for this case and see if we can integrate that into TensorRT. Thanks
What if you add an extra batch dimension. so the inputs be like 1xold_batchxlenx...?
Switching to standard matmul and adding an extra dimension makes no difference for the runtimes. It looks like it maps to the same operation.
There is a hack that should make TRT much faster for this bmm. Instead of doing torch.bmm(param_map, vecs), do this instead:
return torch.sum(param_map * vecs.view(n, 1, 4), dim=2)
Thanks, this does indeed help. I will include this alternative in my comparison.
closing since no activity for more than 3 weeks per our policy, thanks all!