PyCenterNet
PyCenterNet copied to clipboard
AttributeError: 'PyCenterNetHead' object has no attribute 'init_assigner'
Hi
I'm trying to train the PyCenterNet++ with the "[R2-101]-DCN" backbone, below is the error message and environmental details, for reference:
----------------------- environmental details ----------------------- torch.version :- 1.4.0 torch.cuda.is_available() :- True mmdet.version :- 2.11.0 get_compiling_cuda_version() :- 10.1 get_compiler_version() :- GCC 7.5
-----------------------
error message
-----------------------
Hi,
Here is the configuration parameter and code snippet, for reference:
----------------------- training reference snippet -----------------------
from mmdet.datasets import build_dataset from mmdet.models import build_detector from mmdet.apis import train_detector from mmcv import Config import mmcv import os.path as osp
cfg = Config.fromfile('/PyCenterNet/code/configs/pycenternet/pycenternet_res2_101_fpn_dconv_c3-c5_giou_mstrain_2x_coco.py')
cfg.model.pretrained = '/PyCenterNet/code/mmcv/checkpoints/res2net101_v1d_26w_4s_mmdetv2-f0a600f9.pth'
cfg.dataset_type = 'COCODataset'
#The original learning rate (LR) is set for 8-GPU training. #We divide it by 8 since we only use one GPU. cfg.optimizer.lr = 0.02 / 8 cfg.lr_config.warmup = None cfg.log_config.interval = 10
#We can set the evaluation interval to reduce the evaluation times cfg.evaluation.interval = 18 #We can set the checkpoint saving interval to reduce the storage cost cfg.checkpoint_config.interval = 18
#Set seed thus the results are more reproducible cfg.seed = 0 set_random_seed(0, deterministic=False) cfg.gpu_ids = range(1)
#We can also use tensorboard to log the training process cfg.log_config.hooks = [ dict(type='TextLoggerHook'), dict(type='TensorboardLoggerHook')]
#We can initialize the logger for training and have a look #at the final config used for training print(f'Config:\n{cfg.pretty_text}')
cfg.device="cuda" #Build dataset print("cfg.data.train", cfg.data.train) datasets = [build_dataset(cfg.data.train)] print(datasets) #Build the detector model = build_detector(cfg.model)
#Add an attribute for visualization convenience model.CLASSES = datasets[0].CLASSES print("osp.abspath(cfg.work_dir)", osp.abspath(cfg.work_dir)) #Create work_dir mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
train_detector(model, datasets, cfg, distributed=False, validate=True)