transformers icon indicating copy to clipboard operation
transformers copied to clipboard

Error while loading a pre-trained wav2vec2 model

Open Aaryan369 opened this issue 2 years ago • 2 comments

System Info

  • transformers version: 4.20.1
  • Platform: Linux-5.4.0-1085-azure-x86_64-with-glibc2.10
  • Python version: 3.8.13
  • Huggingface_hub version: 0.8.1
  • PyTorch version (GPU?): 1.9.1+cu111 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: Yes
  • Using distributed or parallel set-up in script?: Yes

Who can help?

@patrickvonplaten , @anton-l

Information

  • [X] The official example scripts
  • [ ] My own modified scripts

Tasks

  • [X] An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • [X] My own task or dataset (give details below)

Reproduction

I have been using the code from this to pre-train a wav2vec2 large model (model_name_or_path : "facebook/wav2vec2-large-lv60").

After the training is completed and the model is saved, I am trying to load the saved model using model = Wav2Vec2ForPreTraining.from_pretrained("/path/to/model", local_files_only=True,)

This results in an error:

RuntimeError: Error(s) in loading state_dict for Wav2Vec2ForPreTraining:
	size mismatch for wav2vec2.feature_extractor.conv_layers.1.conv.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([512, 512, 3]).
	size mismatch for wav2vec2.feature_extractor.conv_layers.2.conv.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([512, 512, 3]).
	size mismatch for wav2vec2.feature_extractor.conv_layers.3.conv.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([512, 512, 3]).
	size mismatch for wav2vec2.feature_extractor.conv_layers.4.conv.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([512, 512, 3]).
	size mismatch for wav2vec2.feature_extractor.conv_layers.5.conv.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([512, 512, 2]).
	size mismatch for wav2vec2.feature_extractor.conv_layers.6.conv.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([512, 512, 2]).
	size mismatch for wav2vec2.feature_projection.projection.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 512]).
	size mismatch for wav2vec2.encoder.pos_conv_embed.conv.weight_v: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 64, 128]).
	size mismatch for wav2vec2.encoder.layers.0.attention.k_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for wav2vec2.encoder.layers.0.attention.v_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for wav2vec2.encoder.layers.0.attention.q_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for wav2vec2.encoder.layers.0.attention.out_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for wav2vec2.encoder.layers.0.feed_forward.intermediate_dense.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([4096, 1024]).
	size mismatch for wav2vec2.encoder.layers.0.feed_forward.output_dense.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 4096]).
	size mismatch for wav2vec2.encoder.layers.1.attention.k_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	.
	.
	.
	size mismatch for wav2vec2.encoder.layers.23.feed_forward.output_dense.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 4096]).
	size mismatch for quantizer.codevectors: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1, 640, 384]).
	size mismatch for quantizer.weight_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([640, 512]).
	size mismatch for project_hid.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([768, 1024]).
	size mismatch for project_q.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([768, 768]).
	You may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method.

I am using a custom dataset. In the training I am calling a pre-trained model instead of initiating a new model using config file model = Wav2Vec2ForPreTraining.from_pretrained(model_name_or_path)

Expected behavior

The model is supposed to load properly with all layers and weights.

Aaryan369 avatar Jul 26 '22 08:07 Aaryan369

Hey @Aaryan369,

Is there any way you could upload the checkpoint to the Hub (maybe as a private one if the weights are sensitive?) Happy to take a deeper look then, but in short the above error message shouldn't happen. There seems to be a mismatch with the config and the model weights

patrickvonplaten avatar Aug 23 '22 15:08 patrickvonplaten

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Sep 17 '22 15:09 github-actions[bot]