fairseq icon indicating copy to clipboard operation
fairseq copied to clipboard

Loading Wav2Vec2-2-mBart Checkpoints for S2UT

Open sanchit-gandhi opened this issue 3 years ago • 6 comments

❓ Questions and Help

What is your question?

Many thanks for uploading the fine-tuned model checkpoints for Enhanced Direct Speech-to-Speech Translation in the recent PR https://github.com/facebookresearch/fairseq/pull/4588.

Having downloaded the model weights from:

https://github.com/facebookresearch/fairseq/blob/main/examples/speech_to_speech/docs/enhanced_direct_s2st_discrete_units.md#finetuned-model-checkpoints

and the config file and vocabulary from from:

https://github.com/facebookresearch/fairseq/blob/main/examples/speech_to_speech/docs/enhanced_direct_s2st_discrete_units.md#training

as config_xm_T_1003.yaml and dict_1003_unitmbart.txt respectively, how does one load the pre-trained model?

What have you tried?

First try with the from_pretrained method:

from fairseq.models.speech_to_speech import S2UTConformerModel
model = S2UTConformerModel.from_pretrained(
    "/Users/sanchitgandhi/Downloads",
    checkpoint_file="w2v2_mbart_LND_w_ASR.pt",
    task="speech_to_text",
    )
`FileNotFoundError`
----> 1 model = S2UTConformerModel.from_pretrained("/Users/sanchitgandhi/Downloads", checkpoint_file="w2v2_mbart_LND_w_ASR.pt", task="speech_to_text")

File ~/fairseq/fairseq/models/fairseq_model.py:267, in BaseFairseqModel.from_pretrained(cls, model_name_or_path, checkpoint_file, data_name_or_path, **kwargs)
    244 """
    245 Load a :class:`~fairseq.models.FairseqModel` from a pre-trained model
    246 file. Downloads and caches the pre-trained model file if needed.
   (...)
    263         model archive path.
    264 """
    265 from fairseq import hub_utils
--> 267 x = hub_utils.from_pretrained(
    268     model_name_or_path,
    269     checkpoint_file,
    270     data_name_or_path,
    271     archive_map=cls.hub_models(),
    272     **kwargs,
    273 )
    274 logger.info(x["args"])
    275 return hub_utils.GeneratorHubInterface(x["args"], x["task"], x["models"])

File ~/fairseq/fairseq/hub_utils.py:73, in from_pretrained(model_name_or_path, checkpoint_file, data_name_or_path, archive_map, **kwargs)
     70 if "user_dir" in kwargs:
     71     utils.import_user_module(argparse.Namespace(user_dir=kwargs["user_dir"]))
---> 73 models, args, task = checkpoint_utils.load_model_ensemble_and_task(
     74     [os.path.join(model_path, cpt) for cpt in checkpoint_file.split(os.pathsep)],
     75     arg_overrides=kwargs,
     76 )
     78 return {
     79     "args": args,
     80     "task": task,
     81     "models": models,
     82 }

File ~/fairseq/fairseq/checkpoint_utils.py:473, in load_model_ensemble_and_task(filenames, arg_overrides, task, strict, suffix, num_shards, state)
    471 argspec = inspect.getfullargspec(task.build_model)
    472 if "from_checkpoint" in argspec.args:
--> 473     model = task.build_model(cfg.model, from_checkpoint=True)
    474 else:
    475     model = task.build_model(cfg.model)

File ~/fairseq/fairseq/tasks/speech_to_text.py:129, in SpeechToTextTask.build_model(self, args, from_checkpoint)
    127 args.input_channels = self.data_cfg.input_channels
    128 args.speaker_to_id = self.speaker_to_id
--> 129 return super(SpeechToTextTask, self).build_model(args, from_checkpoint)

File ~/fairseq/fairseq/tasks/fairseq_task.py:676, in LegacyFairseqTask.build_model(self, args, from_checkpoint)
    664 """
    665 Build the :class:`~fairseq.models.BaseFairseqModel` instance for this
    666 task.
   (...)
    672     a :class:`~fairseq.models.BaseFairseqModel` instance
    673 """
    674 from fairseq import models, quantization_utils
--> 676 model = models.build_model(args, self, from_checkpoint)
    677 model = quantization_utils.quantize_model_scalar(model, args)
    678 return model

