mmcv icon indicating copy to clipboard operation
mmcv copied to clipboard

[Enhance] Remove CUDA part from diff_iou_rotated

Open filaPro opened this issue 2 years ago • 5 comments

Motivation

Fix #1922. Following lilanxiao/Rotated_IoU#39#issuecomment-1146352088 we remove cuda and cpp part of diff_iou_rotated in favour of pure pytorch implementation. It has negligible affect on speed however the overall accuracy in corner cases is much better.

Modification

Replace sorting vertices from cuda to pytorch. Split test to GPU and CPU.

BC-breaking (Optional)

No, the public API is not changed. However we may want to move it from mmcv.ops as it does not require cuda now. mmdet3d does not use it in master. I've checked this PR with my FCAF3D PR mmdetection3d#1547. Also diff_iou_rotated is used in mmrotate master. Maybe @ZwwWayne or @Tai-Wang can have a look on mmdetection3d connection and @zytx121 for mmrotate.

Benchmark

Setup: ubuntu 18.04.6, nvidia driver 470.129.06, nvidia geforce rtx 3090, pytorch/pytorch:1.8.1-cuda10.2-cudnn7-devel

# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import pytest
import torch

from mmcv.ops import diff_iou_rotated_2d, diff_iou_rotated_3d

def test():
    np_boxes_2d_1 = np.random.random((100, 1000, 5)).astype(np.float32)
    np_boxes_2d_2 = np.random.random((100, 1000, 5)).astype(np.float32)
    boxes1 = torch.from_numpy(np_boxes_2d_1).cuda()
    boxes2 = torch.from_numpy(np_boxes_2d_2).cuda()
    torch.cuda.synchronize()
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    ious = diff_iou_rotated_2d(boxes1, boxes2)
    end.record()
    torch.cuda.synchronize()
    return start.elapsed_time(end)

a = [test() for _ in range(100)][1:]
print(f'mean={np.mean(a):.4f}, std={np.std(a):.4f}')

master >>> mean=7.1313, std=0.2482 this pr >>> mean=11.1537, std=0.3251

filaPro avatar Jun 21 '22 07:06 filaPro

Hi @filaPro , thanks for your contributions. Could you add some benchmark results in the PR description? You can refer to https://github.com/open-mmlab/mmcv/pull/1718#issuecomment-1061531123.

zhouzaida avatar Jun 21 '22 11:06 zhouzaida

Hi @zhouzaida , done. I haven't tried torch.cuda.Event before, hope my script is right. Also don't think that the decrease of speed is important here.

filaPro avatar Jun 21 '22 12:06 filaPro

Only need to fix minor comments. Others LGTM.

Tai-Wang avatar Jul 01 '22 07:07 Tai-Wang

@filaPro Great contribution! BTW, I am curious about the iou computation used in mmdet3d. Sometimes I feel like there are also some corner cases that BEV IoU can not be precisely computed, for example, for two strictly overlapped boxes on KITTI. Do you have any idea about it and is it related to similar problems as shown in this PR?

Tai-Wang avatar Jul 01 '22 07:07 Tai-Wang

Hi, @Tai-Wang. I think yes, it is the similar problem. The thing is that the intersection of 2 rotated boxes is a convex polygon with from 1 to 8 vertices. We first determine them and then simply (because it is convex) calculate the area. But when the edges of the boxes are collinear it is hard to determine the correct intersection coordinates. And for some reason this numerical instability appears more with cuda code. As the pytorch implementation is more stable we can consider removing mmcv.ops.box_iou_rotated in favour of this mmcv.ops.diff_iou_rotated in the future.

filaPro avatar Jul 06 '22 12:07 filaPro

LGTM. Just need to resolve conflicts

Tai-Wang avatar Oct 19 '22 12:10 Tai-Wang

Hi @filaPro , this PR can be merged after resolving conflicts.

zhouzaida avatar Oct 20 '22 12:10 zhouzaida

Hi @filaPro
According to https://github.com/open-mmlab/mmcv/issues/2335, even if the two boxes overlap, there is still a gradient. Does this meet our expectations?

zytx121 avatar Oct 25 '22 00:10 zytx121