audio icon indicating copy to clipboard operation
audio copied to clipboard

Using MMS model with `star` token for batch size > 1

Open huangruizhe opened this issue 1 year ago • 1 comments

The current implementation assumes batch size is one, when attaching the star dimension: https://github.com/pytorch/audio/blob/ea437b31ce316ea3d66fe73768c0dcb94edb79ad/src/torchaudio/pipelines/_wav2vec2/utils.py#L41

However, the underlying Wav2vec model supports batch size greater than one. So this line should instead be:

star_dim = torch.zeros((output.size(0), output.size(1), 1), dtype=output.dtype, device=output.device) 

huangruizhe avatar Apr 12 '24 03:04 huangruizhe