mmrotate icon indicating copy to clipboard operation
mmrotate copied to clipboard

The accuracy of HRSC2016 has been very low during training.

Open makabakasu opened this issue 2 years ago • 8 comments

Prerequisite

Task

I'm using the official example scripts/configs for the officially supported tasks/models/datasets.

Branch

master branch https://github.com/open-mmlab/mmrotate

Environment

sys.platform: win32 Python: 3.8.13 | packaged by conda-forge | (default, Mar 25 2022, 05:59:00) [MSC v.1929 64 bit (AMD64)] CUDA available: True GPU 0: NVIDIA GeForce RTX 3060 Laptop GPU CUDA_HOME: C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.1 NVCC: Cuda compilation tools, release 11.1, V11.1.74 MSVC: 用于 x64 的 Microsoft (R) C/C++ 优化编译器 19.16.27045 版 GCC: n/a PyTorch: 1.8.0 PyTorch compiling details: PyTorch built with:

  • C++ Version: 199711
  • MSVC 192829337
  • Intel(R) Math Kernel Library Version 2020.0.0 Product Build 20191125 for Intel(R) 64 architecture applications
  • Intel(R) MKL-DNN v1.7.0 (Git Hash 7aed236906b1f7a05c0917e5257a1af05e9ff683)
  • OpenMP 2019
  • CPU capability usage: AVX2
  • CUDA Runtime 11.1
  • NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arc h=compute_61,code=sm_61;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=com pute_86,code=sm_86;-gencode;arch=compute_37,code=compute_37
  • CuDNN 8.0.5
  • Magma 2.5.4
  • Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=11.1, CUDNN_VERSION=8.0.5, CXX_COMPILER=C:/cb/pytorch_1000000000000/work/tmp_bin/ sccache-cl.exe, CXX_FLAGS=/DWIN32 /D_WINDOWS /GR /EHsc /w /bigobj -DUSE_PTHREADPOOL -openmp:experimental -DNDEBUG -DUSE_FBGEMM -DUSE_XNNPACK, LAPACK_ INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=1.8.0, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=OFF, USE_NNPACK=OFF, USE_OPENMP=ON,

TorchVision: 0.9.0 OpenCV: 4.5.3 MMCV: 1.7.1 MMCV Compiler: MSVC 191627045 MMCV CUDA Compiler: 11.1 MMRotate: 0.3.4+7755aa5

Reproduces the problem - code sample

The official code.

Reproduces the problem - command or script

The official configs.

Reproduces the problem - error message

The accuracy of HRSC2016 has been very low during training.

Additional information

I download the HRSC2016 dataset and directly extract it to the part.01 folder, and then add the path to ‘hrsc.py’. But the accuracy has been very low during training. Does any preprocessing required for HRSC2016? image

makabakasu avatar Feb 16 '23 05:02 makabakasu

Hi @makabakasu Please try mmrotate-1.x. https://github.com/open-mmlab/mmrotate/tree/1.x We no longer maintain 0. x.

zytx121 avatar Feb 24 '23 00:02 zytx121

Did you solve this problem?

sudoAimer avatar Mar 02 '23 09:03 sudoAimer

Did you solve this problem?

Have you solved it ? I meet the same problem, using HRSC2016 dataset, the training process was interupted after 47 epochs( I set it to 140 epochs totally) , and the precision is very low( only 0.008)

soHardToHaveAName avatar Apr 13 '23 09:04 soHardToHaveAName

Did you solve this problem?

Have you solved it ? I meet the same problem, using HRSC2016 dataset, the training process was interupted after 47 epochs( I set it to 140 epochs totally) , and the precision is very low( only 0.008) I noticed that my learning rate might have been causing some issues, so I switched from SGD to AdamW on version 0.3.2 , and everything is working fine on version 1.x. I hope it was helpful to you.

sudoAimer avatar Apr 13 '23 09:04 sudoAimer

Did you solve this problem?

Have you solved it ? I meet the same problem, using HRSC2016 dataset, the training process was interupted after 47 epochs( I set it to 140 epochs totally) , and the precision is very low( only 0.008) I noticed that my learning rate might have been causing some issues, so I switched from SGD to AdamW on version 0.3.2 , and everything is working fine on version 1.x. I hope it was helpful to you.

Thanks. I am already using Adam, But the ap is still 0 , and the grad_norm keeps growing after a brief decrease.

Plus, I found the input image is not square after using " pad, size_divisor = 32", so I changed it to "pad, pad_to_square = True". While, nothing changed.

soHardToHaveAName avatar Apr 14 '23 12:04 soHardToHaveAName

