TensorRT icon indicating copy to clipboard operation
TensorRT copied to clipboard

poor performance of batched matmul for larger batch sizes

Open Thanduriel opened this issue 2 years ago • 6 comments

Description

I am currently evaluating machine learning frameworks as a tool for GPU accelerated simulations. I noticed that TensorRT is orders of magnitude slower than the other backends I tried. This can be largely attributed to a single operation, batched matrix-matrix multiply(bmm). Both compilation time and inference time increase massively with batch size, even compared to basic PyTorch:

batch size PyTorch [s] TensorRT [s] TensorRT compilation [s]
2^10 0.0158 0.0001 7.10
2^16 0.016 0.151 38.89
2^18 0.016 0.625 143.25

In my case the batch dimension is the number of elements in the mesh, which can be much larger than common batch sizes in deep learning. Therefore this issue may be out of scope. However, I have got some promising results with TorchInductor compared to my handwritten CUDA, so I believe this has some potential.

Environment

TensorRT Version: 8.6.1.post1

NVIDIA GPU: RTX 3090 24GB

NVIDIA Driver Version: 545.23.08

CUDA Version: 12.3

CUDNN Version: 8.9.2.26

Operating System:

Python Version: Python 3.10.12

PyTorch Version: 2.2.0.dev20231207+cu121

Baremetal or Container (if so, version): Baremetal

Relevant Files

This is the code used to generate the results from above:

import tensorrt
import torch
import time

class TestModel(torch.nn.Module):
    def __init__(self):
        super(TestModel, self).__init__()

    def forward(self, param_map, vecs):
        return torch.bmm(param_map, vecs)

def make_engine(model, inputs, path, logger):
    onnx_path = f"{path}.onnx"
    torch.onnx.export(model, inputs, onnx_path)
    
    builder = tensorrt.Builder(logger)
    network = builder.create_network(1 << int(tensorrt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
    
    parser = tensorrt.OnnxParser(network, logger)
    success = parser.parse_from_file(onnx_path)
    for idx in range(parser.num_errors):
        print(parser.get_error(idx))
    
    if not success:
        print("[Error] Parsing failed.")
        exit()

    config = builder.create_builder_config()
    serialized_engine = builder.build_serialized_network(network, config)
    trt_model_path = f"{path}.trt"
    with open(trt_model_path, "wb") as f:
        f.write(serialized_engine)

    return serialized_engine

if __name__ == '__main__':
    model = TestModel()
    model.eval()
    
    n = 2**16
    param_map = torch.rand([n, 3,4], device='cuda')
    vecs = torch.rand([n, 4, 1], device='cuda')

    start_ref = time.perf_counter()
    res = model.forward(param_map, vecs)
    end_ref = time.perf_counter()

    logger = tensorrt.Logger()
    
    start_compile = time.perf_counter()
    serialized_engine = make_engine(model, (param_map, vecs), "testmodel", logger)
    end_compile = time.perf_counter()

    runtime = tensorrt.Runtime(logger)
    engine = runtime.deserialize_cuda_engine(serialized_engine)
    context = engine.create_execution_context()
    res_trt = torch.zeros_like(res)

    start_trt = time.perf_counter()
    context.execute_v2([param_map.data_ptr(), vecs.data_ptr(), res_trt.data_ptr()])
    end_trt = time.perf_counter()

    # verify result
    print(f"norm(res) = {torch.norm(res).item()}, norm(res-res_trt) = {torch.norm(res-res_trt).item()}")
    # timings
    print(f"t ref     {end_ref-start_ref}s")
    print(f"t trt     {end_trt - start_trt}s")
    print(f"t compile {end_compile - start_compile}")

Steps To Reproduce

Just run the python script above.

Thanduriel avatar Dec 08 '23 13:12 Thanduriel

@nvpohanh any comments? ^ ^

zerollzeng avatar Dec 12 '23 14:12 zerollzeng

What if you add an extra batch dimension. so the inputs be like 1xold_batchxlenx...?

zerollzeng avatar Dec 12 '23 14:12 zerollzeng

@zerollzeng Could you file an internal tracker for this if we can repro this?

I think the problem is that the gemm size is too small (only 3x4*4x1) and is not a typical gemm size in Deep Learning workloads. That's why we have never tried to optimize this case.

That said, we would be interested in seeing what (and whether) special kernel that PyTorch uses for this case and see if we can integrate that into TensorRT. Thanks

nvpohanh avatar Dec 12 '23 14:12 nvpohanh

What if you add an extra batch dimension. so the inputs be like 1xold_batchxlenx...?

Switching to standard matmul and adding an extra dimension makes no difference for the runtimes. It looks like it maps to the same operation.

Thanduriel avatar Dec 13 '23 10:12 Thanduriel

There is a hack that should make TRT much faster for this bmm. Instead of doing torch.bmm(param_map, vecs), do this instead:

return torch.sum(param_map * vecs.view(n, 1, 4), dim=2)

nvpohanh avatar Dec 14 '23 02:12 nvpohanh

Thanks, this does indeed help. I will include this alternative in my comparison.

Thanduriel avatar Dec 21 '23 16:12 Thanduriel

closing since no activity for more than 3 weeks per our policy, thanks all!

ttyio avatar May 07 '24 18:05 ttyio