Add IndexTTS
Context
Adds IndexTTS
Description
Implements Conformer and Perceiver. The main model is based on GPT2 architecture, it will leverge mlx-lm's GPT-2 implementation.
Additional information
This is not done yet. For now, it loads weights correctly. I need to write proper forward pass later.
Checklist
- [X] Loads weight
- [X] Do forward pass correctly
- [X] Decode the generated mel via BigVGAN
- [X] Tests added/updated
@Blaizzy I just noticed IndexTTS uses non-standard BigVGAN, which incoperates conditioning inside of BigVGAN using an ECAPA-TDNN model, such as:
for i in range(self.num_upsamples):
# Upsampling
for i_up in range(len(self.ups[i])):
x = self.ups[i][i_up](x)
+ if self.cond_in_each_up_layer:
+ x = x + self.conds[i](speaker_embedding)
Since it's not the standard BigVGAN implementation, I'm thinking of just implementing it in tts/model/indextts/bigvgan.py, or perhaps I should implement it somewhere like codec/model/bigvgan_indextts since it's still a codec?
What do you think would be more reasonable?
My 2ยข: if you can make the conditioning optional, it would be nice to have the implementation with the codecs since BigVGAN is pretty widely used. If that's too tricky because IndexTTS has invasive modifications, it's probably fine to just have it live alongside the model.
Awesome work @senstella!
@Blaizzy I just noticed IndexTTS uses non-standard BigVGAN, which incoperates conditioning inside of BigVGAN using an ECAPA-TDNN model,
I agree with @lucasnewman 100% here. It's the approach we usually use.
Rule of thumb: Start with having the custom logic with the model (i.e., IndexTTS), and if we see a similar pattern on the next model we port with BigVGAN codec, then we migrate it there.
Great example that demostrates this is SparkTTS, it uses ECAPA-TDNN and some custom modules with vocos Backbone and DAC layers.
Read more here: https://github.com/Blaizzy/mlx-audio/tree/main/mlx_audio/tts/models/spark/modules/encoder_decoder
Thank you! I just subclassed BigVGAN in indextts/bigvgan.py as BigVGANConditioning and applied conditioning there like:
class BigVGANConditioning(BigVGAN):
def __init__(self, config: BigVGANConditioningConfig):
super().__init__(config)
Also, I've implemented another instance of ECAPA-TDNN in indextts/ecapa_tdnn. Since SparkTTS already uses ECAPA-TDNN, I think it might be worth considering merging somewhere like mlx_audio/embedding with different sanitize method for each TTS model implementation - though I'm not entirely sure it's the best approach. I'll just let it live alongside with IndexTTS implementation for now!
I think it's ready for review!
https://huggingface.co/mlx-community/IndexTTS-1.5 https://huggingface.co/mlx-community/IndexTTS
python -m mlx_audio.tts.generate --model mlx-community/IndexTTS-1.5 --text "Describe this image." --ref_audio test.wav
(The model depends on given reference audio, so must provide the reference audio!)
Note that the test failure is related to #191, which is caused by mlx-lm's API refactoring.
Thank you very much @senstella!
I will review and merge it by Sunday ๐
And will also check this breaking change.
I noticed the WER in English seems quite a bit higher than they advertise in their results -- is that the case with the torch model as well? It does sound great when it gets the words right and the voice matching is impressive.
When you generate via python script, isn't ref_audio supposed to accept file name, not mlx array? I think other models accept file names if I'm not mistaken?
Oh, interesting, it looks like most models accept mlx.array, but Spark accepts Path, and Outetts accepts str?
I noticed the WER in English seems quite a bit higher than they advertise in their results -- is that the case with the torch model as well? It does sound great when it gets the words right and the voice matching is impressive.
I think the difference is they're using beam search for GPT-2 decoding, whereas I just used a standard top_k sampler here. Should I add beam search before merging?
When you generate via python script, isn't ref_audio supposed to accept file name, not mlx array? I think other models accept file names if I'm not mistaken?
I followed the common approach used with other models' implementations, but I can add the Path/str handling if it's required!
I was wrong generate function from other models also accepts mx.array. Sorry about that.
I've added a conv.py file with the WNConv1d and WNConvTranspose1d classes from descript/nn/layers.py, also I've updated the weights on HF. Should be all set now!
Note that the test failure is related to https://github.com/Blaizzy/mlx-audio/issues/191, which is caused by mlx-lm's API refactoring.
@senstella fixed in #194
@senstella the biganv tests are failing, could you look at them?
Test should pass now, the problem was WNConvTranspose1d calls mx.conv_transpose1d with parameters
y = mx.conv_transpose1d(
x, weight, self.stride, self.padding, self.dilation, self.groups
)
while mx.conv_transpose1d accepts:
(input, weight, stride, padding, dilation, output_padding, groups)
Therefore, the default value of group(1) slid into output_padding, which messed up the shaping.
Awesome, thanks for the fix!
One more thing I just noticed. The converted model has the codec weights embedded within it.
I typically keep codec weights separate from the main model, which allows us to quantize only the model weights while leaving the codec unquantized. This approach helps preserve output quality since codec degradation can significantly impact audio fidelity.
There's limited data on my approach just intuition so I could be off base here. Would you mind testing the quantized versions (i.e., 4bit or 6bit) to confirm the audio output still sounds acceptable?
I think MLX only quantizes modules that implement to_quantized(group_size, bits). However, BigVGAN codec doesn't contain any Linear nor Embedding. And Conv1d layers doesn't have to_quantized method, so in theory the codec wouldn't be quantized.
Here's some quantized model's generations and full model generation (group_size: 64): 4 bits / 6 bits / fp16
I also just realized I forgot to implement the text normalizer. It seems like English can be mostly mimiced in Python level easily. (Just replace some characters(: -> ,) and verbalize the number(4 -> four)) However, I'm not quite sure about Chinese. The package uses wetext which downloads the normalizer from ModelScope. (China's version of HuggingFace, I believe) Should I include the wetext package in the dependency, or just do the essential normalization in Python? I personally think minimal dependency would be favorable, but I'm not sure basic Python normalization would be sufficient for Chinese texts.
I just included the original wetext normalizer, but please let me know if you prefer otherwise!
I think MLX only quantizes modules that implement to_quantized(group_size, bits). However, BigVGAN codec doesn't contain any Linear nor Embedding. And Conv1d layers doesn't have to_quantized method, so in theory the codec wouldn't be quantized.
Awesome,
I heard the audio, sounds good to me โ
Thanks for clarifying and testing because this is true for vocos, snac and mimi as they have nn.Linear layers.
I just included the original wetext normalizer, but please let me know if you prefer otherwise!
Great, I like the normalizer!
However, I wonder if it's possible to normalize the text without adding extra dependencies? Because this will keep us light and make the swift port easier as well.
However, I wonder if it's possible to normalize the text without adding extra dependencies? Because this will keep us light and make the swift port easier as well.
I think it's quite doable in English. Just handling numbers and special characters should work fine in most cases. However, the problematic part is Chinese: I'm not familiar with the language, so I'm pretty uncertain which elements would require normalization nor I can't verify if it's spoken correctly.
No worries, for now @senstella we can remove wetext since we're not using it anywhere as far as I can see and merge this PR.
In the meantime, @yarshure and @thuongvovan - we'd appreciate your expertise and approval on the Chinese language implementation. Please feel free to submit a PR or open an issue to either restore wetext or enhance IndexTTS's text normalizer. When you do, please tag me and @senstella so we can review promptly.
Great! I removed the wetext dependency and replaced it with simple normalizer written in Python, that should mimic the original normalizer's behavior. I didn't perform any normalizations on Chinese text specifically, further issues or PRs related to this would be really appreciated!
@senstella Great to see mlx-audio supporting IndexTTS! Thank you so much for your efforts. Since IndexTTS 2.0 has been released, I'm wondering if you have any plans to support IndexTTS 2.0 as well.
@senstella Great to see mlx-audio supporting IndexTTS! Thank you so much for your efforts. Since IndexTTS 2.0 has been released, I'm wondering if you have any plans to support IndexTTS 2.0 as well.
Thanks for letting me know! I'll take a look at the paper try implementing them when I get some time. Was a little busy these days, sorry!
@senstella Great to see mlx-audio supporting IndexTTS! Thank you so much for your efforts. Since IndexTTS 2.0 has been released, I'm wondering if you have any plans to support IndexTTS 2.0 as well.
Thanks for letting me know! I'll take a look at the paper try implementing them when I get some time. Was a little busy these days, sorry!
I am glad to receive your reply. Thanks to your help, I have successfully run IndexTTS-1.5 on my m1 laptop, and found a problem in using it, IndexTTS does not need to use the ref_text parameter, SparkTTS does. This results in running IndexTTS without the ref_text parameter set, which automatically downloads โmlx-community/whisper-large-v3-turboโ to convert the audio to text.
Ref_text not found. Transcribing ref_audio...
Fetching 4 files: 100%|โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ| 4/4 [00:00<00:00, 44034.69it/s]
Ref_text So, just to clarify, we've had 19 to 20 year olds on our podcast so far, and if you don't mind me asking, how old are you?
if not ref_text:
print("Ref_text not found. Transcribing ref_audio...")
from mlx_audio.stt.models.whisper import Model as Whisper
stt_model = Whisper.from_pretrained(path_or_hf_repo=stt_model)
ref_text = stt_model.generate(ref_audio).text
print("Ref_text", ref_text)
# clear memory
del stt_model
mx.clear_cache()
Is it possible to avoid the above issue with the following code, which I have tested locally is working fine. Help to confirm this issue when you have time. Thanks and looking forward to your reply.
if not ref_text:
import inspect
if "ref_text" in inspect.signature(model.generate).parameters:
print("Ref_text not found. Transcribing ref_audio...")
from mlx_audio.stt.models.whisper import Model as Whisper
stt_model = Whisper.from_pretrained(path_or_hf_repo=stt_model)
ref_text = stt_model.generate(ref_audio).text
print("Ref_text", ref_text)
# clear memory
del stt_model
mx.clear_cache()
@senstella Great to see mlx-audio supporting IndexTTS! Thank you so much for your efforts. Since IndexTTS 2.0 has been released, I'm wondering if you have any plans to support IndexTTS 2.0 as well.
Thanks for letting me know! I'll take a look at the paper try implementing them when I get some time. Was a little busy these days, sorry!
I am glad to receive your reply. Thanks to your help, I have successfully run IndexTTS-1.5 on my m1 laptop, and found a problem in using it, IndexTTS does not need to use the
ref_textparameter, SparkTTS does. This results in running IndexTTS without theref_textparameter set, which automatically downloads โmlx-community/whisper-large-v3-turboโ to convert the audio to text.Ref_text not found. Transcribing ref_audio... Fetching 4 files: 100%|โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ| 4/4 [00:00<00:00, 44034.69it/s] Ref_text So, just to clarify, we've had 19 to 20 year olds on our podcast so far, and if you don't mind me asking, how old are you?if not ref_text: print("Ref_text not found. Transcribing ref_audio...") from mlx_audio.stt.models.whisper import Model as Whisper stt_model = Whisper.from_pretrained(path_or_hf_repo=stt_model) ref_text = stt_model.generate(ref_audio).text print("Ref_text", ref_text) # clear memory del stt_model mx.clear_cache()Is it possible to avoid the above issue with the following code, which I have tested locally is working fine. Help to confirm this issue when you have time. Thanks and looking forward to your reply.
if not ref_text: import inspect if "ref_text" in inspect.signature(model.generate).parameters: print("Ref_text not found. Transcribing ref_audio...") from mlx_audio.stt.models.whisper import Model as Whisper stt_model = Whisper.from_pretrained(path_or_hf_repo=stt_model) ref_text = stt_model.generate(ref_audio).text print("Ref_text", ref_text) # clear memory del stt_model mx.clear_cache()
Yeah, current implementation transcribes the ref_text no matter what model demands. Using inspect to check the signature looks good to me! You might want to create a new pull request to the repository and get it merged!