Did you solve this problem?

Have you solved it ? I meet the same problem, using HRSC2016 dataset, the training process was interupted after 47 epochs( I set it to 140 epochs totally) , and the precision is very low( only 0.008) I noticed that my learning rate might have been causing some issues, so I switched from SGD to AdamW on version 0.3.2 , and everything is working fine on version 1.x. I hope it was helpful to you.

Thanks. I am already using Adam, But the ap is still 0 , and the grad_norm keeps growing after a brief decrease.

Plus, I found the input image is not square after using " pad, size_divisor = 32", so I changed it to "pad, pad_to_square = True". While, nothing changed. Have you tried using mmrotate 1.x?

sudoAimer avatar Apr 14 '23 14:04 sudoAimer

Did you solve this problem?

Have you solved it ? I meet the same problem, using HRSC2016 dataset, the training process was interupted after 47 epochs( I set it to 140 epochs totally) , and the precision is very low( only 0.008) I noticed that my learning rate might have been causing some issues, so I switched from SGD to AdamW on version 0.3.2 , and everything is working fine on version 1.x. I hope it was helpful to you.

Thanks. I am already using Adam, But the ap is still 0 , and the grad_norm keeps growing after a brief decrease. Plus, I found the input image is not square after using " pad, size_divisor = 32", so I changed it to "pad, pad_to_square = True". While, nothing changed. Have you tried using mmrotate 1.x?

Hi! I tried mmrotate 1.x today, but that didn't solve the problem still.

The model I'm using is Rotated Centernet ( I changed CenterNet to Rotated CenterNet)

Here are the config file and model file :

config file:

default_scope = 'mmrotate'
angle_version = 'le90'

model = dict(
    # type='RCenterNet',
    type='mmdet.CenterNet',
    data_preprocessor=dict(
        type='mmdet.DetDataPreprocessor',
        mean=[123.675, 116.28, 103.53],
        std=[58.395, 57.12, 57.375],
        bgr_to_rgb=True),
    backbone=dict(
        type='mmdet.ResNet',
        depth=18,
        norm_eval=False,
        norm_cfg=dict(type='BN'),
        init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet18')),
    neck=dict(
        type='mmdet.CTResNetNeck',
        in_channels=512,
        num_deconv_filters=(256, 128, 64),
        num_deconv_kernels=(4, 4, 4),
        use_dcn=True),
    bbox_head=dict(
        type='RCenterNetHead',
        # num_classes=15,
        num_classes=1,
        in_channels=64,
        feat_channels=64,
        loss_center_heatmap=dict(type='mmdet.GaussianFocalLoss', loss_weight=1.0),
        loss_wh=dict(type='mmdet.L1Loss', loss_weight=0.1),
        loss_offset=dict(type='mmdet.L1Loss', loss_weight=1.0),
        loss_angle=dict(type='mmdet.L1Loss', loss_weight=0.1)),
    train_cfg=None,
    # test_cfg=dict(topk=100, local_maximum_kernel=3, max_per_img=100))
    test_cfg=dict(
        show_points=False,
        topk=200, 
        local_maximum_kernel=3, 
        max_per_img=200,
        score_thr=0.3,
        nms=None))
        # nms=dict(iou_thr=0.2)))



#========================================== dataset ========================================== :
# dataset settings
dataset_type = 'HRSCDataset'
data_root = '/DATA/HRSC2016_dataset/HRSC2016/'
# dataset_type = 'DOTADataset'
# data_root = 'data/split_ss_dota/'
backend_args = None

train_pipeline = [
    dict(type='mmdet.LoadImageFromFile', backend_args=backend_args),
    dict(type='mmdet.LoadAnnotations', with_bbox=True, box_type='qbox'),
    dict(type='ConvertBoxType', box_type_mapping=dict(gt_bboxes='rbox')),
    # dict(type='mmdet.Resize', scale=(1024, 1024), keep_ratio=True),
    dict(type='mmdet.Resize', scale=(512, 512), keep_ratio=True),
    # dict(type='mmdet.Resize', scale=(512, 512), keep_ratio=True),
    dict(type='Pad', pad_to_square=True),
    dict(
        type='mmdet.RandomFlip',
        prob=0.75,
        direction=['horizontal', 'vertical', 'diagonal']),
    dict(type='mmdet.PackDetInputs')
]
val_pipeline = [
    dict(type='mmdet.LoadImageFromFile', backend_args=backend_args),
    # dict(type='mmdet.Resize', scale=(1024, 1024), keep_ratio=True),
    dict(type='mmdet.Resize', scale=(512, 512), keep_ratio=True),
    dict(type='Pad', pad_to_square=True),
    # avoid bboxes being resized
    dict(type='mmdet.LoadAnnotations', with_bbox=True, box_type='qbox'),
    dict(type='ConvertBoxType', box_type_mapping=dict(gt_bboxes='rbox')),
    dict(
        type='mmdet.PackDetInputs',
        meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
                   'scale_factor'))
]
test_pipeline = [
    dict(type='mmdet.LoadImageFromFile', backend_args=backend_args),
    # dict(type='mmdet.Resize', scale=(1024, 1024), keep_ratio=True),
    dict(type='mmdet.Resize', scale=(512, 512), keep_ratio=True),
    dict(type='Pad', pad_to_square=True),
    dict(
        type='mmdet.PackDetInputs',
        meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
                   'scale_factor'))
]
train_dataloader = dict(
    # batch_size=2,
    # num_workers=2,
    batch_size=16,
    num_workers=4,
    persistent_workers=True,
    sampler=dict(type='DefaultSampler', shuffle=True),
    batch_sampler=None,
    dataset=dict(
        type=dataset_type,
        data_root=data_root,
        ann_file='ImageSets/trainval.txt',
        data_prefix=dict(sub_data_root='FullDataSet/'),
        filter_cfg=dict(filter_empty_gt=True),
        pipeline=train_pipeline))
