mmfewshot
mmfewshot copied to clipboard
Classification demo code error with baseline++
Describe the bug I run the "demo/demo_attention_rpn_detector_inference.py.py" as example arguments, and it works. However, when I change the option "proto-net" to "baseline++" the error happened as follow.
File "C:\GitSource\SolVision3.0\AI_Vision_WPF\TaskProcess\bin\x64\Release\PYD\slmmm\packages\mmfewshot\demo\demo_metric_classifier_1shot_inference.py", line 53, in <module>
main()
File "XXX\mmfewshot\demo\demo_metric_classifier_1shot_inference.py", line 45, in main
process_support_images(model, support_images, support_labels)
File "XXX\mmfewshot\mmfewshot\classification\apis\inference.py", line 71, in process_support_images
model.before_forward_support()
File "XXX\mmfewshot\mmfewshot\classification\models\classifiers\base_finetune.py", line 181, in before_forward_support
assert self.meta_test_head is not None
AssertionError
Reproduction
python demo/demo_metric_classifier_1shot_inference.py demo_classification_images/query_images/Least_Auklet.jpg ../configs/classification/proto_net/cub/proto-net_conv4_1xb105_cub_5way-1shot.py proto-net_conv4_1xb105_cub_5way-1shot_20211120_101211-9ab530c3.pth --support-images-dir demo_classification_images/support_images
What I found I noticed that "baseline++" is inherited from "BaseFinetuneClassifier", but "protonet" is inherited from "BaseMetricClassifier". It was a key reason causing this error I think. I wanna ask is it expected or not?