[Bug] Error Exporting RetinaNet for Single Class Case with CrossEntropyLoss in MMDeploy
Checklist
- [X] I have searched related issues but cannot get the expected help.
- [X] 2. I have read the FAQ documentation but cannot get the expected help.
- [X] 3. The bug has not been fixed in the latest version.
Describe the bug
I am facing an issue when exporting the RetinaNet from mmdetection (https://github.com/open-mmlab/mmdetection/blob/main/configs/base/models/retinanet_r50_fpn.py) for a single class case. IndexError was raised due to the sliced nms_pre_score having zero-dim
I modified the classification loss function to employ Cross Entropy (type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), in such the effective bbox_head config would be as follow:
bbox_head=dict(
type="RetinaHead",
num_classes=80,
in_channels=256,
stacked_convs=4,
feat_channels=256,
anchor_generator=dict(
type="AnchorGenerator",
octave_base_scale=4,
scales_per_octave=3,
ratios=[0.5, 1.0, 2.0],
strides=[8, 16, 32, 64, 128],
),
bbox_coder=dict(type="DeltaXYWHBBoxCoder", target_means=[0.0, 0.0, 0.0, 0.0], target_stds=[1.0, 1.0, 1.0, 1.0]),
loss_cls=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
loss_bbox=dict(type="L1Loss", loss_weight=1.0),
)
To export the model into ONNX, I called the export function from https://github.com/open-mmlab/mmdeploy/blob/bc75c9d6c8940aa03d0e1e5b5962bd930478ba77/mmdeploy/apis/onnx/export.py. Based on my understanding, before torch.onnx.export was invoked, the model is patched with modified child modules and for this particular case, the predict_with_feat() is replaced with base_dense_head__predict_by_feat() in https://github.com/open-mmlab/mmdeploy/blob/bc75c9d6c8940aa03d0e1e5b5962bd930478ba77/mmdeploy/codebase/mmdet/models/dense_heads/base_dense_head.py#L26-L27.
After reviewing the code in https://github.com/open-mmlab/mmdeploy/blob/bc75c9d6c8940aa03d0e1e5b5962bd930478ba77/mmdeploy/codebase/mmdet/models/dense_heads/base_dense_head.py#L26-L27,
I noticed three parts involving the use_sigmoid flag configured in the CrossEntropyLoss, namely:
-
At the constructor of the RetinaHead : https://github.com/open-mmlab/mmdetection/blob/cfd5d3a985b0249de009b67d04f37263e11cdf3d/mmdet/models/dense_heads/anchor_head.py#L73-L78
-
At the base_dense_head, there is first slicing of the
scores. I presume this is to exclude the background (indexnum_classes): https://github.com/open-mmlab/mmdeploy/blob/bc75c9d6c8940aa03d0e1e5b5962bd930478ba77/mmdeploy/codebase/mmdet/models/dense_heads/base_dense_head.py#L113-L117 -
This is the confusing part, there is a second round of slicing when getting the max_scores: https://github.com/open-mmlab/mmdeploy/blob/bc75c9d6c8940aa03d0e1e5b5962bd930478ba77/mmdeploy/codebase/mmdet/models/dense_heads/base_dense_head.py#L141-L146
I hope you could explain the reasoning behind this, as it appears that the last object class is excluded when computing the max_scores. Thank you!
Reproduction
from mmdeploy.apis.onnx import export as onnx_export
onnx_export(
model=model,
args=img,
output_path_prefix=str(main_file),
backend="onnxruntime",
input_metas=input_metas,
context_info=context_info,
input_names=input_names,
output_names=output_names,
opset_version=11,
dynamic_axes=dynamic_axes,
verbose=False,
keep_initializers_as_inputs=False,
optimize=True,
)
Environment
07/02 13:01:39 - mmengine - INFO -
07/02 13:01:39 - mmengine - INFO - **********Environmental information**********
07/02 13:01:44 - mmengine - INFO - sys.platform: linux
07/02 13:01:44 - mmengine - INFO - Python: 3.10.14 (main, Mar 21 2024, 16:24:04) [GCC 11.2.0]
07/02 13:01:44 - mmengine - INFO - CUDA available: True
07/02 13:01:44 - mmengine - INFO - MUSA available: False
07/02 13:01:44 - mmengine - INFO - numpy_random_seed: 2147483648
07/02 13:01:44 - mmengine - INFO - GPU 0: NVIDIA
07/02 13:01:44 - mmengine - INFO - CUDA_HOME: /usr/local/cuda
07/02 13:01:44 - mmengine - INFO - NVCC: Cuda compilation tools, release 12.3, V12.3.107
07/02 13:01:44 - mmengine - INFO - GCC: gcc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
07/02 13:01:44 - mmengine - INFO - PyTorch: 2.0.1
07/02 13:01:44 - mmengine - INFO - PyTorch compiling details: PyTorch built with:
- GCC 9.3
- C++ Version: 201703
- Intel(R) oneAPI Math Kernel Library Version 2023.1-Product Build 20230303 for Intel(R) 64 architecture applications
- Intel(R) MKL-DNN v2.7.3 (Git Hash 6dbeffbae1f23cbbeae17adb7b5b13f1f37c080e)
- OpenMP 201511 (a.k.a. OpenMP 4.5)
- LAPACK is enabled (usually provided by MKL)
- NNPACK is enabled
- CPU capability usage: AVX2
- CUDA Runtime 11.8
- NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_61,code=sm_61;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_90,code=sm_90;-gencode;arch=compute_37,code=compute_37
- CuDNN 8.7
- Magma 2.6.1
- Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=11.8, CUDNN_VERSION=8.7.0, CXX_COMPILER=/opt/rh/devtoolset-9/root/usr/bin/c++, CXX_FLAGS= -D_GLIBCXX_USE_CXX11_ABI=0 -fabi-version=11 -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOROCTRACER -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wunused-local-typedefs -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Werror=cast-function-type -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_DISABLE_GPU_ASSERTS=ON, TORCH_VERSION=2.0.1, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF,
07/02 13:01:44 - mmengine - INFO - TorchVision: 0.15.2
07/02 13:01:44 - mmengine - INFO - OpenCV: 4.9.0
07/02 13:01:44 - mmengine - INFO - MMEngine: 0.10.3
07/02 13:01:44 - mmengine - INFO - MMCV: 2.0.1
07/02 13:01:44 - mmengine - INFO - MMCV Compiler: GCC 9.3
07/02 13:01:44 - mmengine - INFO - MMCV CUDA Compiler: 11.8
07/02 13:01:44 - mmengine - INFO - MMDeploy: 1.3.1+bc75c9d
07/02 13:01:44 - mmengine - INFO -
07/02 13:01:44 - mmengine - INFO - **********Backend information**********
07/02 13:01:44 - mmengine - INFO - tensorrt: None
07/02 13:01:44 - mmengine - INFO - ONNXRuntime: None
07/02 13:01:44 - mmengine - INFO - pplnn: None
07/02 13:01:44 - mmengine - INFO - ncnn: None
07/02 13:01:45 - mmengine - INFO - snpe: None
07/02 13:01:45 - mmengine - INFO - openvino: None
07/02 13:01:45 - mmengine - INFO - torchscript: 2.0.1
07/02 13:01:45 - mmengine - INFO - torchscript custom ops: NotAvailable
07/02 13:01:45 - mmengine - INFO - rknn-toolkit: None
07/02 13:01:45 - mmengine - INFO - rknn-toolkit2: None
07/02 13:01:45 - mmengine - INFO - ascend: None
07/02 13:01:45 - mmengine - INFO - coreml: None
07/02 13:01:45 - mmengine - INFO - tvm: None
07/02 13:01:45 - mmengine - INFO - vacc: None
07/02 13:01:45 - mmengine - INFO -
07/02 13:01:45 - mmengine - INFO - **********Codebase information**********
07/02 13:01:45 - mmengine - INFO - mmdet: 3.3.0
07/02 13:01:45 - mmengine - INFO - mmseg: None
07/02 13:01:45 - mmengine - INFO - mmpretrain: 1.2.0
07/02 13:01:45 - mmengine - INFO - mmocr: None
07/02 13:01:45 - mmengine - INFO - mmagic: None
07/02 13:01:45 - mmengine - INFO - mmdet3d: None
07/02 13:01:45 - mmengine - INFO - mmpose: None
07/02 13:01:45 - mmengine - INFO - mmrotate: None
07/02 13:01:45 - mmengine - INFO - mmaction: None
07/02 13:01:45 - mmengine - INFO - mmrazor: None
07/02 13:01:45 - mmengine - INFO - mmyolo: None
Error traceback
│ xxx/lib/python3.10/site-packages/mmdeploy/codebase/mmdet/models │
│ /detectors/single_stage.py:85 in single_stage_detector__forward │
│ │
│ 82 │ # set the metainfo │
│ 83 │ data_samples = _set_metainfo(data_samples, img_shape) │
│ 84 │ │
│ ❱ 85 │ return __forward_impl(self, batch_inputs, data_samples=data_samples) │
│ 86 │
│ │
│ xxx/lib/python3.10/site-packages/mmdeploy/core/optimizers/funct │
│ ion_marker.py:266 in g │
│ │
│ 263 │ │ │ args = mark_tensors(args, func, func_id, 'input', ctx, attrs, │
│ 264 │ │ │ │ │ │ │ │ is_inspect, args_level) │
│ 265 │ │ │ │
│ ❱ 266 │ │ │ rets = f(*args, **kwargs) │
│ 267 │ │ │ │
│ 268 │ │ │ ctx = Context(output_names) │
│ 269 │ │ │ func_ret = mark_tensors(rets, func, func_id, 'output', ctx, attrs, │
│ │
│ xxx/lib/python3.10/site-packages/mmdeploy/codebase/mmdet/models │
│ /detectors/single_stage.py:23 in __forward_impl │
│ │
│ 20 │ """ │
│ 21 │ x = self.extract_feat(batch_inputs) │
│ 22 │ │
│ ❱ 23 │ output = self.bbox_head.predict(x, data_samples, rescale=False) │
│ 24 │ return output │
│ 25 │
│ 26 │
│ │
│ xxx/lib/python3.10/site-packages/mmdet/models/dense_heads/base_ │
│ dense_head.py:197 in predict │
│ │
│ 194 │ │ │
│ 195 │ │ outs = self(x) │
│ 196 │ │ │
│ ❱ 197 │ │ predictions = self.predict_by_feat( │
│ 198 │ │ │ *outs, batch_img_metas=batch_img_metas, rescale=rescale) │
│ 199 │ │ return predictions │
│ 200 │
│ │
│ xxx/lib/python3.10/site-packages/mmdeploy/codebase/mmdet/models │
│ /dense_heads/base_dense_head.py:145 in base_dense_head__predict_by_feat │
│ │
│ 142 │ │ │ if self.use_sigmoid_cls: │
│ 143 │ │ │ │ max_scores, _ = nms_pre_score.max(-1) │
│ 144 │ │ │ else: │
│ ❱ 145 │ │ │ │ max_scores, _ = nms_pre_score[..., :-1].max(-1) │
│ 146 │ │ │ _, topk_inds = max_scores.topk(pre_topk) │
│ 147 │ │ │ bbox_pred, scores, score_factors = gather_topk( │
│ 148 │ │ │ │ bbox_pred, │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
IndexError: max(): Expected reduction dim 2 to have non-zero size.
/cc @RunningLeon @grimoire
#2534 and https://github.com/open-mmlab/mmdeploy/commit/a51ee2c76caa1c8080e51981dccf829a23907791#diff-2871f924ffd987597f7ca6d4f5227f6d04fc49d3ca246a9bacd4b797e394206fR109
Hi, good day to you maintainers! May I know will there be any action plans for this 👀?
Thanks for the notification. This should be a bug of a big refactor https://github.com/open-mmlab/mmdeploy/pull/1091. Removing either of the slices should be ok.
@grimoire ,thank you for your response. Could you let me know if a patch release is planned, and if so, when we might expect it? I can create a PR for this patch too if it helpss.
I can create a PR for this patch too if it helpss.
Sure, that would be cool.