fairseq
fairseq copied to clipboard
Loading Wav2Vec2-2-mBart Checkpoints for S2UT
❓ Questions and Help
What is your question?
Many thanks for uploading the fine-tuned model checkpoints for Enhanced Direct Speech-to-Speech Translation in the recent PR https://github.com/facebookresearch/fairseq/pull/4588.
Having downloaded the model weights from:
https://github.com/facebookresearch/fairseq/blob/main/examples/speech_to_speech/docs/enhanced_direct_s2st_discrete_units.md#finetuned-model-checkpoints
and the config file and vocabulary from from:
https://github.com/facebookresearch/fairseq/blob/main/examples/speech_to_speech/docs/enhanced_direct_s2st_discrete_units.md#training
as config_xm_T_1003.yaml and dict_1003_unitmbart.txt respectively, how does one load the pre-trained model?
What have you tried?
First try with the from_pretrained method:
from fairseq.models.speech_to_speech import S2UTConformerModel
model = S2UTConformerModel.from_pretrained(
"/Users/sanchitgandhi/Downloads",
checkpoint_file="w2v2_mbart_LND_w_ASR.pt",
task="speech_to_text",
)
`FileNotFoundError`
----> 1 model = S2UTConformerModel.from_pretrained("/Users/sanchitgandhi/Downloads", checkpoint_file="w2v2_mbart_LND_w_ASR.pt", task="speech_to_text")
File ~/fairseq/fairseq/models/fairseq_model.py:267, in BaseFairseqModel.from_pretrained(cls, model_name_or_path, checkpoint_file, data_name_or_path, **kwargs)
244 """
245 Load a :class:`~fairseq.models.FairseqModel` from a pre-trained model
246 file. Downloads and caches the pre-trained model file if needed.
(...)
263 model archive path.
264 """
265 from fairseq import hub_utils
--> 267 x = hub_utils.from_pretrained(
268 model_name_or_path,
269 checkpoint_file,
270 data_name_or_path,
271 archive_map=cls.hub_models(),
272 **kwargs,
273 )
274 logger.info(x["args"])
275 return hub_utils.GeneratorHubInterface(x["args"], x["task"], x["models"])
File ~/fairseq/fairseq/hub_utils.py:73, in from_pretrained(model_name_or_path, checkpoint_file, data_name_or_path, archive_map, **kwargs)
70 if "user_dir" in kwargs:
71 utils.import_user_module(argparse.Namespace(user_dir=kwargs["user_dir"]))
---> 73 models, args, task = checkpoint_utils.load_model_ensemble_and_task(
74 [os.path.join(model_path, cpt) for cpt in checkpoint_file.split(os.pathsep)],
75 arg_overrides=kwargs,
76 )
78 return {
79 "args": args,
80 "task": task,
81 "models": models,
82 }
File ~/fairseq/fairseq/checkpoint_utils.py:473, in load_model_ensemble_and_task(filenames, arg_overrides, task, strict, suffix, num_shards, state)
471 argspec = inspect.getfullargspec(task.build_model)
472 if "from_checkpoint" in argspec.args:
--> 473 model = task.build_model(cfg.model, from_checkpoint=True)
474 else:
475 model = task.build_model(cfg.model)
File ~/fairseq/fairseq/tasks/speech_to_text.py:129, in SpeechToTextTask.build_model(self, args, from_checkpoint)
127 args.input_channels = self.data_cfg.input_channels
128 args.speaker_to_id = self.speaker_to_id
--> 129 return super(SpeechToTextTask, self).build_model(args, from_checkpoint)
File ~/fairseq/fairseq/tasks/fairseq_task.py:676, in LegacyFairseqTask.build_model(self, args, from_checkpoint)
664 """
665 Build the :class:`~fairseq.models.BaseFairseqModel` instance for this
666 task.
(...)
672 a :class:`~fairseq.models.BaseFairseqModel` instance
673 """
674 from fairseq import models, quantization_utils
--> 676 model = models.build_model(args, self, from_checkpoint)
677 model = quantization_utils.quantize_model_scalar(model, args)
678 return model
File ~/fairseq/fairseq/models/__init__.py:106, in build_model(cfg, task, from_checkpoint)
98 ARCH_CONFIG_REGISTRY[model_type](cfg)
100 assert model is not None, (
101 f"Could not infer model type from {cfg}. "
102 "Available models: {}".format(MODEL_DATACLASS_REGISTRY.keys())
103 + f" Requested model type: {model_type}"
104 )
--> 106 return model.build_model(cfg, task)
File ~/fairseq/fairseq/models/speech_to_text/xm_transformer.py:609, in XMTransformerModel.build_model(cls, args, task)
607 base_architecture(args)
608 if getattr(args, "load_pretrained_decoder_from", None):
--> 609 ckpt = torch.load(getattr(args, "load_pretrained_decoder_from", None))
610 decoder_args_dict = cls.get_decoder_args_from_checkpoint(ckpt["cfg"])
611 args = cls.override_decoder_args(args, decoder_args_dict)
File ~/venv/lib/python3.8/site-packages/torch/serialization.py:709, in load(f, map_location, pickle_module, **pickle_load_args)
706 if 'encoding' not in pickle_load_args.keys():
707 pickle_load_args['encoding'] = 'utf-8'
--> 709 with _open_file_like(f, 'rb') as opened_file:
710 if _is_zipfile(opened_file):
711 # The zipfile reader is going to advance the current file position.
712 # If we want to actually tail call to torch.jit.load, we need to
713 # reset back to the original position.
714 orig_position = opened_file.tell()
File ~/venv/lib/python3.8/site-packages/torch/serialization.py:240, in _open_file_like(name_or_buffer, mode)
238 def _open_file_like(name_or_buffer, mode):
239 if _is_path(name_or_buffer):
--> 240 return _open_file(name_or_buffer, mode)
241 else:
242 if 'w' in mode:
File ~/venv/lib/python3.8/site-packages/torch/serialization.py:221, in _open_file.__init__(self, name, mode)
220 def __init__(self, name, mode):
--> 221 super(_open_file, self).__init__(open(name, mode))
FileNotFoundError: [Errno 2] No such file or directory: '/checkpoint/pipibjc/u2u/umbart/interspeech/vp_ll.tr.lgtkn.en_es.mbart_large.alp0.7.bmeos.tps1020.mt1024.uf9.sh_tr.sh.ms128.enb.dnb.dr0.1.atdr0.1.actdr0.0.wd0.01.adam.eps1e-06.cl0.1.lr0.0003.wrm10000.fp16.lam10.mask0.3.rl1.rot0.mrnd0.1.ins0.pers1.0.seed2.ngpu64/checkpoint_42_500000.pt'
Try instead checkpoint_utils.load_model_ensemble_and_task:
import fairseq
cp = "/Users/sanchitgandhi/Downloads/w2v2_mbart_LND_w_ASR.pt"
model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp], arg_overrides={"data": "/Users/sanchitgandhi/Downloads/", "task": "speech_to_text"})
`FileNotFoundError`
----> 1 model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp], arg_overrides={"data": "/Users/sanchitgandhi/Downloads/", "task":"speech_to_text"})
File ~/fairseq/fairseq/checkpoint_utils.py:473, in load_model_ensemble_and_task(filenames, arg_overrides, task, strict, suffix, num_shards, state)
471 argspec = inspect.getfullargspec(task.build_model)
472 if "from_checkpoint" in argspec.args:
--> 473 model = task.build_model(cfg.model, from_checkpoint=True)
474 else:
475 model = task.build_model(cfg.model)
File ~/fairseq/fairseq/tasks/speech_to_text.py:129, in SpeechToTextTask.build_model(self, args, from_checkpoint)
127 args.input_channels = self.data_cfg.input_channels
128 args.speaker_to_id = self.speaker_to_id
--> 129 return super(SpeechToTextTask, self).build_model(args, from_checkpoint)
File ~/fairseq/fairseq/tasks/fairseq_task.py:676, in LegacyFairseqTask.build_model(self, args, from_checkpoint)
664 """
665 Build the :class:`~fairseq.models.BaseFairseqModel` instance for this
666 task.
(...)
672 a :class:`~fairseq.models.BaseFairseqModel` instance
673 """
674 from fairseq import models, quantization_utils
--> 676 model = models.build_model(args, self, from_checkpoint)
677 model = quantization_utils.quantize_model_scalar(model, args)
678 return model
File ~/fairseq/fairseq/models/__init__.py:106, in build_model(cfg, task, from_checkpoint)
98 ARCH_CONFIG_REGISTRY[model_type](cfg)
100 assert model is not None, (
101 f"Could not infer model type from {cfg}. "
102 "Available models: {}".format(MODEL_DATACLASS_REGISTRY.keys())
103 + f" Requested model type: {model_type}"
104 )
--> 106 return model.build_model(cfg, task)
File ~/fairseq/fairseq/models/speech_to_text/xm_transformer.py:609, in XMTransformerModel.build_model(cls, args, task)
607 base_architecture(args)
608 if getattr(args, "load_pretrained_decoder_from", None):
--> 609 ckpt = torch.load(getattr(args, "load_pretrained_decoder_from", None))
610 decoder_args_dict = cls.get_decoder_args_from_checkpoint(ckpt["cfg"])
611 args = cls.override_decoder_args(args, decoder_args_dict)
File ~/venv/lib/python3.8/site-packages/torch/serialization.py:709, in load(f, map_location, pickle_module, **pickle_load_args)
706 if 'encoding' not in pickle_load_args.keys():
707 pickle_load_args['encoding'] = 'utf-8'
--> 709 with _open_file_like(f, 'rb') as opened_file:
710 if _is_zipfile(opened_file):
711 # The zipfile reader is going to advance the current file position.
712 # If we want to actually tail call to torch.jit.load, we need to
713 # reset back to the original position.
714 orig_position = opened_file.tell()
File ~/venv/lib/python3.8/site-packages/torch/serialization.py:240, in _open_file_like(name_or_buffer, mode)
238 def _open_file_like(name_or_buffer, mode):
239 if _is_path(name_or_buffer):
--> 240 return _open_file(name_or_buffer, mode)
241 else:
242 if 'w' in mode:
File ~/venv/lib/python3.8/site-packages/torch/serialization.py:221, in _open_file.__init__(self, name, mode)
220 def __init__(self, name, mode):
--> 221 super(_open_file, self).__init__(open(name, mode))
FileNotFoundError: [Errno 2] No such file or directory: '/checkpoint/pipibjc/u2u/umbart/interspeech/vp_ll.tr.lgtkn.en_es.mbart_large.alp0.7.bmeos.tps1020.mt1024.uf9.sh_tr.sh.ms128.enb.dnb.dr0.1.atdr0.1.actdr0.0.wd0.01.adam.eps1e-06.cl0.1.lr0.0003.wrm10000.fp16.lam10.mask0.3.rl1.rot0.mrnd0.1.ins0.pers1.0.seed2.ngpu64/checkpoint_42_500000.pt'
It would seem as though an additional file is required, under the name:
/checkpoint/pipibjc/u2u/umbart/interspeech/vp_ll.tr.lgtkn.en_es.mbart_large.alp0.7.bmeos.tps1020.mt1024.uf9.sh_tr.sh.ms128.enb.dnb.dr0.1.atdr0.1.actdr0.0.wd0.01.adam.eps1e-06.cl0.1.lr0.0003.wrm10000.fp16.lam10.mask0.3.rl1.rot0.mrnd0.1.ins0.pers1.0.seed2.ngpu64/checkpoint_42_500000.pt
What's your environment?
- fairseq Version (e.g., 1.0 or main):
main - PyTorch Version (e.g., 1.0):
1.13.0 - OS (e.g., Linux):
MacOS, Linux - How you installed fairseq (
pip, source): source - Build command you used (if compiling from source): https://github.com/facebookresearch/fairseq#requirements-and-installation
- Python version: 3.8.9
- CUDA/cuDNN version: -
- GPU models and configuration: -
- Any other relevant information: -
cc @sravyapopuri388
Hi, thanks for reaching out. Seems like the issue is with missing pretrained models. A quick fix is to update the paths in the checkpoint. You can update the w2v_path in the encoder to your local path and can set the load_pretrained_decoder_from to None. You can do something like
model = torch.load(MODEL_PATH) state = model["cfg"]["model"] state.w2v_path = W2V_PATH state.load_pretrained_decoder_from = None
Hi @sravyapopuri388, thanks for your prompt response!
To confirm, the w2v_path should be set to the local path? How does one then load the pre-trained model? I've tried:
import torch
import fairseq
DATA_PATH = "/Users/sanchitgandhi/Downloads"
MODEL_PATH = DATA_PATH + "/w2v2_mbart_LND_w_ASR.pt"
NEW_MODEL_PATH = DATA_PATH + "/w2v2_mbart_LND_w_ASR_new.pt"
model = torch.load(MODEL_PATH)
model["cfg"]["model"].w2v_path = MODEL_PATH
model["cfg"]["model"].load_pretrained_decoder_from = None
torch.save(model, NEW_MODEL_PATH)
fairseq_model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([NEW_MODEL_PATH], arg_overrides={"data": DATA_PATH, "task": "speech_to_text"})
But am encountering the same error as before:
FileNotFoundError: [Errno 2] No such file or directory: '/checkpoint/pipibjc/u2u/umbart/interspeech/vp_ll.tr.lgtkn.en_es.mbart_large.alp0.7.bmeos.tps1020.mt1024.uf9.sh_tr.sh.ms128.enb.dnb.dr0.1.atdr0.1.actdr0.0.wd0.01.adam.eps1e-06.cl0.1.lr0.0003.wrm10000.fp16.lam10.mask0.3.rl1.rot0.mrnd0.1.ins0.pers1.0.seed2.ngpu64/checkpoint_42_500000.pt'
Thank you for your help, it's very much appreciated!
@sanchit-gandhi
You have to change w2v_path into a model with load_pretrained_decoder_from=None, i.e. the NEW_MODEL_PATH? (even it does not exist when assgiend)
I believe It is because w2v_path=(the bugged w2v2_mbart_LND_w_ASR.pt) causes the program to load the bugged w2v again, which has the load_pretrained_decoder_from refers to a non-existing checkpoint.
Thanks @gmryu! I've changed as suggested, and also updated the config_yaml path as this was hardcoded to a local path too.
Now, I'm not getting any errors when loading the model, but the python process is 'killed' during execution (after ~30s). Were you able to load the pre-trained checkpoint? If so, would you mind sharing your code, even as a direct copy and paste with your absolute paths :)
Hi Sanchit, thanks for your patience. Could you try updating the w2v_path in w2v2_mbart_LND_w_ASR.pt state with the original W2v model args instead. More like
import torch
from omegaconf import OmegaConf
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
MODEL_PATH = Path/to/checkpoint
W2V_PATH = Path/to/w2v2/model
NEW_MODEL_PATH = Path/to/updated/model
model = torch.load(MODEL_PATH)
w2v = torch.load(W2V_PATH)
w2v_args = OmegaConf.create(w2v["cfg"])
state = model["cfg"]["model"]
state.w2v_path = None
state.load_pretrained_decoder_from = None
state.w2v_args = w2v_args
torch.save(model, NEW_MODEL_PATH)
Please let me know if this doesn't work in which case I will update the checkpoints on my end before sharing
Hey @sravyapopuri388, thanks for getting back to me! I actually went in and 'hardcoded' all the paths in the fairseq repo just to cover all grounds. The erroneous path issue seems to be resolved. I'm now attempting to load the checkpoint; given it's size (~10GB) I've been attempting this on a big CPU machine. Will let you know if there are any further issues!
Hi @sravyapopuri388! Thanks for the help. Overriding the W2V2 paths fixed the issue with loading the model. Just to confirm, to run the model correctly, no other changes are needed when instantiating the S2UT model, i.e. the decoder path can be left unchanged and the output features default to dim 1007?
Summary of state dict
XMTransformerModel(
(encoder): Wav2VecEncoderWithAdaptor(
(w2v_encoder): Wav2VecEncoder(
(w2v_model): Wav2Vec2Model(
(feature_extractor): ConvFeatureExtractionModel(
(conv_layers): ModuleList(
(0): Sequential(
(0): Conv1d(1, 512, kernel_size=(10,), stride=(5,))
(1): Dropout(p=0.0, inplace=False)
(2): Sequential(
(0): TransposeLast()
(1): Fp32LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(2): TransposeLast()
)
(3): GELU(approximate=none)
)
...
(6): Sequential(
(0): Conv1d(512, 512, kernel_size=(2,), stride=(2,))
(1): Dropout(p=0.0, inplace=False)
(2): Sequential(
(0): TransposeLast()
(1): Fp32LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(2): TransposeLast()
)
(3): GELU(approximate=none)
)
)
)
(post_extract_proj): Linear(in_features=512, out_features=1024, bias=True)
(dropout_input): Dropout(p=0.1, inplace=False)
(dropout_features): Dropout(p=0.1, inplace=False)
(quantizer): None
(project_q): None
(encoder): ConformerEncoder(
(pos_conv): Sequential(
(0): Conv1d(1024, 1024, kernel_size=(128,), stride=(1,), padding=(64,), groups=16)
(1): SamePad()
(2): GELU(approximate=none)
)
(layers): ModuleList(
(0): ConformerWav2Vec2EncoderLayer(
(ffn1): FeedForwardModule(
(layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(w_1): Linear(in_features=1024, out_features=4096, bias=True)
(w_2): Linear(in_features=4096, out_features=1024, bias=True)
(dropout1): Dropout(p=0.0, inplace=False)
(dropout2): Dropout(p=0.0, inplace=False)
(activation): SiLU(inplace=True)
)
(self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(self_attn_dropout): Dropout(p=0.0, inplace=False)
(self_attn): RelPositionMultiHeadedAttention(
(linear_q): Linear(in_features=1024, out_features=1024, bias=True)
(linear_k): Linear(in_features=1024, out_features=1024, bias=True)
(linear_v): Linear(in_features=1024, out_features=1024, bias=True)
(linear_out): Linear(in_features=1024, out_features=1024, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
(linear_pos): Linear(in_features=1024, out_features=1024, bias=False)
)
(conv_module): ConvolutionModule(
(layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(pointwise_conv1): Conv1d(1024, 2048, kernel_size=(1,), stride=(1,), bias=False)
(glu): GLU(dim=1)
(depthwise_conv): Conv1d(1024, 1024, kernel_size=(31,), stride=(1,), padding=(15,), groups=1024, bias=False)
(batch_norm): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activation): SiLU(inplace=True)
(pointwise_conv2): Conv1d(1024, 1024, kernel_size=(1,), stride=(1,), bias=False)
(dropout): Dropout(p=0.0, inplace=False)
)
(ffn2): FeedForwardModule(
(layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(w_1): Linear(in_features=1024, out_features=4096, bias=True)
(w_2): Linear(in_features=4096, out_features=1024, bias=True)
(dropout1): Dropout(p=0.0, inplace=False)
(dropout2): Dropout(p=0.0, inplace=False)
(activation): SiLU(inplace=True)
)
(final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
...
(23): ConformerWav2Vec2EncoderLayer(
(ffn1): FeedForwardModule(
(layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(w_1): Linear(in_features=1024, out_features=4096, bias=True)
(w_2): Linear(in_features=4096, out_features=1024, bias=True)
(dropout1): Dropout(p=0.0, inplace=False)
(dropout2): Dropout(p=0.0, inplace=False)
(activation): SiLU(inplace=True)
)
(self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(self_attn_dropout): Dropout(p=0.0, inplace=False)
(self_attn): RelPositionMultiHeadedAttention(
(linear_q): Linear(in_features=1024, out_features=1024, bias=True)
(linear_k): Linear(in_features=1024, out_features=1024, bias=True)
(linear_v): Linear(in_features=1024, out_features=1024, bias=True)
(linear_out): Linear(in_features=1024, out_features=1024, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
(linear_pos): Linear(in_features=1024, out_features=1024, bias=False)
)
(conv_module): ConvolutionModule(
(layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(pointwise_conv1): Conv1d(1024, 2048, kernel_size=(1,), stride=(1,), bias=False)
(glu): GLU(dim=1)
(depthwise_conv): Conv1d(1024, 1024, kernel_size=(31,), stride=(1,), padding=(15,), groups=1024, bias=False)
(batch_norm): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activation): SiLU(inplace=True)
(pointwise_conv2): Conv1d(1024, 1024, kernel_size=(1,), stride=(1,), bias=False)
(dropout): Dropout(p=0.0, inplace=False)
)
(ffn2): FeedForwardModule(
(layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(w_1): Linear(in_features=1024, out_features=4096, bias=True)
(w_2): Linear(in_features=4096, out_features=1024, bias=True)
(dropout1): Dropout(p=0.0, inplace=False)
(dropout2): Dropout(p=0.0, inplace=False)
(activation): SiLU(inplace=True)
)
(final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
)
(layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(embed_positions): RelPositionalEncoding()
)
(layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(final_proj): None
)
(final_dropout): Dropout(p=0, inplace=False)
)
(adaptor): Conv1dAdaptor(
(layers): ModuleList(
(0): Conv1d(1024, 2048, kernel_size=(3,), stride=(2,), padding=(1,))
)
)
)
(decoder): TransformerDecoder(
(dropout_module): FairseqDropout()
(embed_tokens): Embedding(1007, 1024, padding_idx=1)
(embed_positions): SinusoidalPositionalEmbedding()
(layers): LayerDropModuleList(
(0): TransformerDecoderLayerBase(
(dropout_module): FairseqDropout()
(self_attn): MultiheadAttention(
(dropout_module): FairseqDropout()
(k_proj): Linear(in_features=1024, out_features=1024, bias=True)
(v_proj): Linear(in_features=1024, out_features=1024, bias=True)
(q_proj): Linear(in_features=1024, out_features=1024, bias=True)
(out_proj): Linear(in_features=1024, out_features=1024, bias=True)
)
(activation_dropout_module): FairseqDropout()
(self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(encoder_attn): MultiheadAttention(
(dropout_module): FairseqDropout()
(k_proj): Linear(in_features=1024, out_features=1024, bias=True)
(v_proj): Linear(in_features=1024, out_features=1024, bias=True)
(q_proj): Linear(in_features=1024, out_features=1024, bias=True)
(out_proj): Linear(in_features=1024, out_features=1024, bias=True)
)
(encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(fc1): Linear(in_features=1024, out_features=4096, bias=True)
(fc2): Linear(in_features=4096, out_features=1024, bias=True)
(final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(1): TransformerDecoderLayerBase(
(dropout_module): FairseqDropout()
(self_attn): MultiheadAttention(
(dropout_module): FairseqDropout()
(k_proj): Linear(in_features=1024, out_features=1024, bias=True)
(v_proj): Linear(in_features=1024, out_features=1024, bias=True)
(q_proj): Linear(in_features=1024, out_features=1024, bias=True)
(out_proj): Linear(in_features=1024, out_features=1024, bias=True)
)
(activation_dropout_module): FairseqDropout()
(self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(encoder_attn): MultiheadAttention(
(dropout_module): FairseqDropout()
(k_proj): Linear(in_features=1024, out_features=1024, bias=True)
(v_proj): Linear(in_features=1024, out_features=1024, bias=True)
(q_proj): Linear(in_features=1024, out_features=1024, bias=True)
(out_proj): Linear(in_features=1024, out_features=1024, bias=True)
)
(encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(fc1): Linear(in_features=1024, out_features=4096, bias=True)
(fc2): Linear(in_features=4096, out_features=1024, bias=True)
(final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
...
(11): TransformerDecoderLayerBase(
(dropout_module): FairseqDropout()
(self_attn): MultiheadAttention(
(dropout_module): FairseqDropout()
(k_proj): Linear(in_features=1024, out_features=1024, bias=True)
(v_proj): Linear(in_features=1024, out_features=1024, bias=True)
(q_proj): Linear(in_features=1024, out_features=1024, bias=True)
(out_proj): Linear(in_features=1024, out_features=1024, bias=True)
)
(activation_dropout_module): FairseqDropout()
(self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(encoder_attn): MultiheadAttention(
(dropout_module): FairseqDropout()
(k_proj): Linear(in_features=1024, out_features=1024, bias=True)
(v_proj): Linear(in_features=1024, out_features=1024, bias=True)
(q_proj): Linear(in_features=1024, out_features=1024, bias=True)
(out_proj): Linear(in_features=1024, out_features=1024, bias=True)
)
(encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(fc1): Linear(in_features=1024, out_features=4096, bias=True)
(fc2): Linear(in_features=4096, out_features=1024, bias=True)
(final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
)
(layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(output_projection): Linear(in_features=1024, out_features=1007, bias=False)
)
)
Hi @sanchit-gandhi yes the output feature dimension size looks correct to me. Just overriding the w2v_args argument with the args from original w2v2 model and setting the load_pretrained_decoder_from path to None to avoid errors should work and no other changes are required. Please let me know if you encounter any other issues. Thanks!
Okay brilliant, thanks @sravyapopuri388! Closing this as the model can be correctly loaded with the aforementioned w2v_args + load_pretrained_decoder_from tricks.