File ~/fairseq/fairseq/models/__init__.py:106, in build_model(cfg, task, from_checkpoint)
     98             ARCH_CONFIG_REGISTRY[model_type](cfg)
    100 assert model is not None, (
    101     f"Could not infer model type from {cfg}. "
    102     "Available models: {}".format(MODEL_DATACLASS_REGISTRY.keys())
    103     + f" Requested model type: {model_type}"
    104 )
--> 106 return model.build_model(cfg, task)

File ~/fairseq/fairseq/models/speech_to_text/xm_transformer.py:609, in XMTransformerModel.build_model(cls, args, task)
    607 base_architecture(args)
    608 if getattr(args, "load_pretrained_decoder_from", None):
--> 609     ckpt = torch.load(getattr(args, "load_pretrained_decoder_from", None))
    610     decoder_args_dict = cls.get_decoder_args_from_checkpoint(ckpt["cfg"])
    611     args = cls.override_decoder_args(args, decoder_args_dict)

File ~/venv/lib/python3.8/site-packages/torch/serialization.py:709, in load(f, map_location, pickle_module, **pickle_load_args)
    706 if 'encoding' not in pickle_load_args.keys():
    707     pickle_load_args['encoding'] = 'utf-8'
--> 709 with _open_file_like(f, 'rb') as opened_file:
    710     if _is_zipfile(opened_file):
    711         # The zipfile reader is going to advance the current file position.
    712         # If we want to actually tail call to torch.jit.load, we need to
    713         # reset back to the original position.
    714         orig_position = opened_file.tell()

File ~/venv/lib/python3.8/site-packages/torch/serialization.py:240, in _open_file_like(name_or_buffer, mode)
    238 def _open_file_like(name_or_buffer, mode):
    239     if _is_path(name_or_buffer):
--> 240         return _open_file(name_or_buffer, mode)
    241     else:
    242         if 'w' in mode:

File ~/venv/lib/python3.8/site-packages/torch/serialization.py:221, in _open_file.__init__(self, name, mode)
    220 def __init__(self, name, mode):
--> 221     super(_open_file, self).__init__(open(name, mode))

FileNotFoundError: [Errno 2] No such file or directory: '/checkpoint/pipibjc/u2u/umbart/interspeech/vp_ll.tr.lgtkn.en_es.mbart_large.alp0.7.bmeos.tps1020.mt1024.uf9.sh_tr.sh.ms128.enb.dnb.dr0.1.atdr0.1.actdr0.0.wd0.01.adam.eps1e-06.cl0.1.lr0.0003.wrm10000.fp16.lam10.mask0.3.rl1.rot0.mrnd0.1.ins0.pers1.0.seed2.ngpu64/checkpoint_42_500000.pt'

Try instead checkpoint_utils.load_model_ensemble_and_task:

import fairseq

cp = "/Users/sanchitgandhi/Downloads/w2v2_mbart_LND_w_ASR.pt"
model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp], arg_overrides={"data": "/Users/sanchitgandhi/Downloads/", "task": "speech_to_text"})
`FileNotFoundError`
----> 1 model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp], arg_overrides={"data": "/Users/sanchitgandhi/Downloads/", "task":"speech_to_text"})

File ~/fairseq/fairseq/checkpoint_utils.py:473, in load_model_ensemble_and_task(filenames, arg_overrides, task, strict, suffix, num_shards, state)
    471 argspec = inspect.getfullargspec(task.build_model)
    472 if "from_checkpoint" in argspec.args:
--> 473     model = task.build_model(cfg.model, from_checkpoint=True)
    474 else:
    475     model = task.build_model(cfg.model)

File ~/fairseq/fairseq/tasks/speech_to_text.py:129, in SpeechToTextTask.build_model(self, args, from_checkpoint)
    127 args.input_channels = self.data_cfg.input_channels
    128 args.speaker_to_id = self.speaker_to_id
--> 129 return super(SpeechToTextTask, self).build_model(args, from_checkpoint)

File ~/fairseq/fairseq/tasks/fairseq_task.py:676, in LegacyFairseqTask.build_model(self, args, from_checkpoint)
    664 """
    665 Build the :class:`~fairseq.models.BaseFairseqModel` instance for this
    666 task.
   (...)
    672     a :class:`~fairseq.models.BaseFairseqModel` instance
    673 """
    674 from fairseq import models, quantization_utils
--> 676 model = models.build_model(args, self, from_checkpoint)
    677 model = quantization_utils.quantize_model_scalar(model, args)
    678 return model

