vision
vision copied to clipboard
ViTDet object detection + segmentation implementation
This PR implements ViTDet, as per https://github.com/pytorch/vision/issues/7630 . I needed this implementation regardless of the feedback from torchvision maintainers, but I figured it makes sense to try and merge this upstream. The implementation borrows heavily from the implementation in detectron2. There is still some work to do, but since there is no feedback on whether this will ever be merged I will pause development at this stage.
Discussion points
- I had to move some weights around and use different implementation for the Attention layer, making existing weights incompatible.
- Currently I put the ViTDet implementation inside the
mask_rcnn.py
file, since they are so much alike. Should I put it in a separatevitdet.py
file instead? - I have only added a convenience function for a MaskRCNN with ViT-B/16 backbone. Do we want other backbones? If yes, which ones? For ResNet we also only provide convenience functions for ResNet50.. so not sure what to do here.
Current status
A training with the following command:
python train.py \
--dataset coco --model maskrcnn_vit_b_16_sfpn --epochs 10 --batch-size 2 \
--aspect-ratio-group-factor -1 --weights-backbone ViT_B_16_Weights.IMAGENET1K_V1 --data-path=/srv/data/coco \
--opt vitdet --lr 8e-5 --wd 0.1 --data-augmentation lsj --lr-steps 3 6 --image-min-size 1024 --image-max-size 1024
python train.py \
--dataset coco --model maskrcnn_vit_b_16_sfpn --epochs 10 --batch-size 2 \
--aspect-ratio-group-factor -1 --weights-backbone ViT_B_16_Weights.IMAGENET1K_V1 --data-path=/srv/data/coco \
--opt vitdet --lr 8e-5 --wd 0.1 --data-augmentation lsj --lr-steps 3 6 --image-min-size 1024 --image-max-size 1024
Achieves the following result:
IoU metric: bbox
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.475
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.691
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.524
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.322
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.512
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.612
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.366
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.579
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.606
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.432
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.645
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.757
IoU metric: segm
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.424
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.662
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.455
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.236
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.454
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.617
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.337
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.525
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.548
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.364
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.590
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.715
The segmentation results are identical to the results from their paper.
Todo's
- [ ] I broke the implementation for the classification part of ViT. This needs some more work.
- [ ] I removed the previously available argument to set trainable layers.
- [ ] Train a MaskRCNN model + upload weights.
- [ ] Double check all docstrings to make sure they are still correct.
- [ ] Check formatting / unit tests.
- [x] Check conversion to torchscript.
- [ ] Check conversion to ONNX (?).
My main intention with opening this PR is to allow torchvision maintainers to provide their feedback and opinion. @fmassa I'm not sure if you are still working on these things, but I tag you since we worked together on the RetinaNet implementation :).
:link: Helpful Links
:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/vision/7690
- :page_facing_up: Preview Python docs built from this PR
Note: Links to docs will display an error until the docs builds have been completed.
This comment was automatically generated by Dr. CI and updates every 15 minutes.
I updated this PR so that the implementation more closely resembles the initial implementation of ViT in torchvision. I have also updated the first post accordingly, to avoid unnecessary reading :p
The only difference made in this PR now is that pos_embedding
has moved to ViT
class instead of the Encoder
class (so that any sized images are accepted). This means that existing weights files are not compatible anymore, for which I added a workaround in the _vision_transformer
function. Is this acceptable?
(the rest of the todo's still stand)
I trained a COCO model using ViT-B as backbone with the following command:
python train.py \
--dataset coco --model maskrcnn_vit_b_16_sfpn --epochs 10 --batch-size 4 \
--aspect-ratio-group-factor -1 --weights-backbone ViT_B_16_Weights.IMAGENET1K_V1 --data-path=/srv/data/coco \
--opt adamw --lr 8e-5 --wd 0.1 --data-augmentation lsj --lr-steps 3 6
And got the following results:
IoU metric: segm
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.320
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.534
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.331
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.138
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.338
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.492
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.285
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.440
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.465
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.258
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.502
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.640
This configuration should get approximately 0.424
mAP according to the ViTDet paper (versus 0.320
in this training). This tells me that it is learning something, so in general the implementation is correct, but there is still something missing.
One thing to note is that I trained on a single GPU with batchsize=4, whereas they trained with 64 GPUs (1 image per GPU). I'm not sure what the effect of this is, since I don't have 64 GPUs at my disposal. If someone has the resources to train with batchsize=64, I would be very interested to see how it performs.
In the meantime I will try and use this model some more to see if I can improve on these results.
Is there any update how to fix this? I really would like to have a working VITDet torchvision implementation.
None that I have found. I modified the implementation to match that of detectron2 (to the point where both networks output the same features, given the same input and a seed for RNG), but the results are surprisingly even worse. I don't have the numbers on hand at the moment, but I will continue to look into this.
If you're interested, feel free to give it a go and see what performance you get.
I'm slowly making progress on this, but I am not completely there yet. Is there still interest in this from the torchvision maintainers to merge this at some point?
@pmeier can I ask you for your feedback? Or alternatively can you let me know who best to ask?
The latest changes did have an impact on the COCO evaluation score:
IoU metric: bbox
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.418
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.635
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.458
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.267
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.450
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.559
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.334
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.531
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.559
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.377
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.598
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.715
IoU metric: segm
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.380
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.604
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.405
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.194
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.409
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.573
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.313
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.488
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.512
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.316
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.558
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.688
Though 0.380 still isn't the expected 0.424. I worry that the relative positional embedding in the multihead attention might explain this difference (which is not possible using the Attention layer from torch). The easiest solution would be to implement a custom Attention layer in torchvision, a la detectron2.
Good news, the accuracy has gone up significantly by changing the attention layer. The main difference should be that it uses a relative positional embedding. The score I am getting on COCO now is:
IoU metric: bbox
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.471
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.691
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.519
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.315
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.509
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.608
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.363
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.576
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.604
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.427
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.645
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.755
IoU metric: segm
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.421
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.660
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.453
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.224
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.455
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.609
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.334
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.521
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.544
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.355
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.591
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.71
That 0.421 is awfully close to the reported 0.424 by their paper. I will update the first post with TODO's that are still left for implementation. Considering there seems to be little to no interest in this, I will stop development here as this was all I needed (working ViTDet in torchvision).
I found some bug in the learning rate decay, with those fixes the results are:
IoU metric: bbox
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.475
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.691
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.524
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.322
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.512
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.612
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.366
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.579
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.606
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.432
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.645
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.757
IoU metric: segm
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.424
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.662
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.455
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.236
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.454
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.617
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.337
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.525
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.548
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.364
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.590
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.715
Which for segmentation is identical to the results in the paper, bbox is nearly identical.
:partying_face: