w2v2-speaker icon indicating copy to clipboard operation
w2v2-speaker copied to clipboard

releasing fine-tuned model

Open GuangkeChen opened this issue 3 years ago • 5 comments

Hello, thanks for your work and code, which help my research a lot. I wonder if you could provide the fine-tuned model?

GuangkeChen avatar Aug 11 '22 04:08 GuangkeChen

I found a checkpoint lying around on my disk, it should correspond to one of the three models whose results are shown in Table 3 under the 'tri-stage' results - with a mean EER of 1.97% on voxceleb-extended test set. (See https://arxiv.org/abs/2109.15053)

You should be able to download it here: https://surfdrive.surf.nl/files/index.php/s/1HO2w6WHTqCmfVk

nikvaessen avatar Aug 11 '22 09:08 nikvaessen

Thanks for your reply.

GuangkeChen avatar Aug 12 '22 07:08 GuangkeChen

Hello, thanks again for your code and pre-trained model. I have tried to load the pre-trained with the checkpoint you provided for speaker embedding extraction, but got the following error:

RuntimeError: Error(s) in loading state_dict for Wav2vec2FCModule:
       size mismatch for loss_fn.fc_weights: copying a param with shape torch.Size([5994, 768]) 
from checkpoint, the shape in current model is torch.Size([5994, 1536]).

and 1536 is the double of 768.

This is my code adapting from your predict.py:

  from src.lightning_modules.speaker import (
  Wav2vec2FCModule,
  )
  import torch

  # create evaluator (for speaker recognition)
  evaluator: SpeakerRecognitionEvaluator = instantiate(cfg.evaluator)

  network_cfg = instantiate(cfg.network)

  def loss_fn_constructor():
      # should be instantiated in the network
      # so that potential parameters are properly
      # registered
      return instantiate(cfg.optim.loss)

  validation_pairs = [] # do not care
  test_pairs = [] # do not care

  num_speakers = 5994 # hard-code 
  
  kwargs = {
          "hyperparameters_to_save": cfg,
          "cfg": network_cfg,
          "num_speakers": num_speakers,
          "loss_fn_constructor": loss_fn_constructor,
          "validation_pairs": validation_pairs,
          "test_pairs": test_pairs,
          "evaluator": evaluator,
  }
  
  network_class = Wav2vec2FCModule

  load_network_from_checkpoint = '/public/home/chengk//w2v2-speaker/n1_3stage.best.ckpt'

  network = network_class.load_from_checkpoint(
              load_network_from_checkpoint, strict=False, **kwargs
          )
  network.eval()

This is my config file (partial):

defaults:
  - data/module: voxceleb1
  - data/pipeline: wav2vec_base_pipeline
  - data/shards: shards_voxceleb
  - data/dataloader: speaker
  - evaluator: cosine_distance
  - network: wav2vec2_fc
  - optim/loss: aam_softmax
  - tokenizer: default
  - trainer: trainer

Do you have any tips for this issue? Many thanks!

GuangkeChen avatar Oct 02 '22 18:10 GuangkeChen

You should change network.stat_pooling_type=mean+std to network.stat_pooling_type=mean. Either edit https://github.com/nikvaessen/w2v2-speaker/blob/master/config/network/wav2vec2_fc.yaml#L40 or use python run.py network.stat_pooling_type=mean

nikvaessen avatar Oct 02 '22 20:10 nikvaessen

@GuangkeChen, I tried to use your code for speaker embedding extraction, but I noticed that you haven't defined cfg. Were you able to successfully run your code? Could you provide me with a version that works? thanks!

Kinetic-shaun avatar Oct 27 '23 10:10 Kinetic-shaun