File ~/fairseq/fairseq/models/__init__.py:106, in build_model(cfg, task, from_checkpoint)
     98             ARCH_CONFIG_REGISTRY[model_type](cfg)
    100 assert model is not None, (
    101     f"Could not infer model type from {cfg}. "
    102     "Available models: {}".format(MODEL_DATACLASS_REGISTRY.keys())
    103     + f" Requested model type: {model_type}"
    104 )
--> 106 return model.build_model(cfg, task)

File ~/fairseq/fairseq/models/speech_to_text/xm_transformer.py:609, in XMTransformerModel.build_model(cls, args, task)
    607 base_architecture(args)
    608 if getattr(args, "load_pretrained_decoder_from", None):
--> 609     ckpt = torch.load(getattr(args, "load_pretrained_decoder_from", None))
    610     decoder_args_dict = cls.get_decoder_args_from_checkpoint(ckpt["cfg"])
    611     args = cls.override_decoder_args(args, decoder_args_dict)

File ~/venv/lib/python3.8/site-packages/torch/serialization.py:709, in load(f, map_location, pickle_module, **pickle_load_args)
    706 if 'encoding' not in pickle_load_args.keys():
    707     pickle_load_args['encoding'] = 'utf-8'
--> 709 with _open_file_like(f, 'rb') as opened_file:
    710     if _is_zipfile(opened_file):
    711         # The zipfile reader is going to advance the current file position.
    712         # If we want to actually tail call to torch.jit.load, we need to
    713         # reset back to the original position.
    714         orig_position = opened_file.tell()

File ~/venv/lib/python3.8/site-packages/torch/serialization.py:240, in _open_file_like(name_or_buffer, mode)
    238 def _open_file_like(name_or_buffer, mode):
    239     if _is_path(name_or_buffer):
--> 240         return _open_file(name_or_buffer, mode)
    241     else:
    242         if 'w' in mode:

File ~/venv/lib/python3.8/site-packages/torch/serialization.py:221, in _open_file.__init__(self, name, mode)
    220 def __init__(self, name, mode):
--> 221     super(_open_file, self).__init__(open(name, mode))

FileNotFoundError: [Errno 2] No such file or directory: '/checkpoint/pipibjc/u2u/umbart/interspeech/vp_ll.tr.lgtkn.en_es.mbart_large.alp0.7.bmeos.tps1020.mt1024.uf9.sh_tr.sh.ms128.enb.dnb.dr0.1.atdr0.1.actdr0.0.wd0.01.adam.eps1e-06.cl0.1.lr0.0003.wrm10000.fp16.lam10.mask0.3.rl1.rot0.mrnd0.1.ins0.pers1.0.seed2.ngpu64/checkpoint_42_500000.pt'

It would seem as though an additional file is required, under the name:

/checkpoint/pipibjc/u2u/umbart/interspeech/vp_ll.tr.lgtkn.en_es.mbart_large.alp0.7.bmeos.tps1020.mt1024.uf9.sh_tr.sh.ms128.enb.dnb.dr0.1.atdr0.1.actdr0.0.wd0.01.adam.eps1e-06.cl0.1.lr0.0003.wrm10000.fp16.lam10.mask0.3.rl1.rot0.mrnd0.1.ins0.pers1.0.seed2.ngpu64/checkpoint_42_500000.pt

What's your environment?

  • fairseq Version (e.g., 1.0 or main): main
  • PyTorch Version (e.g., 1.0): 1.13.0
  • OS (e.g., Linux): MacOS, Linux
  • How you installed fairseq (pip, source): source
  • Build command you used (if compiling from source): https://github.com/facebookresearch/fairseq#requirements-and-installation
  • Python version: 3.8.9
  • CUDA/cuDNN version: -
  • GPU models and configuration: -
  • Any other relevant information: -

cc @sravyapopuri388

sanchit-gandhi avatar Jul 21 '22 07:07 sanchit-gandhi

Hi, thanks for reaching out. Seems like the issue is with missing pretrained models. A quick fix is to update the paths in the checkpoint. You can update the w2v_path in the encoder to your local path and can set the load_pretrained_decoder_from to None. You can do something like

model = torch.load(MODEL_PATH) state = model["cfg"]["model"] state.w2v_path = W2V_PATH state.load_pretrained_decoder_from = None

sravyapopuri388 avatar Jul 21 '22 19:07 sravyapopuri388

