metavoice-src icon indicating copy to clipboard operation
metavoice-src copied to clipboard

Faster inference: Implemented EOT for causal sampling stopping

Open SinanAkkoyun opened this issue 1 year ago • 2 comments

Through printing each generated token id tensor, I found token 2048 to be the EOT token.

Before, the inference code sampled all tokens despite EOTs being emitted. Therefore, I implemented early stopping in casual.py.

SinanAkkoyun avatar Feb 09 '24 02:02 SinanAkkoyun

@SinanAkkoyun Do you know how much faster this would be ?

jay2jp avatar Feb 09 '24 15:02 jay2jp

@jay2jp Yup! With a 4090 and flash decoding, around 850 tokens (the amount of data that was being generated before, no matter how short the sentence) take 8 to 9 seconds. With this PR, when you for example generate the sentence "I would love to see a speedup like that." it will take around 2 to 3 seconds.

But a long sentence will still take 8 seconds, as I just stop generation when the sentence ends.

SinanAkkoyun avatar Feb 09 '24 15:02 SinanAkkoyun

Thanks Sinan, I had a brief look and I think this won't work with batching. I have this + a few other changes implemented in https://github.com/metavoiceio/metavoice-src/pull/46 so I'll close your PR :)

pyetras avatar Feb 12 '24 18:02 pyetras