val_dataloader = dict(
    # batch_size=1,
    # num_workers=2,
    batch_size=16,
    num_workers=4,
    persistent_workers=True,
    drop_last=False,
    sampler=dict(type='DefaultSampler', shuffle=False),
    dataset=dict(
        type=dataset_type,
        data_root=data_root,
        ann_file='ImageSets/test.txt',
        data_prefix=dict(sub_data_root='FullDataSet/'),
        test_mode=True,
        pipeline=val_pipeline))
test_dataloader = val_dataloader

val_evaluator = dict(type='DOTAMetric', metric='mAP')
test_evaluator = val_evaluator












#========================================== runtime ========================================== :

default_scope = 'mmrotate'

default_hooks = dict(
    timer=dict(type='IterTimerHook'),
    # logger=dict(type='LoggerHook', interval=50),
    logger=dict(type='LoggerHook', interval=10),
    param_scheduler=dict(type='ParamSchedulerHook'),
    checkpoint=dict(type='CheckpointHook', interval=10),
    sampler_seed=dict(type='DistSamplerSeedHook'),
    visualization=dict(type='mmdet.DetVisualizationHook'))

env_cfg = dict(
    cudnn_benchmark=False,
    mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
    dist_cfg=dict(backend='nccl'),
)

vis_backends = [
    dict(type='LocalVisBackend'), 
    dict(type='TensorboardVisBackend')]
visualizer = dict(
    type='RotLocalVisualizer', vis_backends=vis_backends, name='visualizer')
log_processor = dict(type='LogProcessor', window_size=50, by_epoch=True)

log_level = 'INFO'
load_from = None
resume = False







#========================================== training schedule ========================================== :

# optimizer
# Based on the default settings of modern detectors, the SGD effect is better
# than the Adam in the source code, so we use SGD default settings and
# if you use adam+lr5e-4, the map is 29.1.
# optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
# optimizer = dict(type='Adam', lr=5e-4, weight_decay=0.0001) # 8 GPU * 16 img
optim_wrapper = dict(
    type='OptimWrapper',
    # optimizer=dict(type='SGD', lr=0.0025, momentum=0.9, weight_decay=0.0001),
    optimizer=dict(type='Adam', lr=5e-4 / 8, weight_decay=0.0001),# 1 GPU * 16 img
    clip_grad=dict(max_norm=35, norm_type=2))
# optim_wrapper = dict(clip_grad=dict(max_norm=35, norm_type=2))

# max_epochs = 28
max_epochs = 140
# learning policy
# Based on the default settings of modern detectors, we added warmup settings.

# learning rate
param_scheduler = [
    dict(
        type='LinearLR', start_factor=0.001, by_epoch=False, begin=0,
        end=1000),
    dict(
        type='MultiStepLR',
        begin=0,
        end=max_epochs,
        by_epoch=True,
        # milestones=[18, 24],  # the real step is [18*5, 24*5]
        milestones=[90, 120],  # the real step is [18*5, 24*5]
        gamma=0.1)
]

train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=max_epochs, val_interval=1)
# train_cfg = dict(max_epochs=max_epochs)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')


model file:

# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Tuple