Hi @sravyapopuri388, thanks for your prompt response!

To confirm, the w2v_path should be set to the local path? How does one then load the pre-trained model? I've tried:

import torch
import fairseq

DATA_PATH = "/Users/sanchitgandhi/Downloads"
MODEL_PATH = DATA_PATH + "/w2v2_mbart_LND_w_ASR.pt"
NEW_MODEL_PATH = DATA_PATH + "/w2v2_mbart_LND_w_ASR_new.pt"

model = torch.load(MODEL_PATH)

model["cfg"]["model"].w2v_path = MODEL_PATH
model["cfg"]["model"].load_pretrained_decoder_from = None

torch.save(model, NEW_MODEL_PATH)

fairseq_model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([NEW_MODEL_PATH], arg_overrides={"data": DATA_PATH, "task": "speech_to_text"})

But am encountering the same error as before:

FileNotFoundError: [Errno 2] No such file or directory: '/checkpoint/pipibjc/u2u/umbart/interspeech/vp_ll.tr.lgtkn.en_es.mbart_large.alp0.7.bmeos.tps1020.mt1024.uf9.sh_tr.sh.ms128.enb.dnb.dr0.1.atdr0.1.actdr0.0.wd0.01.adam.eps1e-06.cl0.1.lr0.0003.wrm10000.fp16.lam10.mask0.3.rl1.rot0.mrnd0.1.ins0.pers1.0.seed2.ngpu64/checkpoint_42_500000.pt'

Thank you for your help, it's very much appreciated!

sanchit-gandhi avatar Jul 22 '22 09:07 sanchit-gandhi

@sanchit-gandhi You have to change w2v_path into a model with load_pretrained_decoder_from=None, i.e. the NEW_MODEL_PATH? (even it does not exist when assgiend)

I believe It is because w2v_path=(the bugged w2v2_mbart_LND_w_ASR.pt) causes the program to load the bugged w2v again, which has the load_pretrained_decoder_from refers to a non-existing checkpoint.

gmryu avatar Jul 23 '22 13:07 gmryu

Thanks @gmryu! I've changed as suggested, and also updated the config_yaml path as this was hardcoded to a local path too.

Now, I'm not getting any errors when loading the model, but the python process is 'killed' during execution (after ~30s). Were you able to load the pre-trained checkpoint? If so, would you mind sharing your code, even as a direct copy and paste with your absolute paths :)

sanchit-gandhi avatar Jul 25 '22 09:07 sanchit-gandhi

Hi Sanchit, thanks for your patience. Could you try updating the w2v_path in w2v2_mbart_LND_w_ASR.pt state with the original W2v model args instead. More like

import torch
from omegaconf import OmegaConf
from fairseq.dataclass.utils import convert_namespace_to_omegaconf

MODEL_PATH = Path/to/checkpoint
W2V_PATH = Path/to/w2v2/model
NEW_MODEL_PATH = Path/to/updated/model


model = torch.load(MODEL_PATH)
w2v = torch.load(W2V_PATH)
w2v_args = OmegaConf.create(w2v["cfg"])

state = model["cfg"]["model"]
state.w2v_path = None
state.load_pretrained_decoder_from = None
state.w2v_args = w2v_args

torch.save(model, NEW_MODEL_PATH)

Please let me know if this doesn't work in which case I will update the checkpoints on my end before sharing

sravyapopuri388 avatar Jul 28 '22 16:07 sravyapopuri388

Hey @sravyapopuri388, thanks for getting back to me! I actually went in and 'hardcoded' all the paths in the fairseq repo just to cover all grounds. The erroneous path issue seems to be resolved. I'm now attempting to load the checkpoint; given it's size (~10GB) I've been attempting this on a big CPU machine. Will let you know if there are any further issues!

sanchit-gandhi avatar Jul 29 '22 15:07 sanchit-gandhi

Hi @sravyapopuri388! Thanks for the help. Overriding the W2V2 paths fixed the issue with loading the model. Just to confirm, to run the model correctly, no other changes are needed when instantiating the S2UT model, i.e. the decoder path can be left unchanged and the output features default to dim 1007?

