fairseq icon indicating copy to clipboard operation
fairseq copied to clipboard

[Potential Bug] in wav2vec (1.0?): targets always zero, loss computation potentially incorrect?

Open InfProbSciX opened this issue 1 year ago • 3 comments

🐛 Bug

wav2vec (the og one) from the wav2vec readme, when called, as I understand, should return the contrastive predictive logits from the paper that are used to calculate the CPC objective.

Consider the vanilla example:

import fairseq, torch
print(fairseq.__version__)  # 0.12.2

torch.manual_seed(42)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task(['fairseq/wav2vec_large.pt'])
model = model[0].train().to(device)

torch.manual_seed(42)

source = torch.rand(1, 32000).to(device) - 0.5

out = model(source)

print(out['cpc_targets'].min())  # tensor(0, device='cuda:0')
print(out['cpc_targets'].max())  # tensor(0, device='cuda:0')

This is problematic (right?) as the targets should reflect which of these 11 outputs is the positive sample for CPC and which 10 are the negative.

I tracked down the loss that's used - the wav2vec criterion, and it doesn't seem to do anything special with these targets?

I should mention, model.wav2vec_predictions.infonce is True. I believe that due to this condition, in model.wav2vec_predictions.forward, labels is created as:

labels = predictions.new_full((predictions.shape[0] // copies,), 0, dtype=torch.long)

but is not filled in.

The forward computation of model can also be reproduced by:

w2v_z = model.feature_extractor(source)
w2v_c = model.feature_aggregator(w2v_z)
cpc_logits, cpc_targets = model.wav2vec_predictions(w2v_c, w2v_z)

and it's model.wav2vec_predictions that doesn't seem to return the right targets.

InfProbSciX avatar Dec 13 '23 03:12 InfProbSciX

It seems like this wasn't the case in the oldest commit of the model: https://github.com/facebookresearch/fairseq/blob/392fce8a9873e54eca71cfca9d98f2685fdf6238/fairseq/models/wav2vec.py#L405 and was introduced around the time of vq-wav2vec?

InfProbSciX avatar Dec 13 '23 03:12 InfProbSciX

Also shouldn't wav2vec_predictions.infonce be set to False for wav2vec large?

InfProbSciX avatar Dec 13 '23 06:12 InfProbSciX

I have the same error, the data suddenly changed to all zeros in model, it may be speech module problem

krgy12138 avatar Dec 20 '23 15:12 krgy12138