mmpretrain
mmpretrain copied to clipboard
[Bug] Revise the _remove_state_dict_prefix and _add_state_dict_prefix functions in timm.py to adapt to the case of multiple submodels.
When using TimmClassifier as student or teacher model in Knowledge Distillation Algorithms, there have some bugs in save_checkpoint and load_checkpoint.
-
save_checkpoint When saving checkpoint like
save_checkpoint(self.state_dict(), 'xxx.pth'), whereselfis a Knowledge Distillation Algorithm which contains submodelsself.studentandself.teacher,self.state_dict()will recursively call the state_dict function here. The_remove_state_dict_prefixfunction in theTimmClassifierclass will be used as a hook to modify the originaldestination. Specifically, the_remove_state_dict_prefixfunction creates anew_state_dictwhose memory is different from the originaldestinationas thehook_resultto modify the originaldestinationfor submodelsstudentandteacher. But the state_dict funtion of the Knowledge Distillation Algorithm Model will not receive this modify, so the memory address and value ofdestinationhave not changed. To solve this problem, we change the_remove_state_dict_prefixfunction to modify thestate_dictdirectly instead of creating anew_state_dict. -
load_checkpoint When loading checkpoint of a Knowledge Distillation Algorithm Model whose student and teacher are all
TimmClassifier. The_add_state_dict_prefixfunction in theTimmClassifierclass will be used as a hook to modify thestate_dictof each submodel. When modifying the student submodel,_add_state_dict_prefixfunction will delete all keys ofteachersubmodel. To solve this problem, we change the_add_state_dict_prefixfunction to only delete the key that different from its new_key.
please sign the CLA so that I can review your PR.
Hello, can you sign the CLA and fix the lint problem? Then we can merge the PR. @wilxy
Hello, can you sign the CLA and fix the lint problem? Then we can merge the PR. @wilxy
Thanks for the reminder, I've signed the CLA and fixed the lint problem.
Hi @wilxy , Can you migrate this PR to the main branch?