fairseq
fairseq copied to clipboard
AttributeError: 'dict' object has no attribute 'replace', when loading vq-wav2vec pretrained model as in the instructions
🐛 Bug
I just run the code from the instructions in Google Colab:
import torch
import fairseq
cp = torch.load('/path/to/vq-wav2vec.pt')
model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp])
model = model[0]
model.eval()
wav_input_16khz = torch.randn(1,10000)
z = model.feature_extractor(wav_input_16khz)
_, idxs = model.vector_quantizer.forward_idx(z)
print(idxs.shape) # output: torch.Size([1, 60, 2]), 60 timesteps with 2 indexes corresponding to 2 groups in the model
And I got:
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-2-08d9c7fdb6af> in <module>()
3
4 cp = torch.load('/content/vq-wav2vec.pt')
----> 5 model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp])
6 model = model[0]
7 model.eval()
/usr/local/lib/python3.7/dist-packages/fairseq/checkpoint_utils.py in load_model_ensemble_and_task(filenames, arg_overrides, task, strict, suffix, num_shards)
272 for shard_idx in range(num_shards):
273 if num_shards == 1:
--> 274 filename = filename.replace(".pt", suffix + ".pt")
275 else:
276 filename = orig_filename[:-3] + f"_part{shard_idx}.pt"
AttributeError: 'dict' object has no attribute 'replace'
To Reproduce
Steps to reproduce the behavior (always include the command you ran):
- Open Google Colab
- Run
!pip install --upgrade fairseq
- Run
!pip install --upgrade torch
- Run
!wget https://dl.fbaipublicfiles.com/fairseq/wav2vec/vq-wav2vec.pt
- Run
import torch
import fairseq
cp = torch.load('/content/vq-wav2vec.pt')
model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp])
- Get error.
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-2-08d9c7fdb6af> in <module>()
3
4 cp = torch.load('/content/vq-wav2vec.pt')
----> 5 model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp])
6 model = model[0]
7 model.eval()
/usr/local/lib/python3.7/dist-packages/fairseq/checkpoint_utils.py in load_model_ensemble_and_task(filenames, arg_overrides, task, strict, suffix, num_shards)
272 for shard_idx in range(num_shards):
273 if num_shards == 1:
--> 274 filename = filename.replace(".pt", suffix + ".pt")
275 else:
276 filename = orig_filename[:-3] + f"_part{shard_idx}.pt"
AttributeError: 'dict' object has no attribute 'replace'
Code sample
!pip install --upgrade fairseq &> /dev/null
!pip install --upgrade torch &> /dev/null
!wget https://dl.fbaipublicfiles.com/fairseq/wav2vec/vq-wav2vec.pt`
import torch
import fairseq
cp = torch.load('/content/vq-wav2vec.pt')
model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp])
Expected behavior
The model should be loaded without a problem.
Environment
- pip version: 19.3.1
- fairseq Version (e.g., 1.0 or master): 0.10.2
- PyTorch Version (e.g., 1.0): 1.8.1+cu101
- OS (e.g., Linux): Ubuntu 18.04.5 LTS (Google Colab)
- How you installed fairseq (
pip
, source): pip - Build command you used (if compiling from source):
- Python version: 3.7.10
- CUDA/cuDNN version: cuda_11.0_bu.TC445_37.28845127_0
- GPU models and configuration:
Sun Jun 13 01:47:37 2021
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.27 Driver Version: 460.32.03 CUDA Version: 11.2 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 Tesla P100-PCIE... Off | 00000000:00:04.0 Off | 0 |
| N/A 36C P0 26W / 250W | 2MiB / 16280MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| No running processes found |
+-----------------------------------------------------------------------------+
- Any other relevant information:
Additional context
I know what is the problem, the README is not up do date and the loading inside fairseq.checkpoint_utils.load_model_ensemble_and_task([cp])
should be done with file paths and not with dictionaries.
But, when I load it with file path:
import torch
import fairseq
cp_path = '/content/vq-wav2vec.pt'
model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp_path])
model = model[0]
model.eval()
wav_input_16khz = torch.randn(1,10000)
z = model.feature_extractor(wav_input_16khz)
_, idxs = model.vector_quantizer.forward_idx(z)
print(idxs.shape) # output: torch.Size([1, 60, 2]), 60 timesteps with 2 indexes corresponding to 2 groups in the model
I get a different error:
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
<ipython-input-3-23ce555cb9fe> in <module>()
3
4 cp_path = '/content/vq-wav2vec.pt'
----> 5 model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp_path])
6 model = model[0]
7 model.eval()
2 frames
/usr/local/lib/python3.7/dist-packages/fairseq/checkpoint_utils.py in load_model_ensemble_and_task(filenames, arg_overrides, task, strict, suffix, num_shards)
277 if not PathManager.exists(filename):
278 raise IOError("Model file not found: {}".format(filename))
--> 279 state = load_checkpoint_to_cpu(filename, arg_overrides)
280 if shard_idx == 0:
281 args = state["args"]
/usr/local/lib/python3.7/dist-packages/fairseq/checkpoint_utils.py in load_checkpoint_to_cpu(path, arg_overrides)
230 for arg_name, arg_val in arg_overrides.items():
231 setattr(args, arg_name, arg_val)
--> 232 state = _upgrade_state_dict(state)
233 return state
234
/usr/local/lib/python3.7/dist-packages/fairseq/checkpoint_utils.py in _upgrade_state_dict(state)
437 choice = getattr(state["args"], registry_name, None)
438 if choice is not None:
--> 439 cls = REGISTRY["registry"][choice]
440 registry.set_defaults(state["args"], cls)
441
KeyError: 'binary_cross_entropy'
So it has a problem either way.
Have you dealt with this bug?thank you
Change this line
model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp])
to
model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task({"vq-wav2vec.pt":cp})
pip install --upgrade git+https://github.com/pytorch/fairseq.git@0f078de343d985e0cba6a5c1dc8a6394698c95c7