a Few questions about Scaling, Non-verbal cues, etc.
Hi! Thanks for the awesome work. I've been following and waiting for the paper for a while and finally had the pleasure of reading it.
I have a few questions I'd appreciate if you could answer, or at least share your thoughts with me:
1. Model Scaling
Sounds like the model is around 200M parameters. That's relatively tiny for an AR model these days. Does this model scale well to higher parameters, for example up to 3B like Parakeet?
2. Acoustic Quality
I've tried zero-shot audio prompting + unconditional generation. When it works, it sounds expressive and pretty good to me, but the acoustic quality leaves a lot to be desired. I know LibriSpeech isn't exactly a good dataset to begin with, but is there any remedy, or anything I should pay attention to if I'm going to train from scratch?
3. Non-verbal Cues and Speaker Switching, Description prompting
One of the high points of Parakeet/Dia was the fact you could add non-verbal cues (laughter, sniff, cough tokens etc.) or switch speakers in one generation. Since we don't rely on monotonic search or forced aligners, I assume the same can be done here as well.
But should I define a dedicated token for these sounds, or can the model learn the correct mappings even if, for example, we just use something like "(laugh)" or "haha" without defining the entire word as a single token? or would a [Speaker N] token change the identity of the speaker ? or the dialect / language etc.
I'm asking this because classic TTS models usually can't handle it properly this way.
4. Codec Replacement
I also want to change the codec to something that supports 44.1kHz. Hopefully it will improve the acoustic quality. DAC or NVIDIA's Audio Codec or even Encodec 48khz seem appropriate, though the HF's "processor" logic must be dealt with. Overall, is there anything in particular I need to pay attention to? (such as Encodec related hardcoded configs)
5. Global Style Encoder
I know it's very unconventional, but I'm thinking about adding a global style encoder that takes the same codec latents as the input (as opposed to mel) and add it as the third conditioning signal, since I assume they already have the necessary features to reconstruct audio. It's particularly useful for long-form generation and we can skip the whole prompt prefixing.
I assume this line is where the conditioning happens? But moreover, is the loss objective even compatible with this idea? So far I've only seen style encoders to be trained with a reconstruction loss.
6. Training Logs
It would be great if you could share your training logs.
Anyway, I appreciate your time. This is one of the more interesting projects I've seen in a while and I really want to give it a shot! Thank you.
@Respaired Hi, thanks for your interests. I am glad to share the following information.
1. Model Scaling
We have attempted to train a 0.5B model using Emilia's Chinese data (initialized with a Qwen2.5-0.5B model) and achieved good generation quality; further scaling to 3B requires additional validation.
2. Acoustic Quality
The model we provide was trained on the Libriheavy, in order to align with previous research; however, similar to Librispeech, Libriheavy was not designed for TTS tasks, and its audio quality is not high, which has affected the generation quality.
You might consider further fine-tuning on some high-quality datasets (perhaps LibriTTS or in-house data?). Some research has shown that pretraining on large-scale, general-quality data followed by fine-tuning on high-quality data can improve speech quality. Additionally, Encodec actually takes 24kHz audio as input, while Libriheavy consists of 16kHz data, so we performed upsampling during training. Therefore, I believe that further fine-tuning with high-quality 24kHz data would help improve the quality.
3. Non-verbal Cues and Speaker Switching, Description prompting
I believe that language modeling-based TTS models have strong potential for such tasks, as scaling the model and data enables compositional generalization across speech and text generation capabilities. Whether to use dedicated tokens or allow the model to learn directly from text depends on the data. In my view, both approaches are feasible for language modeling-based TTS models. Setting dedicated tokens is more suitable for scenarios with smaller amounts of data, but it needs more efforts in labeling.
4. Codec Replacement
When you choose to replace the latent encoder, I think the following aspects are particularly worth paying attention to:
Latent Dimensionality
Latent dimensionality is a very important parameter. Since we use a lightweight MLP to model the continuous latent distribution at each step, we prefer the latent dimensionality to be not large. (In fact, one of the motivations behind latent generative modeling (e.g. Latent Diffusion) was to reduce the dimensionality of the continuous space.) In our experiments, we used Encodec which has a 128-dimensional latent space, but we have not verified whether higher-dimensional latent encoders would work (e.g., DAC uses a 1024-dimensional latent space).
How the Latent Space is Regularized
As we pointed out in the Sec. 5.2 of the paper, the latent space should be regularized for the downstream generative modeling. (There can be a latent space with minimum information loss but not suitable for generative modeling.) Basically, there are two types of regularization methods: KL-regularization or VQ-regularization. (Refer to App.G of this paper).
If you're using an audio codec as the latent encoder, you're effectively adopting the VQ-regularized approach. Most audio codecs use residual vector quantization, meaning the degree of regularization is controlled by the number of codebooks used to construct the latent representation. Using the pre-VQ representation or too many codebooks may preserve more information, but it tends to produce an overly smooth latent space, which can hurt the performance of the downstream generative model. Conversely, using fewer codebooks increases information loss but results in a better-regularized latent space.
This presents a trade-off between information preservation and regularization, and it needs to be carefully tuned for optimal performance.
5. Global Style Encoder
This line is used to construct the input of the AR model. If you add more conditions to the AR model (e.g., at the beginning), you don't need to compute the loss at those steps. This is similar to how we skip loss calculation for time steps corresponding to input text tokens.
6. Training Logs
I'll check whether we have permission to update the logs and get back to you shortly.
Feel free to reach me if you have further questions.
@Paulmzr
Thank you very much for taking your time to answer me; I really appreciate it.
I agree 100% that Libri on general is not high quality. Since my domain / language is wholly different from English for this next project, I will have to train from Scratch, but it doesn't matter since I already have a preprocessed dataset (700 to 10K hours). I plan on training at least 250K steps with a lower Batch size than what was mentioned in the paper (only have 4x A6000).
hopefully it'll converge. usually with models that don't rely on external forced aligners it should take around this much of training to get something.
About the Latent space regularization and the choice of codec, I feel like EnCodec is compressing way too much (Artifacts can be heard even at the highest bitrate) especially at the default 6.0 kbps, and honestly doesn't sound good compared to DAC across different sampling rates.
As you know DAC returns something like this :
x = model.preprocess(signal.audio_data, signal.sample_rate)
z, codes, latents, _, _ = model.encode(x)
I assume z is what we want since it is the vector-quantized information. (or the pre-projected latents?)
if we are going to use z then in that case I wonder if we even need to project it using the:
https://github.com/ictnlp/SLED-TTS/blob/b8ed10d9953160efd8a0538b4ea5af80a57c9e96/sled/sled.py#L60 since it is already mapped to 1024 dims.
First of all DAC has 9 codebooks and each codebook has a dimensionality of 8 (opposed to 128 for EnCodec), isn't this still heavily compressed or regularized? the latents have a dim of 72. so it's even less than EnCodec.
if we need the post VQ tensor as proposed, Why couldn't we down sample or project it to a lower space then up sample it for decoding? it adds additional learning parameters but it wouldn't cost us too much and we can still use a better codec.
or maybe we could make the MLP heavier, though i'm not sure how much it goes against the motivation of this research:
# Score Arguments
vae_embed_dim: int = 1024 # from 128
diffloss_d: int = 4 # from 3
diffloss_w: int = 1536 # from 1024
training_cfg: float = 0.0
noise_channels: int = 512 # from 128
Again, I understand these are a lot of questions so I can't thank you enough for this work and your time!
Hi, @Respaired
https://github.com/ictnlp/SLED-TTS/blob/b8ed10d9953160efd8a0538b4ea5af80a57c9e96/sled/sled.py#L60
I think this line is unnecessary, if the dimensions are already 1024.
First of all DAC has 9 codebooks and each codebook has a dimensionality of 8 (opposed to 128 for EnCodec), isn't this still heavily compressed or regularized? the latents have a dim of 72. so it's even less than EnCodec.
I don’t quite understand why you think the latents have a dimension of 72. If I remember correctly, in DAC the latent before RVQ is 1024-dimensional, and the residual is projected down to 8 dimensions before each quantization operation. Therefore, the latent representation should be 1024-dimensional if you use DAC.
Why couldn't we down sample or project it to a lower space then up sample it for decoding? it adds additional learning parameters but it wouldn't cost us too much and we can still use a better codec.
Doing this would mean you need to add extra downsample and upsample modules to the codec, and continue training with reconstruction loss. Perhaps you can consider tuning only these newly added modules while keeping the pretrained parameters of the codec frozen.
or maybe we could make the MLP heavier, though i'm not sure how much it goes against the motivation of this research:
This is feasible, but it contradicts the original motivation of using a latent encoder. My point is that the latent dimension shouldn't be too large, similar to how a discrete model shouldn't have an excessively large vocabulary size.
Overall, the choice of codec indeed acts as a bottleneck limiting continuous-domain generative models, and designing a specialized latent encoder for them has the potential to significantly enhance their performance.
Hi, @Respaired
Line 60 in b8ed10d self.z_proj = nn.Linear(self.token_embed_dim, self.hidden_size, bias=True)
I think this line is unnecessary, if the dimensions are already 1024.
First of all DAC has 9 codebooks and each codebook has a dimensionality of 8 (opposed to 128 for EnCodec), isn't this still heavily compressed or regularized? the latents have a dim of 72. so it's even less than EnCodec.
I don’t quite understand why you think the latents have a dimension of 72. If I remember correctly, in DAC the latent before RVQ is 1024-dimensional, and the residual is projected down to 8 dimensions before each quantization operation. Therefore, the latent representation should be 1024-dimensional if you use DAC.
Why couldn't we down sample or project it to a lower space then up sample it for decoding? it adds additional learning parameters but it wouldn't cost us too much and we can still use a better codec.
Doing this would mean you need to add extra downsample and upsample modules to the codec, and continue training with reconstruction loss. Perhaps you can consider tuning only these newly added modules while keeping the pretrained parameters of the codec frozen.
or maybe we could make the MLP heavier, though i'm not sure how much it goes against the motivation of this research:
This is feasible, but it contradicts the original motivation of using a latent encoder. My point is that the latent dimension shouldn't be too large, similar to how a discrete model shouldn't have an excessively large vocabulary size.
Overall, the choice of codec indeed acts as a bottleneck limiting continuous-domain generative models, and designing a specialized latent encoder for them has the potential to significantly enhance their performance.
Maybe it's due to the naming convention that caused this confusion, but the latent variable is each codebook's dim(8 for dac) multiplied by number of codebooks (9 in dac) so we get something like this.
so my assumption is that the Z is already the summed quantized continuous version. because it's projected to 1024
I understand your point clearly. but I guess there's simply no other choice, I don't remember any 44.1khz codec that satisfies this constraint. but regardless, I believe even that latent is in a compressed state, I will test and report back to you. being limited to EnCodec is a deal-breaker.
one last thing, if you can't provide the logs, may i ask you, heuristically at what optimizer step did you achieve some basic alignment and got something that resembles speech? I'd rather jump directly into testing DAC and test that hypothesis if you can give me any info that could act as my baseline.
@Paulmzr
Looks like Nvidia's audio codec 44.1khz have a dim of 32. it's slightly slower than DAC and moderately slower than EnCodec. but seems like a very good alternative to me.
with torch.no_grad():
encoded_tokens, encoded_len = nemo_codec_model.encode(audio=audio_tensor, audio_len=audio_len)
post_quantized_latents = nemo_codec_model.dequantize(tokens=encoded_tokens, tokens_len=encoded_len)
print(f"Shape of encoded_tokens: {encoded_tokens.shape}")
print(f"Shape of post_quantized_latents: {post_quantized_latents.shape}")
Shape of encoded_tokens: torch.Size([1, 8, 436])
Shape of post_quantized_latents: torch.Size([1, 32, 436])
I'll try to test all these in the next few weeks once I have a good baseline.
Hello there! Thanks for an amazing work! I'm a bit new to neural codecs in general and RVQ ones even more. So I need some time to grasp my understanding on them before I fully understand your discussion here. However, just wanted to clearify some questions upfront. So the RVQ codecs are currently the SOTA afaik. Can you please outline is it possible to integrade some RVQ codecs here instead of Encodec or your method is not designed for them? My particular interest is in Mimi and SNAC, but I'm also currently looking at acoustic/speech tokenizers from a recently published VibeVoice model by Microsoft.
@Paulmzr
Answering my own question based on several discussions here, looks like it's possible yet we have to consider using pre-VQ latents for conditioning as well as the recommendations given in this thread.
I think he said pre-VQ latents are too smooth iirc. The mlp in this model is very picky when it comes to what codec you can use. I had very limited success swapping it with another codec. SNAC (anything that has something to do with DAC) is too big for this architecture and not a trivial thing to make them work here
@Respaired oh I see... Sad to hear that( Have you tried putting a heavier projection module instead of MLP to battle that? And can you recommend any other papers in that direction that look more promising or maybe you had any success with? I'm working on a constrained ARM device for the TTS and I'm trying to illuminate as many moving parts as I can like several decoders, hierarchical codes etc from the speech model. This work seemed promising in that regard