import torch
import torch.nn as nn
from mmcv.ops import batched_nms
from mmengine.config import ConfigDict
from mmengine.model import bias_init_with_prob, normal_init
from mmengine.structures import InstanceData
from torch import Tensor

from mmrotate.registry import MODELS
from mmdet.utils import (ConfigType, InstanceList, OptConfigType,
                         OptInstanceList, OptMultiConfig)
from mmdet.models.utils import (gaussian_radius, gen_gaussian_target, get_local_maximum,
                     get_topk_from_heatmap, multi_apply,
                     transpose_and_gather_feat)
from mmdet.models.dense_heads.base_dense_head import BaseDenseHead
from mmdet.models.dense_heads import CenterNetHead

from mmrotate.structures.bbox import RotatedBoxes, qbox2rbox

@MODELS.register_module()
class RCenterNetHead(CenterNetHead):
    """Objects as Points Head. CenterHead use center_point to indicate object's
    position. Paper link <https://arxiv.org/abs/1904.07850>

    Args:
        in_channels (int): Number of channel in the input feature map.
        feat_channels (int): Number of channel in the intermediate feature map.
        num_classes (int): Number of categories excluding the background
            category.
        loss_center_heatmap (:obj:`ConfigDict` or dict): Config of center
            heatmap loss. Defaults to
            dict(type='GaussianFocalLoss', loss_weight=1.0)
        loss_wh (:obj:`ConfigDict` or dict): Config of wh loss. Defaults to
             dict(type='L1Loss', loss_weight=0.1).
        loss_offset (:obj:`ConfigDict` or dict): Config of offset loss.
            Defaults to dict(type='L1Loss', loss_weight=1.0).
        train_cfg (:obj:`ConfigDict` or dict, optional): Training config.
            Useless in CenterNet, but we keep this variable for
            SingleStageDetector.
        test_cfg (:obj:`ConfigDict` or dict, optional): Testing config
            of CenterNet.
        init_cfg (:obj:`ConfigDict` or dict or list[dict] or
            list[:obj:`ConfigDict`], optional): Initialization
            config dict.
    """

    def __init__(self,
                 in_channels: int,
                 feat_channels: int,
                 num_classes: int,
                 loss_center_heatmap: ConfigType = dict(
                     type='mmdet.GaussianFocalLoss',
                     loss_weight=1.0),
                 loss_wh: ConfigType = dict(type='mmdet.L1Loss', loss_weight=0.1),
                 loss_offset: ConfigType = dict(
                     type='mmdet.L1Loss', loss_weight=1.0),
                 loss_angle: ConfigType = dict(type='mmdet.L1Loss', loss_weight=0.1),
                 train_cfg: OptConfigType = None,
                 test_cfg: OptConfigType = None,
                 init_cfg: OptMultiConfig = None) -> None:
        super().__init__(
            init_cfg=init_cfg,
            in_channels=in_channels,
            feat_channels=feat_channels,
            num_classes=num_classes,
            loss_center_heatmap = loss_center_heatmap,
            loss_wh=loss_wh,
            loss_offset=loss_offset,
            )
        self.num_classes = num_classes
        self.heatmap_head = self._build_head(in_channels, feat_channels,
                                             num_classes)
        self.wh_head = self._build_head(in_channels, feat_channels, 2)
        self.offset_head = self._build_head(in_channels, feat_channels, 2)
        self.angle_head = self._build_head(in_channels, feat_channels, 1)

        self.loss_center_heatmap = MODELS.build(loss_center_heatmap)
        self.loss_wh = MODELS.build(loss_wh)
        self.loss_offset = MODELS.build(loss_offset)
        self.loss_angle = MODELS.build(loss_angle)

        self.train_cfg = train_cfg
        self.test_cfg = test_cfg
        self.fp16_enabled = False

    def _build_head(self, in_channels: int, feat_channels: int,
                    out_channels: int) -> nn.Sequential:
        """Build head for each branch."""
        layer = nn.Sequential(
            nn.Conv2d(in_channels, feat_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(feat_channels, out_channels, kernel_size=1))
        return layer

    def init_weights(self) -> None:
        """Initialize weights of the head."""
        bias_init = bias_init_with_prob(0.1)
        self.heatmap_head[-1].bias.data.fill_(bias_init)
        for head in [self.wh_head, self.offset_head, self.angle_head]:
            for m in head.modules():
                if isinstance(m, nn.Conv2d):
                    normal_init(m, std=0.001)

    def forward(self, x: Tuple[Tensor, ...]) -> Tuple[List[Tensor]]:
        """Forward features. Notice CenterNet head does not use FPN.

        Args:
            x (tuple[Tensor]): Features from the upstream network, each is
                a 4D-tensor.

        Returns:
            center_heatmap_preds (list[Tensor]): center predict heatmaps for
                all levels, the channels number is num_classes.
            wh_preds (list[Tensor]): wh predicts for all levels, the channels
                number is 2.
            offset_preds (list[Tensor]): offset predicts for all levels, the
               channels number is 2.
        """
        return multi_apply(self.forward_single, x)

    def forward_single(self, x: Tensor) -> Tuple[Tensor, ...]:
        """Forward feature of a single level.

        Args:
            x (Tensor): Feature of a single level.

        Returns:
            center_heatmap_pred (Tensor): center predict heatmaps, the
               channels number is num_classes.
            wh_pred (Tensor): wh predicts, the channels number is 2.
            offset_pred (Tensor): offset predicts, the channels number is 2.
        """
        center_heatmap_pred = self.heatmap_head(x).sigmoid()
        wh_pred = self.wh_head(x)
        offset_pred = self.offset_head(x)
        angle_pred = self.angle_head(x)
        return center_heatmap_pred, wh_pred, offset_pred, angle_pred

    def loss_by_feat(
            self,
            center_heatmap_preds: List[Tensor],
            wh_preds: List[Tensor],
            offset_preds: List[Tensor],
            angle_preds: List[Tensor],
            batch_gt_instances: InstanceList,
            batch_img_metas: List[dict],
            batch_gt_instances_ignore: OptInstanceList = None) -> dict:
        """Compute losses of the head.

        Args:
            center_heatmap_preds (list[Tensor]): center predict heatmaps for
               all levels with shape (B, num_classes, H, W).
            wh_preds (list[Tensor]): wh predicts for all levels with
               shape (B, 2, H, W).
            offset_preds (list[Tensor]): offset predicts for all levels
               with shape (B, 2, H, W).
            batch_gt_instances (list[:obj:`InstanceData`]): Batch of
                gt_instance. It usually includes ``bboxes`` and ``labels``
                attributes.
            batch_img_metas (list[dict]): Meta information of each image, e.g.,
                image size, scaling factor, etc.
            batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
                Batch of gt_instances_ignore. It includes ``bboxes`` attribute
                data that is ignored during training and testing.
                Defaults to None.

        Returns:
            dict[str, Tensor]: which has components below:
                - loss_center_heatmap (Tensor): loss of center heatmap.
                - loss_wh (Tensor): loss of hw heatmap
                - loss_offset (Tensor): loss of offset heatmap.
        """
        assert len(center_heatmap_preds) == len(wh_preds) == len(
            offset_preds) == 1
        center_heatmap_pred = center_heatmap_preds[0]
        wh_pred = wh_preds[0]
        offset_pred = offset_preds[0]
        angle_pred = angle_preds[0]

        gt_bboxes = [
            gt_instances.bboxes for gt_instances in batch_gt_instances
        ]
        gt_labels = [
            gt_instances.labels for gt_instances in batch_gt_instances
        ]
        img_shape = batch_img_metas[0]['batch_input_shape']
        target_result, avg_factor = self.get_targets(gt_bboxes, gt_labels,
                                                     center_heatmap_pred.shape,
                                                     img_shape)

        center_heatmap_target = target_result['center_heatmap_target']
        wh_target = target_result['wh_target']
        offset_target = target_result['offset_target']
        angle_target = target_result['angle_target']
        wh_offset_target_weight = target_result['wh_offset_target_weight']

        # Since the channel of wh_target and offset_target is 2, the avg_factor
        # of loss_center_heatmap is always 1/2 of loss_wh and loss_offset.
        loss_center_heatmap = self.loss_center_heatmap(
            center_heatmap_pred, center_heatmap_target, avg_factor=avg_factor)
        loss_wh = self.loss_wh(
            wh_pred,
            wh_target,
            wh_offset_target_weight,
            avg_factor=avg_factor * 2)
        loss_offset = self.loss_offset(
            offset_pred,
            offset_target,
            wh_offset_target_weight,
            avg_factor=avg_factor * 2)
        loss_angle = self.loss_angle(
            angle_pred,
            angle_target,
            wh_offset_target_weight,
            avg_factor=avg_factor)
        return dict(
            loss_center_heatmap=loss_center_heatmap,
            loss_wh=loss_wh,
            loss_offset=loss_offset,
            loss_angle=loss_angle)

    def get_targets(self, gt_bboxes: List[Tensor], gt_labels: List[Tensor],
                    feat_shape: tuple, img_shape: tuple) -> Tuple[dict, int]:
        """Compute regression and classification targets in multiple images.

        Args:
            gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
                shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
            gt_labels (list[Tensor]): class indices corresponding to each box.
            feat_shape (tuple): feature map shape with value [B, _, H, W]
            img_shape (tuple): image shape.

        Returns:
            tuple[dict, float]: The float value is mean avg_factor, the dict
            has components below:
               - center_heatmap_target (Tensor): targets of center heatmap, \
                   shape (B, num_classes, H, W).
               - wh_target (Tensor): targets of wh predict, shape \
                   (B, 2, H, W).
               - offset_target (Tensor): targets of offset predict, shape \
                   (B, 2, H, W).
               - wh_offset_target_weight (Tensor): weights of wh and offset \
                   predict, shape (B, 2, H, W).
        """
        img_h, img_w = img_shape[:2]
        bs, _, feat_h, feat_w = feat_shape

        width_ratio = float(feat_w / img_w)
        height_ratio = float(feat_h / img_h)

        center_heatmap_target = gt_bboxes[-1].new_zeros(
            [bs, self.num_classes, feat_h, feat_w])
        wh_target = gt_bboxes[-1].new_zeros([bs, 2, feat_h, feat_w])
        center_offset_target = gt_bboxes[-1].new_zeros([bs, 2, feat_h, feat_w])
        wh_offset_target_weight = gt_bboxes[-1].new_zeros(
            [bs, 2, feat_h, feat_w])
        angle_target = gt_bboxes[-1].new_zeros([bs, 1, feat_h, feat_w])

        for batch_id in range(bs):
            gt_bbox = gt_bboxes[batch_id]
            gt_label = gt_labels[batch_id]
            # center_x = (gt_bbox[:, [0]] + gt_bbox[:, [2]]) * width_ratio / 2
            # center_y = (gt_bbox[:, [1]] + gt_bbox[:, [3]]) * height_ratio / 2
            center_x = gt_bbox[:, [0]] * width_ratio
            center_y = gt_bbox[:, [1]] * height_ratio
            gt_centers = torch.cat((center_x, center_y), dim=1)

            for j, ct in enumerate(gt_centers):
                ctx_int, cty_int = ct.int()
                ctx, cty = ct
                # scale_box_h = (gt_bbox[j][3] - gt_bbox[j][1]) * height_ratio
                # scale_box_w = (gt_bbox[j][2] - gt_bbox[j][0]) * width_ratio
                gt_angle = gt_bbox[j][4]

                rotated_height_ratio = height_ratio
                rotated_width_ratio = width_ratio
                # rotated_height_ratio = torch.sqrt((height_ratio**2 + torch.tan(gt_angle)**2 * width_ratio**2)
                #                                   /(1 + torch.tan(gt_angle)**2))
                # rotated_width_ratio = torch.sqrt((width_ratio**2 + torch.tan(gt_angle)**2 * height_ratio**2)
                #                                   /(1 + torch.tan(gt_angle)**2))
                scale_box_h = gt_bbox[j][3] * rotated_height_ratio
                scale_box_w = gt_bbox[j][2] * rotated_width_ratio
                radius = gaussian_radius([scale_box_h, scale_box_w],
                                         min_overlap=0.3)
                radius = max(0, int(radius))
                ind = gt_label[j]
                gen_gaussian_target(center_heatmap_target[batch_id, ind],
                                    [ctx_int, cty_int], radius)

                wh_target[batch_id, 0, cty_int, ctx_int] = scale_box_w
                wh_target[batch_id, 1, cty_int, ctx_int] = scale_box_h

                angle_target[batch_id, 0, cty_int, ctx_int] = gt_angle
                center_offset_target[batch_id, 0, cty_int, ctx_int] = ctx - ctx_int
                center_offset_target[batch_id, 1, cty_int, ctx_int] = cty - cty_int

                wh_offset_target_weight[batch_id, :, cty_int, ctx_int] = 1

        avg_factor = max(1, center_heatmap_target.eq(1).sum())
        target_result = dict(
            center_heatmap_target=center_heatmap_target,
            wh_target=wh_target,
            offset_target=center_offset_target,
            angle_target=angle_target,
            wh_offset_target_weight=wh_offset_target_weight)
        return target_result, avg_factor

    def predict_by_feat(self,
                        center_heatmap_preds: List[Tensor],
                        wh_preds: List[Tensor],
                        offset_preds: List[Tensor],
                        angle_preds: List[Tensor],
                        batch_img_metas: Optional[List[dict]] = None,
                        rescale: bool = True,
                        with_nms: bool = False) -> InstanceList:
        """Transform network output for a batch into bbox predictions.

        Args:
            center_heatmap_preds (list[Tensor]): Center predict heatmaps for
                all levels with shape (B, num_classes, H, W).
            wh_preds (list[Tensor]): WH predicts for all levels with
                shape (B, 2, H, W).
            offset_preds (list[Tensor]): Offset predicts for all levels
                with shape (B, 2, H, W).
            batch_img_metas (list[dict], optional): Batch image meta info.
                Defaults to None.
            rescale (bool): If True, return boxes in original image space.
                Defaults to True.
            with_nms (bool): If True, do nms before return boxes.
                Defaults to False.

        Returns:
            list[:obj:`InstanceData`]: Instance segmentation
            results of each image after the post process.
            Each item usually contains following keys.

                - scores (Tensor): Classification scores, has a shape
                  (num_instance, )
                - labels (Tensor): Labels of bboxes, has a shape
                  (num_instances, ).
                - bboxes (Tensor): Has a shape (num_instances, 4),
                  the last dimension 4 arrange as (x1, y1, x2, y2).
        """
        assert len(center_heatmap_preds) == len(wh_preds) == len(
            offset_preds) == 1
        result_list = []
        for img_id in range(len(batch_img_metas)):
            result_list.append(
                self._predict_by_feat_single(
                    center_heatmap_preds[0][img_id:img_id + 1, ...],
                    wh_preds[0][img_id:img_id + 1, ...],
                    offset_preds[0][img_id:img_id + 1, ...],
                    angle_preds[0][img_id:img_id + 1, ...],
                    batch_img_metas[img_id],
                    rescale=rescale,
                    with_nms=with_nms))
        return result_list

    def _predict_by_feat_single(self,
                                center_heatmap_pred: Tensor,
                                wh_pred: Tensor,
                                offset_pred: Tensor,
                                angle_pred: Tensor,
                                img_meta: dict,
                                rescale: bool = True,
                                with_nms: bool = False) -> InstanceData:
        """Transform outputs of a single image into bbox results.

        Args:
            center_heatmap_pred (Tensor): Center heatmap for current level with
                shape (1, num_classes, H, W).
            wh_pred (Tensor): WH heatmap for current level with shape
                (1, num_classes, H, W).
            offset_pred (Tensor): Offset for current level with shape
                (1, corner_offset_channels, H, W).
            img_meta (dict): Meta information of current image, e.g.,
                image size, scaling factor, etc.
            rescale (bool): If True, return boxes in original image space.
                Defaults to True.
            with_nms (bool): If True, do nms before return boxes.
                Defaults to False.

        Returns:
            :obj:`InstanceData`: Detection results of each image
            after the post process.
            Each item usually contains following keys.

                - scores (Tensor): Classification scores, has a shape
                  (num_instance, )
                - labels (Tensor): Labels of bboxes, has a shape
                  (num_instances, ).
                - bboxes (Tensor): Has a shape (num_instances, 4),
                  the last dimension 4 arrange as (x1, y1, x2, y2).
        """
        batch_det_bboxes, batch_labels, batch_points = self._decode_heatmap(
            center_heatmap_pred,
            wh_pred,
            offset_pred,
            angle_pred,
            img_meta['batch_input_shape'],
            k=self.test_cfg.topk,
            kernel=self.test_cfg.local_maximum_kernel)

        # det_bboxes = batch_det_bboxes.view([-1, 5])
        det_bboxes = batch_det_bboxes.view([-1, 6])
        det_labels = batch_labels.view(-1)
        det_points = batch_points

        # batch_border = det_bboxes.new_tensor(img_meta['border'])[...,
                                                                #  [2, 0, 2, 0]]
        # det_bboxes[..., :4] -= batch_border

        if rescale and 'scale_factor' in img_meta:
            det_bboxes[..., :4] /= det_bboxes.new_tensor(
                img_meta['scale_factor']).repeat((1, 2))
            for i in range(len(det_points)):
                det_points[i][...,:2] /= det_points[0].new_tensor(img_meta['scale_factor'])[0]
        if self.test_cfg.score_thr is not None:
            score_indx = det_bboxes[:,-1] > self.test_cfg.score_thr
            det_bboxes = det_bboxes[score_indx]
            det_labels = det_labels[score_indx]
        if with_nms:
            det_bboxes, det_labels = self._bboxes_nms(det_bboxes, det_labels,
                                                      self.test_cfg)
        results = InstanceData()

        results.bboxes = RotatedBoxes(det_bboxes[..., :5])
        results.scores = det_bboxes[..., 5]
        results.labels = det_labels
        # results.points = det_points
        
        return results

    def _decode_heatmap(self,
                        center_heatmap_pred: Tensor,
                        wh_pred: Tensor,
                        offset_pred: Tensor,
                        angle_pred: Tensor,
                        img_shape: tuple,
                        k: int = 100,
                        kernel: int = 3) -> Tuple[Tensor, Tensor]:
        """Transform outputs into detections raw bbox prediction.

        Args:
            center_heatmap_pred (Tensor): center predict heatmap,
               shape (B, num_classes, H, W).
            wh_pred (Tensor): wh predict, shape (B, 2, H, W).
            offset_pred (Tensor): offset predict, shape (B, 2, H, W).
            img_shape (tuple): image shape in hw format.
            k (int): Get top k center keypoints from heatmap. Defaults to 100.
            kernel (int): Max pooling kernel for extract local maximum pixels.
               Defaults to 3.

        Returns:
            tuple[Tensor]: Decoded output of CenterNetHead, containing
               the following Tensors:

              - batch_bboxes (Tensor): Coords of each box with shape (B, k, 5)
              - batch_topk_labels (Tensor): Categories of each box with \
                  shape (B, k)
        """
        height, width = center_heatmap_pred.shape[2:]
        inp_h, inp_w = img_shape

        center_heatmap_pred = get_local_maximum(
            center_heatmap_pred, kernel=kernel)

        *batch_dets, topk_ys, topk_xs = get_topk_from_heatmap(
            center_heatmap_pred, k=k)
        batch_scores, batch_index, batch_topk_labels = batch_dets

        ctr_points = torch.cat((topk_xs, topk_ys, batch_topk_labels, batch_scores), dim=0).permute(1,0)
        wh = transpose_and_gather_feat(wh_pred, batch_index)
        offset = transpose_and_gather_feat(offset_pred, batch_index)
        angle = transpose_and_gather_feat(angle_pred, batch_index)
        topk_xs = topk_xs + offset[..., 0]
        topk_ys = topk_ys + offset[..., 1]

        # ctr_points = torch.cat((topk_xs, topk_ys, batch_topk_labels, batch_scores), dim=0).permute(1,0)
        
        
        width_ratio = width / inp_w
        height_ratio = height / inp_h
        # rotated_height_ratio = torch.sqrt((height_ratio**2 + torch.tan(angle)**2 * width_ratio**2)
        #                                           /(1 + torch.tan(angle)**2))
        # rotated_width_ratio = torch.sqrt((width_ratio**2 + torch.tan(angle)**2 * height_ratio**2)
        #                                           /(1 + torch.tan(angle)**2))
        rotated_height_ratio = height_ratio
        rotated_width_ratio = width_ratio
        x_ctr = topk_xs / width_ratio
        y_ctr = topk_ys / height_ratio
        w = wh[..., 0] / rotated_width_ratio
        h = wh[..., 1] / rotated_height_ratio
        angle = angle[...,0]

        # tl_x = (topk_xs - wh[..., 0] / 2) * (inp_w / width)
        # tl_y = (topk_ys - wh[..., 1] / 2) * (inp_h / height)
        # br_x = (topk_xs + wh[..., 0] / 2) * (inp_w / width)
        # br_y = (topk_ys + wh[..., 1] / 2) * (inp_h / height)

        # batch_bboxes = torch.stack([tl_x, tl_y, br_x, br_y], dim=2)
        batch_bboxes = torch.stack([x_ctr, y_ctr, w, h, angle], dim=2)
        batch_bboxes = torch.cat((batch_bboxes, batch_scores[..., None]),
                                 dim=-1)
        
        scores_thresh = 0.3
        ctr_points = ctr_points[ ctr_points[:,-1] > scores_thresh]
        ctr_points[...,:2] /= width_ratio
        points = [ctr_points]

        return batch_bboxes, batch_topk_labels, points

    def _bboxes_nms(self, bboxes: Tensor, labels: Tensor,
                    cfg: ConfigDict) -> Tuple[Tensor, Tensor]:
        """bboxes nms."""
        if labels.numel() > 0:
            max_num = cfg.max_per_img
            bboxes, keep = batched_nms(bboxes[:, :4], bboxes[:,
                                                             -1].contiguous(),
                                       labels, cfg.nms)
            if max_num > 0:
                bboxes = bboxes[:max_num]
                labels = labels[keep][:max_num]

        return bboxes, labels

soHardToHaveAName avatar Apr 17 '23 09:04 soHardToHaveAName

Hi @soHardToHaveAName, did you ever manage to train the RCenterNet model?

xavibou avatar May 03 '24 13:05 xavibou