mmtracking
mmtracking copied to clipboard
Is it possible to train a MOT model on one class of LaSOT dataset?
Hello, I want to make a model that can track multiple drones in a video using only drone data from LaSOT dataset. I'd appreciate it if you let me know if this is a possible idea in the first place.
Preliminaries
- I downloaded only the drone folder from the LaSOT dataset and made it into CocoVideoDataset.
- The data format is as follows.
data
ㄴ lasot
ㄴ annotations
ㄴ lasot_train.json
ㄴ lasot_test.json
ㄴ lasot_train_infos.txt
ㄴ lasot_test_infos.txt
ㄴ LaSOTBenchmark
ㄴ drone
ㄴ drone-1
ㄴ drone-2
ㄴ ....
- Created a custom config in configs/mot/bytebrack. base also uses two custom modules
'../../_base_/datasets/lasot_drone.py'
'../../_base_/models/custom_yolox_x_8x8.py'
. - Changed to num_classes=1, classes=('drone', ) in all configuration files.
===
Error
When training starts, after normal training for only 1 epoch, data['category_id'] = self.cat_ids[label]
error occurs in the evaluation process.
I've looked all over for similar issues in mmDetection, but they didn't solve it.
Do you know any solution?
my config
from mmtrack import datasets
_base_ = [
'../../_base_/models/custom_yolox_x_8x8.py',
'../../_base_/datasets/lasot_drone.py', '../../_base_/default_runtime.py'
]
img_scale = (1280, 720)
samples_per_gpu = 2
"""
init_cfg=dict(
type='Pretrained',
checkpoint= # noqa: E251
'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_x_8x8_300e_coco/yolox_x_8x8_300e_coco_20211126_140254-1ef88d67.pth' # noqa: E501
)),
"""
model = dict(
type='ByteTrack',
detector=dict(
input_size=img_scale,
random_size_range=(18, 32),
bbox_head=dict(num_classes=1),
test_cfg=dict(score_thr=0.01, nms=dict(type='nms', iou_threshold=0.7)),
init_cfg=dict(
type='Pretrained',
checkpoint= # noqa: E251
'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_x_8x8_300e_coco/yolox_x_8x8_300e_coco_20211126_140254-1ef88d67.pth' # noqa: E501
)),
motion=dict(type='KalmanFilter'),
tracker=dict(
type='ByteTracker',
obj_score_thrs=dict(high=0.6, low=0.1),
init_track_thr=0.7,
weight_iou_with_det_scores=True,
match_iou_thrs=dict(high=0.1, low=0.5, tentative=0.3),
num_frames_retain=30))
train_pipeline = [
dict(
type='Mosaic',
img_scale=img_scale,
pad_val=114.0,
bbox_clip_border=False),
dict(
type='RandomAffine',
scaling_ratio_range=(0.1, 2),
border=(-img_scale[0] // 2, -img_scale[1] // 2),
bbox_clip_border=False),
dict(
type='MixUp',
img_scale=img_scale,
ratio_range=(0.8, 1.6),
pad_val=114.0,
bbox_clip_border=False),
dict(type='YOLOXHSVRandomAug'),
dict(type='RandomFlip', flip_ratio=0.5),
dict(
type='Resize',
img_scale=img_scale,
keep_ratio=True,
bbox_clip_border=False),
dict(type='Pad', size_divisor=32, pad_val=dict(img=(114.0, 114.0, 114.0))),
dict(type='FilterAnnotations', min_gt_bbox_wh=(1, 1), keep_empty=False),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=img_scale,
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(
type='Normalize',
mean=[0.0, 0.0, 0.0],
std=[1.0, 1.0, 1.0],
to_rgb=False),
dict(
type='Pad',
size_divisor=32,
pad_val=dict(img=(114.0, 114.0, 114.0))),
dict(type='ImageToTensor', keys=['img']),
dict(type='VideoCollect', keys=['img'])
])
]
# musma_custom_dataset
# dataset_type = "CocoVideoDataset"
"""
ref_img_sampler=dict(
num_ref_imgs=1,
frame_range=10,
filter_key_img=True,
method='uniform'),
"""
# dataset_type = "LaSOTDataset"
dataset_type = "CocoVideoDataset"
classes = ('drone',)
data_root = 'data/lasot/'
data = dict(
samples_per_gpu=2,
workers_per_gpu=2,
train=dict(
type=dataset_type,
classes=classes,
ann_file=data_root + "annotations/lasot_train.json",
img_prefix=data_root + "LaSOTBenchmark",
pipeline=train_pipeline),
val=dict(
type=dataset_type,
classes=classes,
ann_file=data_root + "annotations/lasot_test.json",
img_prefix=data_root + "LaSOTBenchmark",
pipeline=test_pipeline),
test=dict(
type=dataset_type,
classes=classes,
ann_file=data_root + "annotations/lasot_test.json",
img_prefix=data_root + "LaSOTBenchmark",
pipeline=test_pipeline))
# optimizer
# default 8 gpu
optimizer = dict(
type='SGD',
lr=0.0001 / 2 * samples_per_gpu,
momentum=0.9,
# 5e-4
weight_decay=0.0001,
nesterov=True,
paramwise_cfg=dict(norm_decay_mult=0.0, bias_decay_mult=0.0))
optimizer_config=dict(_delete_=True, grad_clip=dict(max_norm=35, norm_type=2))
# some hyper parameters
total_epochs = 80
num_last_epochs = 10
resume_from = None
interval = 5
# learning policy
lr_config = dict(
policy='YOLOX',
warmup='exp',
by_epoch=False,
warmup_by_epoch=True,
warmup_ratio=0.5,
warmup_iters=2000,
num_last_epochs=num_last_epochs,
min_lr_ratio=0.05)
custom_hooks = [
dict(
type='YOLOXModeSwitchHook',
num_last_epochs=num_last_epochs,
priority=48),
dict(
type='SyncNormHook',
num_last_epochs=num_last_epochs,
interval=interval,
priority=48),
dict(
type='ExpMomentumEMAHook',
resume_from=resume_from,
momentum=0.0001,
priority=49)
]
checkpoint_config = dict(interval=1)
evaluation = dict(metric=['bbox', 'track'], interval=1)
search_metrics = ['MOTA', 'IDF1', 'FN', 'FP', 'IDs', 'MT', 'ML']
# you need to set mode='dynamic' if you are using pytorch<=1.5.0
fp16 = dict(loss_scale=dict(init_scale=512.))
train used : bash ./tools/dist_train.sh ./configs/mot/bytetrack/bytetrack_yolox_x_drone_lasot.py 2
result:
/home/sysadmin/miniconda3/envs/mm/lib/python3.7/site-packages/torch/distributed/launch.py:186: FutureWarning: The module torch.distributed.launch is deprecated
and will be removed in future. Use torchrun.
Note that --use_env is set by default in torchrun.
If your script expects `--local_rank` argument to be set, please
change it to read from `os.environ['LOCAL_RANK']` instead. See
https://pytorch.org/docs/stable/distributed.html#launch-utility for
further instructions
FutureWarning,
WARNING:torch.distributed.run:
*****************************************
Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
*****************************************
/home/sysadmin/Desktop/mmtracking/mmtrack/core/utils/misc.py:35: UserWarning: Setting MKL_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
f'Setting MKL_NUM_THREADS environment variable for each process '
/home/sysadmin/Desktop/mmtracking/mmtrack/core/utils/misc.py:35: UserWarning: Setting MKL_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
f'Setting MKL_NUM_THREADS environment variable for each process '
2022-05-27 05:54:24,473 - mmtrack - INFO - Environment info:
------------------------------------------------------------
sys.platform: linux
Python: 3.7.13 (default, Mar 29 2022, 02:18:16) [GCC 7.5.0]
CUDA available: True
GPU 0,1: NVIDIA RTX A5000
CUDA_HOME: /usr/local/cuda
NVCC: Cuda compilation tools, release 11.7, V11.7.64
GCC: gcc (Ubuntu 8.4.0-1ubuntu1~18.04) 8.4.0
PyTorch: 1.11.0
PyTorch compiling details: PyTorch built with:
- GCC 7.3
- C++ Version: 201402
- Intel(R) oneAPI Math Kernel Library Version 2021.4-Product Build 20210904 for Intel(R) 64 architecture applications
- Intel(R) MKL-DNN v2.5.2 (Git Hash a9302535553c73243c632ad3c4c80beec3d19a1e)
- OpenMP 201511 (a.k.a. OpenMP 4.5)
- LAPACK is enabled (usually provided by MKL)
- NNPACK is enabled
- CPU capability usage: AVX2
- CUDA Runtime 11.3
- NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_61,code=sm_61;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_37,code=compute_37
- CuDNN 8.2
- Magma 2.5.2
- Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=11.3, CUDNN_VERSION=8.2.0, CXX_COMPILER=/opt/rh/devtoolset-7/root/usr/bin/c++, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_KINETO -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -DEDGE_PROFILER_USE_KINETO -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=1.11.0, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=OFF, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF,
TorchVision: 0.12.0
OpenCV: 4.5.5
MMCV: 1.5.1
MMCV Compiler: GCC 7.3
MMCV CUDA Compiler: 11.3
MMTracking: 0.13.0+2542848
------------------------------------------------------------
2022-05-27 05:54:24,473 - mmtrack - INFO - Distributed training: True
2022-05-27 05:54:24,994 - mmtrack - INFO - Config:
img_scale = (1280, 720)
model = dict(
detector=dict(
type='YOLOX',
input_size=(1280, 720),
random_size_range=(18, 32),
random_size_interval=10,
backbone=dict(
type='CSPDarknet', deepen_factor=1.33, widen_factor=1.25),
neck=dict(
type='YOLOXPAFPN',
in_channels=[320, 640, 1280],
out_channels=320,
num_csp_blocks=4),
bbox_head=dict(
type='YOLOXHead',
num_classes=1,
in_channels=320,
feat_channels=320),
train_cfg=dict(
assigner=dict(type='SimOTAAssigner', center_radius=2.5)),
test_cfg=dict(score_thr=0.01, nms=dict(type='nms', iou_threshold=0.7)),
init_cfg=dict(
type='Pretrained',
checkpoint=
'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_x_8x8_300e_coco/yolox_x_8x8_300e_coco_20211126_140254-1ef88d67.pth'
)),
type='ByteTrack',
motion=dict(type='KalmanFilter'),
tracker=dict(
type='ByteTracker',
obj_score_thrs=dict(high=0.6, low=0.1),
init_track_thr=0.7,
weight_iou_with_det_scores=True,
match_iou_thrs=dict(high=0.1, low=0.5, tentative=0.3),
num_frames_retain=30))
dataset_type = 'CocoVideoDataset'
classes = ('drone', )
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(
type='Mosaic',
img_scale=(1280, 720),
pad_val=114.0,
bbox_clip_border=False),
dict(
type='RandomAffine',
scaling_ratio_range=(0.1, 2),
border=(-640, -360),
bbox_clip_border=False),
dict(
type='MixUp',
img_scale=(1280, 720),
ratio_range=(0.8, 1.6),
pad_val=114.0,
bbox_clip_border=False),
dict(type='YOLOXHSVRandomAug'),
dict(type='RandomFlip', flip_ratio=0.5),
dict(
type='Resize',
img_scale=(1280, 720),
keep_ratio=True,
bbox_clip_border=False),
dict(type='Pad', size_divisor=32, pad_val=dict(img=(114.0, 114.0, 114.0))),
dict(type='FilterAnnotations', min_gt_bbox_wh=(1, 1), keep_empty=False),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1280, 720),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(
type='Normalize',
mean=[0.0, 0.0, 0.0],
std=[1.0, 1.0, 1.0],
to_rgb=False),
dict(
type='Pad',
size_divisor=32,
pad_val=dict(img=(114.0, 114.0, 114.0))),
dict(type='ImageToTensor', keys=['img']),
dict(type='VideoCollect', keys=['img'])
])
]
data_root = 'data/lasot/'
data = dict(
samples_per_gpu=2,
workers_per_gpu=2,
train=dict(
type='CocoVideoDataset',
ann_file='data/lasot/annotations/lasot_train.json',
img_prefix='data/lasot/LaSOTBenchmark',
ref_img_sampler=dict(
num_ref_imgs=1,
frame_range=10,
filter_key_img=True,
method='uniform'),
pipeline=[
dict(
type='Mosaic',
img_scale=(1280, 720),
pad_val=114.0,
bbox_clip_border=False),
dict(
type='RandomAffine',
scaling_ratio_range=(0.1, 2),
border=(-640, -360),
bbox_clip_border=False),
dict(
type='MixUp',
img_scale=(1280, 720),
ratio_range=(0.8, 1.6),
pad_val=114.0,
bbox_clip_border=False),
dict(type='YOLOXHSVRandomAug'),
dict(type='RandomFlip', flip_ratio=0.5),
dict(
type='Resize',
img_scale=(1280, 720),
keep_ratio=True,
bbox_clip_border=False),
dict(
type='Pad',
size_divisor=32,
pad_val=dict(img=(114.0, 114.0, 114.0))),
dict(
type='FilterAnnotations',
min_gt_bbox_wh=(1, 1),
keep_empty=False),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
],
classes=('drone', )),
val=dict(
type='CocoVideoDataset',
ann_file='data/lasot/annotations/lasot_test.json',
img_prefix='data/lasot/LaSOTBenchmark',
ref_img_sampler=None,
pipeline=[
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1280, 720),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(
type='Normalize',
mean=[0.0, 0.0, 0.0],
std=[1.0, 1.0, 1.0],
to_rgb=False),
dict(
type='Pad',
size_divisor=32,
pad_val=dict(img=(114.0, 114.0, 114.0))),
dict(type='ImageToTensor', keys=['img']),
dict(type='VideoCollect', keys=['img'])
])
],
classes=('drone', )),
test=dict(
type='CocoVideoDataset',
ann_file='data/lasot/annotations/lasot_test.json',
img_prefix='data/lasot/LaSOTBenchmark',
ref_img_sampler=None,
pipeline=[
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1280, 720),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(
type='Normalize',
mean=[0.0, 0.0, 0.0],
std=[1.0, 1.0, 1.0],
to_rgb=False),
dict(
type='Pad',
size_divisor=32,
pad_val=dict(img=(114.0, 114.0, 114.0))),
dict(type='ImageToTensor', keys=['img']),
dict(type='VideoCollect', keys=['img'])
])
],
classes=('drone', )))
optimizer = dict(
type='SGD',
lr=0.0001,
momentum=0.9,
weight_decay=0.0001,
nesterov=True,
paramwise_cfg=dict(norm_decay_mult=0.0, bias_decay_mult=0.0))
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
checkpoint_config = dict(interval=1)
log_config = dict(interval=50, hooks=[dict(type='TextLoggerHook')])
dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = None
resume_from = None
workflow = [('train', 1)]
opencv_num_threads = 0
mp_start_method = 'fork'
samples_per_gpu = 2
total_epochs = 80
num_last_epochs = 10
interval = 5
lr_config = dict(
policy='YOLOX',
warmup='exp',
by_epoch=False,
warmup_by_epoch=True,
warmup_ratio=0.5,
warmup_iters=2000,
num_last_epochs=10,
min_lr_ratio=0.05)
custom_hooks = [
dict(type='YOLOXModeSwitchHook', num_last_epochs=10, priority=48),
dict(type='SyncNormHook', num_last_epochs=10, interval=5, priority=48),
dict(
type='ExpMomentumEMAHook',
resume_from=None,
momentum=0.0001,
priority=49)
]
evaluation = dict(metric=['bbox', 'track'], interval=1)
search_metrics = ['MOTA', 'IDF1', 'FN', 'FP', 'IDs', 'MT', 'ML']
fp16 = dict(loss_scale=dict(init_scale=512.0))
work_dir = './work_dirs/bytetrack_yolox_x_drone_lasot'
gpu_ids = [0]
2022-05-27 05:54:27,436 - mmtrack - INFO - Set random seed to 1405118894, deterministic: False
2022-05-27 05:54:28,366 - mmdet - INFO - image shape: height=1280, width=720 in YOLOX.__init__
2022-05-27 05:54:28,402 - mmtrack - INFO - initialize YOLOX with init_cfg {'type': 'Pretrained', 'checkpoint': 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_x_8x8_300e_coco/yolox_x_8x8_300e_coco_20211126_140254-1ef88d67.pth'}
2022-05-27 05:54:28,402 - mmcv - INFO - load model from: https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_x_8x8_300e_coco/yolox_x_8x8_300e_coco_20211126_140254-1ef88d67.pth
2022-05-27 05:54:28,402 - mmcv - INFO - load checkpoint from http path: https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_x_8x8_300e_coco/yolox_x_8x8_300e_coco_20211126_140254-1ef88d67.pth
2022-05-27 05:54:28,674 - mmcv - WARNING - The model and loaded state dict do not match exactly
size mismatch for bbox_head.multi_level_conv_cls.0.weight: copying a param with shape torch.Size([80, 320, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 320, 1, 1]).
size mismatch for bbox_head.multi_level_conv_cls.0.bias: copying a param with shape torch.Size([80]) from checkpoint, the shape in current model is torch.Size([1]).
size mismatch for bbox_head.multi_level_conv_cls.1.weight: copying a param with shape torch.Size([80, 320, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 320, 1, 1]).
size mismatch for bbox_head.multi_level_conv_cls.1.bias: copying a param with shape torch.Size([80]) from checkpoint, the shape in current model is torch.Size([1]).
size mismatch for bbox_head.multi_level_conv_cls.2.weight: copying a param with shape torch.Size([80, 320, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 320, 1, 1]).
size mismatch for bbox_head.multi_level_conv_cls.2.bias: copying a param with shape torch.Size([80]) from checkpoint, the shape in current model is torch.Size([1]).
loading annotations into memory...
loading annotations into memory...
Done (t=0.19s)
creating index...
index created!
self.cat_ids : []
self.cat2label : {}
Done (t=0.19s)
creating index...
2022-05-27 05:54:29,109 - mmdet - INFO - image shape: height=1280, width=720 in Mosaic.__init__
2022-05-27 05:54:29,110 - mmdet - INFO - image shape: height=1280, width=720 in MixUp.__init__
index created!
self.cat_ids : []
self.cat2label : {}
체크포인트 클래스 : ('drone',)
모델 클래스 : ('drone',)
loading annotations into memory...
Done (t=0.03s)
creating index...
체크포인트 클래스 : ('drone',)
모델 클래스 : ('drone',)
index created!
self.cat_ids : []
self.cat2label : {}
2022-05-27 05:54:29,826 - mmtrack - INFO - Start running, host: sysadmin@pre5820, work_dir: /home/sysadmin/Desktop/mmtracking/work_dirs/bytetrack_yolox_x_drone_lasot
2022-05-27 05:54:29,826 - mmtrack - INFO - Hooks will be executed in the following order:
before_run:
(VERY_HIGH ) YOLOXLrUpdaterHook
(ABOVE_NORMAL) Fp16OptimizerHook
(49 ) ExpMomentumEMAHook
(NORMAL ) CheckpointHook
(NORMAL ) DistEvalHook
(VERY_LOW ) TextLoggerHook
--------------------
before_train_epoch:
(VERY_HIGH ) YOLOXLrUpdaterHook
(48 ) YOLOXModeSwitchHook
(48 ) SyncNormHook
(49 ) ExpMomentumEMAHook
(NORMAL ) DistSamplerSeedHook
(NORMAL ) DistEvalHook
(LOW ) IterTimerHook
(VERY_LOW ) TextLoggerHook
--------------------
before_train_iter:
(VERY_HIGH ) YOLOXLrUpdaterHook
(NORMAL ) DistEvalHook
(LOW ) IterTimerHook
--------------------
after_train_iter:
(ABOVE_NORMAL) Fp16OptimizerHook
(49 ) ExpMomentumEMAHook
(NORMAL ) CheckpointHook
(NORMAL ) DistEvalHook
(LOW ) IterTimerHook
(VERY_LOW ) TextLoggerHook
--------------------
after_train_epoch:
(48 ) SyncNormHook
(49 ) ExpMomentumEMAHook
(NORMAL ) CheckpointHook
(NORMAL ) DistEvalHook
(VERY_LOW ) TextLoggerHook
--------------------
before_val_epoch:
(NORMAL ) DistSamplerSeedHook
(LOW ) IterTimerHook
(VERY_LOW ) TextLoggerHook
--------------------
before_val_iter:
(LOW ) IterTimerHook
--------------------
after_val_iter:
(LOW ) IterTimerHook
--------------------
after_val_epoch:
(VERY_LOW ) TextLoggerHook
--------------------
after_run:
(VERY_LOW ) TextLoggerHook
--------------------
2022-05-27 05:54:29,827 - mmtrack - INFO - workflow: [('train', 1)], max: 80 epochs
2022-05-27 05:54:29,933 - mmtrack - INFO - Checkpoints will be saved to /home/sysadmin/Desktop/mmtracking/work_dirs/bytetrack_yolox_x_drone_lasot by HardDiskBackend.
loading annotations into memory...
Done (t=0.03s)
creating index...
index created!
self.cat_ids : []
self.cat2label : {}
2022-05-27 05:54:32,227 - mmtrack - INFO - Saving checkpoint at 1 epochs
[ ] 0/11027, elapsed: 0s, ETA:/home/sysadmin/miniconda3/envs/mm/lib/python3.7/site-packages/torch/functional.py:568: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at /opt/conda/conda-bld/pytorch_1646755953518/work/aten/src/ATen/native/TensorShape.cpp:2228.)
return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]
/home/sysadmin/miniconda3/envs/mm/lib/python3.7/site-packages/mmdet/models/dense_heads/yolox_head.py:286: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at /opt/conda/conda-bld/pytorch_1646755953518/work/torch/csrc/utils/tensor_new.cpp:210.)
scale_factors).unsqueeze(1)
/home/sysadmin/miniconda3/envs/mm/lib/python3.7/site-packages/torch/functional.py:568: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at /opt/conda/conda-bld/pytorch_1646755953518/work/aten/src/ATen/native/TensorShape.cpp:2228.)
return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]
/home/sysadmin/miniconda3/envs/mm/lib/python3.7/site-packages/mmdet/models/dense_heads/yolox_head.py:286: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at /opt/conda/conda-bld/pytorch_1646755953518/work/torch/csrc/utils/tensor_new.cpp:210.)
scale_factors).unsqueeze(1)
[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> ] 10154/11027, 52.1 task/s, elapsed: 195s, ETA: 17s
2022-05-27 05:58:12,960 - mmtrack - INFO - ---CLEAR MOT Evaluation---
2022-05-27 05:58:12,960 - mmtrack - INFO - Accumulating...
2022-05-27 05:58:13,711 - mmtrack - INFO - Evaluating...
2022-05-27 05:58:14,053 - mmtrack - INFO - Rendering...
2022-05-27 05:58:14,058 - mmtrack - INFO -
IDF1 MOTA MOTP FP FN IDSw Rcll Prcn MT PT ML FM
drone 0.0% -inf% NaN 23244 0 0 NaN 0.0% 0 0 0 0
OVERALL 0.0% -inf% NaN 23244 0 0 NaN 0.0% 0 0 0 0
AVERAGE 0.0% -inf% 0.000 23244 0 0 0.0% 0.0% 0 0 0 0
2022-05-27 05:58:14,058 - mmtrack - INFO - Evaluation finishes with 1.10 s.
Traceback (most recent call last):
File "./tools/train.py", line 212, in <module>
main()
File "./tools/train.py", line 208, in main
meta=meta)
File "/home/sysadmin/Desktop/mmtracking/mmtrack/apis/train.py", line 175, in train_model
runner.run(data_loaders, cfg.workflow, cfg.total_epochs)
File "/home/sysadmin/miniconda3/envs/mm/lib/python3.7/site-packages/mmcv/runner/epoch_based_runner.py", line 127, in run
epoch_runner(data_loaders[i], **kwargs)
File "/home/sysadmin/miniconda3/envs/mm/lib/python3.7/site-packages/mmcv/runner/epoch_based_runner.py", line 54, in train
self.call_hook('after_train_epoch')
File "/home/sysadmin/miniconda3/envs/mm/lib/python3.7/site-packages/mmcv/runner/base_runner.py", line 309, in call_hook
getattr(hook, fn_name)(self)
File "/home/sysadmin/miniconda3/envs/mm/lib/python3.7/site-packages/mmcv/runner/hooks/evaluation.py", line 267, in after_train_epoch
self._do_evaluate(runner)
File "/home/sysadmin/Desktop/mmtracking/mmtrack/core/evaluation/eval_hooks.py", line 62, in _do_evaluate
key_score = self.evaluate(runner, results)
File "/home/sysadmin/miniconda3/envs/mm/lib/python3.7/site-packages/mmcv/runner/hooks/evaluation.py", line 364, in evaluate
results, logger=runner.logger, **self.eval_kwargs)
File "/home/sysadmin/Desktop/mmtracking/mmtrack/datasets/coco_video_dataset.py", line 451, in evaluate
**bbox_kwargs)
File "/home/sysadmin/miniconda3/envs/mm/lib/python3.7/site-packages/mmdet/datasets/coco.py", line 641, in evaluate
result_files, tmp_dir = self.format_results(results, jsonfile_prefix)
File "/home/sysadmin/miniconda3/envs/mm/lib/python3.7/site-packages/mmdet/datasets/coco.py", line 383, in format_results
result_files = self.results2json(results, jsonfile_prefix)
File "/home/sysadmin/miniconda3/envs/mm/lib/python3.7/site-packages/mmdet/datasets/coco.py", line 315, in results2json
json_results = self._det2json(results)
File "/home/sysadmin/miniconda3/envs/mm/lib/python3.7/site-packages/mmdet/datasets/coco.py", line 252, in _det2json
data['category_id'] = self.cat_ids[label]
IndexError: list index out of range
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 4541 closing signal SIGTERM
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 4540) of binary: /home/sysadmin/miniconda3/envs/mm/bin/python
Traceback (most recent call last):
File "/home/sysadmin/miniconda3/envs/mm/lib/python3.7/runpy.py", line 193, in _run_module_as_main
"__main__", mod_spec)
File "/home/sysadmin/miniconda3/envs/mm/lib/python3.7/runpy.py", line 85, in _run_code
exec(code, run_globals)
File "/home/sysadmin/miniconda3/envs/mm/lib/python3.7/site-packages/torch/distributed/launch.py", line 193, in <module>
main()
File "/home/sysadmin/miniconda3/envs/mm/lib/python3.7/site-packages/torch/distributed/launch.py", line 189, in main
launch(args)
File "/home/sysadmin/miniconda3/envs/mm/lib/python3.7/site-packages/torch/distributed/launch.py", line 174, in launch
run(args)
File "/home/sysadmin/miniconda3/envs/mm/lib/python3.7/site-packages/torch/distributed/run.py", line 718, in run
)(*cmd_args)
File "/home/sysadmin/miniconda3/envs/mm/lib/python3.7/site-packages/torch/distributed/launcher/api.py", line 131, in __call__
return launch_agent(self._config, self._entrypoint, list(args))
File "/home/sysadmin/miniconda3/envs/mm/lib/python3.7/site-packages/torch/distributed/launcher/api.py", line 247, in launch_agent
failures=result.failures,
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
============================================================
./tools/train.py FAILED
------------------------------------------------------------
Failures:
<NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
time : 2022-05-27_05:58:16
host : pre5820
rank : 0 (local_rank: 0)
exitcode : 1 (pid: 4540)
error_file: <N/A>
traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================
Hi, moey920,
In my opinion, there should be something wrong with your annotation file. I suggest you check the "categories" of the annotation file.
In addition, you should also check the file mmdet/datasets/coco.py
to see if the categories are as expected.
thanks for your comment, I will comment after confirmation.
self.cat_ids : []
self.cat2label : {}
The beginning of lasot_test.json looks like this.
{"categories": [{"id": 0, "name": 0}], "videos": [{"id": 1, "name": "drone-13"}, {"id": 2, "name": "drone-15"}, {"id": 3, "name": "drone-2"}, {"id": 4, "name": "drone-7"}], "images": [{"file_name": "drone/drone-13/img/00000001.jpg", "height": 720, "width": 1280, "id": 1, "frame_id": 0, "video_id": 1},......
Is there a problem with the category name being 0? I used the lasot2coco
file, but I don't know why the name is int.
Another error occurs when learning by changing the name to "drone".
assert 'mix_results' in results
Also I modified CLASSES = ("drone", )
in /mmtrack/datasets/coco_video_dataset.py
. Where is the mmdet/datasets/coco.py
you mentioned?
Well, according to my understanding, you want to track UAVs, that is, there is only one type of target.
Do I understand correctly?
Yes, that's right.
I see you use bytetrack
, so I don't recommend using datatype of CocoVideoDataset
.
You can take a look here https://github.com/open-mmlab/mmtracking/blob/2542848f4b441e92b5b7c54c32285d0e675dac1a/configs/mot/bytetrack/bytetrack_yolox_x_crowdhuman_mot17-private-half.py#L90 to change your own configuration file, because for bytetrack, we only need to train a detection model, which does not use some properties of videos.
In addition, update the json file.
"categories": [{"id": 0, "name": 0}]
to "categories": [{"id": 0, "name": "drone"}]
You must ensure that the name
and CLASSES
(like the picture above) are consistent.
Is it ok to train only the YoloX model using the entire LaSOT dataset and then change the checkpoint of the detector?
Of course, you can take a look at the issue🚀.
https://github.com/open-mmlab/mmtracking/issues/564
Is it ok to train only the YoloX model using the entire LaSOT dataset and then change the checkpoint of the detector?
@moey920 were you able to succeed? I'm also trying this method
It works good. However, I'm exploring the architecture of the detector because the class I'm trying to detect has poor performance. Try it.