temporal-shift-module icon indicating copy to clipboard operation
temporal-shift-module copied to clipboard

How to finetune the online MobileNetV2 TSM for a Jester like usecase

Open FabianHertwig opened this issue 5 years ago • 4 comments

Hi,

thank you for sharing your great work and the code!

I would like to use the TSM in an online scenario for gesture recognition, but the camera angle is a bit different than the one used in the jester dataset. So my idea is to create my own dataset and finetune the model on that.

I see that the model architecture for online inference is a bit different than the one for training. Some layers are a InvertedResidualWithShift layer instead of a InvertedResidual layer. So I guess I can not finetune from the weights provided for the online demo? I actually tried that, but I get errors according to the state dict keys that I can not resolve.

Are you willing to share the weights for the mobilenet_v2 online model for the jester dataset?

If that is not possible, should I then finetune from the kinetics weights? When I try that, I also get an error on the state dict keys. First on almost all keys and I think that is because in the main.py file model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda() is used before loading the weights. When I move the line to after loading the weights, I get the following error:

=> fine-tuning from 'pretrained_models/TSM_kinetics_RGB_mobilenetv2_shift8_blockres_avg_segment8_e100_dense.pth'
#### Notice: keys that failed to load: {'new_fc.bias', 'base_model.classifier.bias', 'new_fc.weight', 'base_model.classifier.weight'}
=> New dataset, do not load fc weights
Traceback (most recent call last):
  File "/datadrive/mit-han-lab/temporal-shift-module/main.py", line 379, in <module>
    main()
  File "/datadrive/mit-han-lab/temporal-shift-module/main.py", line 122, in main
    model.load_state_dict(model_dict)
  File "/anaconda/envs/temporal-shift-module/lib/python3.7/site-packages/torch/nn/modules/module.py", line 830, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for TSN:
	Unexpected key(s) in state_dict: "base_model.classifier.weight", "base_model.classifier.bias". 

There I can set model.load_state_dict(model_dict, strict=False) and then the script runs further until it fails at the data loader. I guess the jester dataset changed a bit. Now the description files is a .csv file and does not contain the number of frames. I guess that is expected by the dataloader.

What is the expected format for the dataset?

Looking forward for your reply.

FabianHertwig avatar Feb 13 '20 17:02 FabianHertwig

@FabianHertwig did you find a solution?

MattVil avatar Sep 08 '20 11:09 MattVil

@MattVil I moved away from this project and started using the handtracking of the google media pipeline project.

FabianHertwig avatar Sep 08 '20 11:09 FabianHertwig

@FabianHertwig I'm still have the same problem as you, how to train the model for a new dataset?and in your mediapipe?

aureosun avatar Jul 28 '22 08:07 aureosun

The key or weight conventions are a little different from the model used for online recognition. So, make sure the naming convention of layer weights matches with the online model.

ZubairKhan001 avatar Jul 28 '22 08:07 ZubairKhan001