mmdeploy
mmdeploy copied to clipboard
[Bug] TRTBatchedRotatedNMS PlugIn CUDA kernels raise a runtime CUDA error - cudaErrorLaunchOutOfResources on NVIDIA Jetson TX2
Checklist
- [ ] I have searched related issues but cannot get the expected help.
- [X] 2. I have read the FAQ documentation but cannot get the expected help.
- [ ] 3. The bug has not been fixed in the latest version.
Describe the bug
We successfully deploy the RtmDetR to Onnx and TensorRT BackEnds. Additionally, we successfully convert the generated Onnx to TRT engine by our self using your TRTBatchedRotatedNMS PlugIn open source.
During RtmDetR inference on NVIDIA Jetson TX2 platform based on JetPack 4.6.2: https://developer.nvidia.com/embedded/jetpack-sdk-462
A CUDA error is raised - cudaErrorLaunchOutOfResources.
After a lot of checks we verified that the root cause is the implementation of the CUDA kernel - allClassRotatedNMS_kernel.
For reference purpose only, the same inference is successfully operated on the following platform: Quadro RTX 3000 GPU CUDA 11.7 TRT 8.5.3.1 CuDNN 8.9.2
Any help will much appreciated.
Regards,
Reproduction
Activate the deploy scripts: python ^ ./MmLab/mmdeploy/tools/deploy.py ^ ./MmLab/mmdeploy/configs/mmrotate/rotated-detection_tensorrt-fp16_dynamic-320x320-1024x1024.py ^ ./MmLab/mmrotate/configs/rotated_rtmdet/rotated_rtmdet_m-3x-dota.py ^ ./MmLab/playground/checkpoints/cspnext-m_8xb256-rsb-a1-600e_in1k-ecb3bbd9.pth ^ ./MmLab/playground/Images/100MEDIA_VOC2012_JPEGImages_DJI_0195__1024__974___974.png ^ --work-dir d:/ThirdParties/MmLab/playground/mmdeploy_model/RtmDetR_Medium_DeltaTrained_Dynamic ^ --device cuda ^ --dump-info
Trying to inference the generated TRT engine file on several platforms for benchmark.
We didn't change any MmLab source code.
Environment
We didn't install the MmLab, MmRotate & MmDeploy on the NVIDIA Jetson but on our PC.
After the Onnx was generated, Based on our own Python\C++ code, we registered the PlugIn and successfully created the TRT engine both on our PC and NVIDIA Jetson TX2.
On the PC everything is working well but on the TX2 no.
Error traceback
cudaErrorLaunchOutOfResources CUDA error is reported by the allClassRotatedNMS_kernel CUDA kernel.
jetson device has limited register usage.
Try smaller pre_top_k
keep_top_k
max_output_boxes_per_class
in https://github.com/open-mmlab/mmdeploy/blob/main/configs/mmrotate/rotated-detection_static.py
Hello @grimoire,
Thank you for your suggestion. We tried several values options, for example, we extremely reduced them to: pre_top_k=5, keep_top_k=5, max_output_boxes_per_class=1
But still, for every values combination we tested, we always get the same CUDA error described above.
Please advise,
Thanks
What is the value of t_size
in https://github.com/open-mmlab/mmdeploy/blob/2882c64eea8640f913588f6962e66abf2e7b6c86/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/allClassRotatedNMS.cu#L431 ?
Hello @grimoire, The original generated Onnx and PlugIn source code produce the value 6 for t_size variable. During our investigations, without full understanding the impacts, just for debugging, we manually set it's value to several options started from 1 to 6. Unfortunately, for all of them the CUDA error is still reproduced.
Thank you for your support. Regards,
I see. you can try use a smaller BS, which might reduce the register usage of the kernel. Note that smaller BS means larger t_size, which will also enlarge register usage.
We tried that test also. We manually set its value to 1. Still the CUDA error reproduced.
Based on our tests that we made till now, I think that the implementation of the kernel itself cause this CUDA error. I don't know, but maybe some local variables, amount of input arguments, internal loops etc. cause this issue.
What do you think?
Regards,
I think that the implementation of the kernel itself cause this CUDA error.
cudaErrorLaunchOutOfResources
will be raised when the kernel uses too many registers. Both BS
and t_size
are parameters that could limit the usage of registers. (pre_top_k
keep_top_k
max_output_boxes_per_class
leads to smaller t_size
). It works on NMS without rotation, but rotated NMS is far more complicate, which use large register blocks to compute rotation intersection.
TX2 has 32768 registers per block (vs 65536 on any non-edge device. I guess this is the key on this error).
Since I do not have a tx2 device, it is hard to optimize the kernel for now. You can try optimize the kernel by yourself, such as fuse functions with large array.
Thank you @grimoire, Your support is appriciated.
We tried to flat the NMS calls stack to one function without any real implementation changes but still got the error.
Finally, we decided to bypass the problem by changing the following code:
`for (int t = 0; t < TSIZE; t++) { const int cur_idx = threadIdx.x + blockDim.x * t; const int item_idx = offset + cur_idx;
if ((kept_bboxinfo_flag[cur_idx]) && (item_idx > ref_item_idx)) {
// TODO: may need to add bool normalized as argument, HERE true means
// normalized
**//if (single_box_iou_rotated(&ref_bbox[0], loc_bbox + t * 5) > nms_threshold) {
//kept_bboxinfo_flag[cur_idx] = false;
//}**
}
}
We masked the blded lines and implement NMS externally by ourself. This help us to bypass the problem.
If you will ever update the PlugIn I will happy to hear about it.
Thanks,