Summary of state dict
XMTransformerModel(
  (encoder): Wav2VecEncoderWithAdaptor(
    (w2v_encoder): Wav2VecEncoder(
      (w2v_model): Wav2Vec2Model(
        (feature_extractor): ConvFeatureExtractionModel(
          (conv_layers): ModuleList(
            (0): Sequential(
              (0): Conv1d(1, 512, kernel_size=(10,), stride=(5,))
              (1): Dropout(p=0.0, inplace=False)
              (2): Sequential(
                (0): TransposeLast()
                (1): Fp32LayerNorm((512,), eps=1e-05, elementwise_affine=True)
                (2): TransposeLast()
              )
              (3): GELU(approximate=none)
            )
            ...
            (6): Sequential(
              (0): Conv1d(512, 512, kernel_size=(2,), stride=(2,))
              (1): Dropout(p=0.0, inplace=False)
              (2): Sequential(
                (0): TransposeLast()
                (1): Fp32LayerNorm((512,), eps=1e-05, elementwise_affine=True)
                (2): TransposeLast()
              )
              (3): GELU(approximate=none)
            )
          )
        )
        (post_extract_proj): Linear(in_features=512, out_features=1024, bias=True)
        (dropout_input): Dropout(p=0.1, inplace=False)
        (dropout_features): Dropout(p=0.1, inplace=False)
        (quantizer): None
        (project_q): None
        (encoder): ConformerEncoder(
          (pos_conv): Sequential(
            (0): Conv1d(1024, 1024, kernel_size=(128,), stride=(1,), padding=(64,), groups=16)
            (1): SamePad()
            (2): GELU(approximate=none)
          )
          (layers): ModuleList(
            (0): ConformerWav2Vec2EncoderLayer(
              (ffn1): FeedForwardModule(
                (layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
                (w_1): Linear(in_features=1024, out_features=4096, bias=True)
                (w_2): Linear(in_features=4096, out_features=1024, bias=True)
                (dropout1): Dropout(p=0.0, inplace=False)
                (dropout2): Dropout(p=0.0, inplace=False)
                (activation): SiLU(inplace=True)
              )
              (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (self_attn_dropout): Dropout(p=0.0, inplace=False)
              (self_attn): RelPositionMultiHeadedAttention(
                (linear_q): Linear(in_features=1024, out_features=1024, bias=True)
                (linear_k): Linear(in_features=1024, out_features=1024, bias=True)
                (linear_v): Linear(in_features=1024, out_features=1024, bias=True)
                (linear_out): Linear(in_features=1024, out_features=1024, bias=True)
                (dropout): Dropout(p=0.0, inplace=False)
                (linear_pos): Linear(in_features=1024, out_features=1024, bias=False)
              )
              (conv_module): ConvolutionModule(
                (layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
                (pointwise_conv1): Conv1d(1024, 2048, kernel_size=(1,), stride=(1,), bias=False)
                (glu): GLU(dim=1)
                (depthwise_conv): Conv1d(1024, 1024, kernel_size=(31,), stride=(1,), padding=(15,), groups=1024, bias=False)
                (batch_norm): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                (activation): SiLU(inplace=True)
                (pointwise_conv2): Conv1d(1024, 1024, kernel_size=(1,), stride=(1,), bias=False)
                (dropout): Dropout(p=0.0, inplace=False)
              )
              (ffn2): FeedForwardModule(
                (layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
                (w_1): Linear(in_features=1024, out_features=4096, bias=True)
                (w_2): Linear(in_features=4096, out_features=1024, bias=True)
                (dropout1): Dropout(p=0.0, inplace=False)
                (dropout2): Dropout(p=0.0, inplace=False)
                (activation): SiLU(inplace=True)
              )
              (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            ...
            (23): ConformerWav2Vec2EncoderLayer(
              (ffn1): FeedForwardModule(
                (layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
                (w_1): Linear(in_features=1024, out_features=4096, bias=True)
                (w_2): Linear(in_features=4096, out_features=1024, bias=True)
                (dropout1): Dropout(p=0.0, inplace=False)
                (dropout2): Dropout(p=0.0, inplace=False)
                (activation): SiLU(inplace=True)
              )
              (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (self_attn_dropout): Dropout(p=0.0, inplace=False)
              (self_attn): RelPositionMultiHeadedAttention(
                (linear_q): Linear(in_features=1024, out_features=1024, bias=True)
                (linear_k): Linear(in_features=1024, out_features=1024, bias=True)
                (linear_v): Linear(in_features=1024, out_features=1024, bias=True)
                (linear_out): Linear(in_features=1024, out_features=1024, bias=True)
                (dropout): Dropout(p=0.0, inplace=False)
                (linear_pos): Linear(in_features=1024, out_features=1024, bias=False)
              )
              (conv_module): ConvolutionModule(
                (layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
                (pointwise_conv1): Conv1d(1024, 2048, kernel_size=(1,), stride=(1,), bias=False)
                (glu): GLU(dim=1)
                (depthwise_conv): Conv1d(1024, 1024, kernel_size=(31,), stride=(1,), padding=(15,), groups=1024, bias=False)
                (batch_norm): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                (activation): SiLU(inplace=True)
                (pointwise_conv2): Conv1d(1024, 1024, kernel_size=(1,), stride=(1,), bias=False)
                (dropout): Dropout(p=0.0, inplace=False)
              )
              (ffn2): FeedForwardModule(
                (layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
                (w_1): Linear(in_features=1024, out_features=4096, bias=True)
                (w_2): Linear(in_features=4096, out_features=1024, bias=True)
                (dropout1): Dropout(p=0.0, inplace=False)
                (dropout2): Dropout(p=0.0, inplace=False)
                (activation): SiLU(inplace=True)
              )
              (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
          )
          (layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (embed_positions): RelPositionalEncoding()
        )
        (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (final_proj): None
      )
      (final_dropout): Dropout(p=0, inplace=False)
    )
    (adaptor): Conv1dAdaptor(
      (layers): ModuleList(
        (0): Conv1d(1024, 2048, kernel_size=(3,), stride=(2,), padding=(1,))
      )
    )
  )
  (decoder): TransformerDecoder(
    (dropout_module): FairseqDropout()
    (embed_tokens): Embedding(1007, 1024, padding_idx=1)
    (embed_positions): SinusoidalPositionalEmbedding()
    (layers): LayerDropModuleList(
      (0): TransformerDecoderLayerBase(
        (dropout_module): FairseqDropout()
        (self_attn): MultiheadAttention(
          (dropout_module): FairseqDropout()
          (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
          (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
          (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
          (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
        )
        (activation_dropout_module): FairseqDropout()
        (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (encoder_attn): MultiheadAttention(
          (dropout_module): FairseqDropout()
          (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
          (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
          (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
          (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
        )
        (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (fc1): Linear(in_features=1024, out_features=4096, bias=True)
        (fc2): Linear(in_features=4096, out_features=1024, bias=True)
        (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      )
      (1): TransformerDecoderLayerBase(
        (dropout_module): FairseqDropout()
        (self_attn): MultiheadAttention(
          (dropout_module): FairseqDropout()
          (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
          (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
          (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
          (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
        )
        (activation_dropout_module): FairseqDropout()
        (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (encoder_attn): MultiheadAttention(
          (dropout_module): FairseqDropout()
          (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
          (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
          (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
          (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
        )
        (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (fc1): Linear(in_features=1024, out_features=4096, bias=True)
        (fc2): Linear(in_features=4096, out_features=1024, bias=True)
        (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      )
      ...
      (11): TransformerDecoderLayerBase(
        (dropout_module): FairseqDropout()
        (self_attn): MultiheadAttention(
          (dropout_module): FairseqDropout()
          (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
          (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
          (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
          (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
        )
        (activation_dropout_module): FairseqDropout()
        (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (encoder_attn): MultiheadAttention(
          (dropout_module): FairseqDropout()
          (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
          (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
          (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
          (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
        )
        (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (fc1): Linear(in_features=1024, out_features=4096, bias=True)
        (fc2): Linear(in_features=4096, out_features=1024, bias=True)
        (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      )
    )
    (layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    (output_projection): Linear(in_features=1024, out_features=1007, bias=False)
  )
)

sanchit-gandhi avatar Aug 04 '22 15:08 sanchit-gandhi

Hi @sanchit-gandhi yes the output feature dimension size looks correct to me. Just overriding the w2v_args argument with the args from original w2v2 model and setting the load_pretrained_decoder_from path to None to avoid errors should work and no other changes are required. Please let me know if you encounter any other issues. Thanks!

sravyapopuri388 avatar Aug 04 '22 21:08 sravyapopuri388

Okay brilliant, thanks @sravyapopuri388! Closing this as the model can be correctly loaded with the aforementioned w2v_args + load_pretrained_decoder_from tricks.

sanchit-gandhi avatar Aug 10 '22 17:08 sanchit-gandhi