audio
audio copied to clipboard
Using MMS model with `star` token for batch size > 1
The current implementation assumes batch size is one, when attaching the star dimension:
https://github.com/pytorch/audio/blob/ea437b31ce316ea3d66fe73768c0dcb94edb79ad/src/torchaudio/pipelines/_wav2vec2/utils.py#L41
However, the underlying Wav2vec model supports batch size greater than one. So this line should instead be:
star_dim = torch.zeros((output.size(0), output.size(1), 1), dtype=output.dtype, device=output.device)