fairseq
fairseq copied to clipboard
[Potential Bug] in wav2vec (1.0?): targets always zero, loss computation potentially incorrect?
🐛 Bug
wav2vec (the og one) from the wav2vec readme, when called, as I understand, should return the contrastive predictive logits from the paper that are used to calculate the CPC objective.
Consider the vanilla example:
import fairseq, torch
print(fairseq.__version__) # 0.12.2
torch.manual_seed(42)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task(['fairseq/wav2vec_large.pt'])
model = model[0].train().to(device)
torch.manual_seed(42)
source = torch.rand(1, 32000).to(device) - 0.5
out = model(source)
print(out['cpc_targets'].min()) # tensor(0, device='cuda:0')
print(out['cpc_targets'].max()) # tensor(0, device='cuda:0')
This is problematic (right?) as the targets should reflect which of these 11 outputs is the positive sample for CPC and which 10 are the negative.
I tracked down the loss that's used - the wav2vec criterion, and it doesn't seem to do anything special with these targets?
I should mention, model.wav2vec_predictions.infonce
is True
. I believe that due to this condition, in model.wav2vec_predictions.forward, labels is created as:
labels = predictions.new_full((predictions.shape[0] // copies,), 0, dtype=torch.long)
but is not filled in.
The forward computation of model can also be reproduced by:
w2v_z = model.feature_extractor(source)
w2v_c = model.feature_aggregator(w2v_z)
cpc_logits, cpc_targets = model.wav2vec_predictions(w2v_c, w2v_z)
and it's model.wav2vec_predictions
that doesn't seem to return the right targets.
It seems like this wasn't the case in the oldest commit of the model: https://github.com/facebookresearch/fairseq/blob/392fce8a9873e54eca71cfca9d98f2685fdf6238/fairseq/models/wav2vec.py#L405 and was introduced around the time of vq-wav2vec?
Also shouldn't wav2vec_predictions.infonce
be set to False
for wav2vec large?
I have the same error, the data suddenly changed to all zeros in model, it may be speech module problem