DeepLearningExamples icon indicating copy to clipboard operation
DeepLearningExamples copied to clipboard

[Fastpitch] Multi-speaker model changes output speaker identity for different texts

Open adrianastan opened this issue 2 years ago • 28 comments

Hi,

We are trying to train a multi-speaker model starting from the LibriTTS data and using the latest FastPitch commit. We selected the 50 speakers which have the most utterances in the dataset, and removed the single-word ones (resulting in around 8400 samples in the training set). The model was trained for 1500 epochs and the output quality is quite alright.

However, when using the same speaker ID and synthesising multiple texts, we get slightly different output identities. Did anyone else encounter this problem? Is there anything we can tweak in the model to make sure this does not happen? Or what else can we do to enforce a single output speaker ID across synthesised utterances?

Thanks, Adriana

adrianastan avatar Feb 26 '22 10:02 adrianastan

Hi @adrianastan

Sorry for replying late. I haven't got much experience with that many speakers, but I'd try to add capacity to the model and look at class imbalance - if the speakers are indeed imbalanced, you can try weighting the speakers somehow, or take the easy route and just repeat those speakers' utterances in the filelist.

alancucki avatar Mar 11 '22 13:03 alancucki

Hi, thank you for your reply.

We also trained a model with the exact same number of utterances and the same text from 37 different speakers, and the results are the same. I assume that the simple summation of the speaker embedding to the text encoding is not strong enough to preserve the speaker identity (https://github.com/NVIDIA/DeepLearningExamples/blob/de507d9fecfbdd50ad001bdb15e89f8eae46871e/PyTorch/SpeechSynthesis/FastPitch/fastpitch/transformer.py#L207)

So I am wondering if anybody else tried the multispeaker model and found alternative ways of performing this conditioning.

Thanks!

adrianastan avatar Mar 12 '22 11:03 adrianastan

I assume that the simple summation of the speaker embedding to the text encoding is not strong enough to preserve the speaker identity

That might be the case. Positional embeddings are sometimes added also in between the layers to keep the positional information from fading out. Doing the same with speaker embedding might be worth a shot.

How many utterances do you roughly have per speaker? Are there both male and female speakers?

alancucki avatar Mar 14 '22 17:03 alancucki

I now added the embedding to condition the decoder as well here: https://github.com/NVIDIA/DeepLearningExamples/blob/de507d9fecfbdd50ad001bdb15e89f8eae46871e/PyTorch/SpeechSynthesis/FastPitch/fastpitch/model.py#L314 . But the results aren't any better.

We use both male and female speakers, ranging from 137 to 321 utterances. For the model where we used the same data and same number of utterances, there were 638 utterances from each speaker (37 speakers in total).

adrianastan avatar Mar 14 '22 17:03 adrianastan

I assume that the simple summation of the speaker embedding to the text encoding is not strong enough to preserve the speaker identity

That might be the case. Positional embeddings are sometimes added also in between the layers to keep the positional information from fading out. Doing the same with speaker embedding might be worth a shot.

@alancucki : was trying out this comment on adding speaker embedding per layer. did as below.

--- a/fastpitch/transformer.py
+++ b/fastpitch/transformer.py
@@ -211,6 +211,7 @@ class FFTransformer(nn.Module):
 
         for layer in self.layers:
             out = layer(out, mask=mask)
+            out += conditioning
 
         # out = self.drop(out)
         return out, mask

is that the way you were thinking?

dsplog avatar Mar 19 '22 06:03 dsplog

Exactly! Did the quality improve?

alancucki avatar May 06 '22 14:05 alancucki

broadly i can say yes. for few speakers, it gives a noticeable improvement.

however, there are speakers still not getting captured well. it maybe due to not having sufficient data for those speakers. checking still.... :-)

dsplog avatar May 09 '22 02:05 dsplog

Hi,

I tried multi speaker Fastpitch for 2200 Epochs (2 speaker M, F each having 15,000 sentences per speaker , total of 50 hours of voice data). But the output is just noise & the model isn't learning. Whereas I was able to train the model from scratch for single speaker and the output was quite alright.

could anyone help me out here? Is there something Important to configure when training Multispeaker FastPitch that I missed out?

Thank you

rygopu avatar May 09 '22 14:05 rygopu

We also used external speaker embeddings (derived from SpeechBrain's model: https://speechbrain.github.io/) as opposed to having FastPitch learn them. This helped a bit, but it still fails at generating short utterances and sometimes even longer ones in the desired speaker's identity.

adrianastan avatar May 09 '22 14:05 adrianastan

Thank you ! I'll check it out. Just to clarify, Did you use external speaker embedding because the output from FastPitch was purely noise or did you use it to Improve utterance pronunciations?

rygopu avatar May 09 '22 14:05 rygopu

Just to improve the speaker control.

adrianastan avatar May 09 '22 17:05 adrianastan

@rygopu The model should not fail completely; even after ~100 epochs you should be able to synthesize a fairly intelligible output. I'd suggest a bit of debugging before drawing conclusions.

In particular please check the data pre-processing pipeline, and make sure that your speaker IDs do get loaded and used by the model. Also, the alignment module might fail to converge for whatever reason.

alancucki avatar May 10 '22 18:05 alancucki

Thanks @alancucki & @adrianastan. I retrained FastPitch (2 speakers: 13,000 sentences per speaker, 4500 epochs) and the output is still purely noise [Debug: I tried loading different checkpoints and it seems that the right speaker embeddings are loaded by the model and I could see the embeddings being learnt as the model training progress, but the still the output is same as previous runs] Do you have any leads that I could try?

rygopu avatar May 16 '22 15:05 rygopu

Is it noise-noise or speech-like noise? Is the symbol list you use at training the same as the one used at inference? Is the transcription correct and aligned with the audio?

adrianastan avatar May 17 '22 06:05 adrianastan

  • @adrianastan It's just noise-noise. Yes, symbol list is same during Inference & training. Yes, transcription is aligned as well.

  • What were the audio pre-processing parameters that you used? Also, did you downsample the audio Sampling rate to 22050?

  • Would you be able to share the dataset that you used / audio pre-processing parameters?

rygopu avatar May 17 '22 13:05 rygopu

I am afraid I cannot share the dataset, but I did downsample the audio, trimmed the silence and normalised the volume.

Maybe you can try using one of the single-speaker models as starting point for your multi-speaker one.

adrianastan avatar May 17 '22 14:05 adrianastan

Hi @adrianastan , thank you. I trained Multi-Speaker Fastpitch with the same dataset (Issue was related to pre-processing, downsampling using ffmpeg rather than using librosa has solved it). Also when I run Inference for multiple texts for the same speaker, output quality and Identity seems alright. (my training dataset consists of 15,000 sentences per speaker for a total of 3 different speakers)

rygopu avatar May 25 '22 16:05 rygopu

Thanks for your reply! We are using 50 speakers with 200 utts/speaker, and it still changes the identity. We are now retraining using the ideas here: https://github.com/NVIDIA/DeepLearningExamples/issues/707#issuecomment-727021066

adrianastan avatar May 26 '22 06:05 adrianastan

Hi, @alancucki,

So we tried all the methods mentioned so far:

  • balancing the data
  • adding the speaker conditioning on the decoder side, as well
  • using two attention heads
  • doubling the parameters as here https://github.com/NVIDIA/DeepLearningExamples/issues/707#issuecomment-727021066

But there are still no improvements. The identity still changes for different text inputs. Are there any other ideas you think are worth trying out?

Thanks, Adriana

adrianastan avatar Jul 20 '22 08:07 adrianastan

@adrianastan, have you tried concatenating the speaker embeddings to the text encoding (by repeating it for each symbol)?

martinvk1 avatar Feb 16 '23 11:02 martinvk1

But this is what happens now in FastPitch: https://github.com/NVIDIA/DeepLearningExamples/blob/afea561ecff80b82f17316a0290f6f34c486c9a5/PyTorch/SpeechSynthesis/FastPitch/fastpitch/transformer.py#L207

adrianastan avatar Feb 16 '23 11:02 adrianastan

@adrianastan No, the speaker embedding is summed with the input and positional encoding, not concatenated. This kind of summation should be acceptable for positional encoding, but it is not suited for adding a speaker embedding. I think it would work much better to concatenate, similar to what has been done in multispeaker Tacotron in the past (for example here: https://github.com/CorentinJ/Real-Time-Voice-Cloning/blob/98d0ca4d4d140a4bb6bc7d54c84b1915a79041d5/synthesizer/models/tacotron.py#L62)

After concatenating the dimensionality will change, so other parameters need to be adjusted accordingly. Alternatively, a Linear layer could project back down into the original dimension. Not sure how well that would work.

martinvk1 avatar Feb 16 '23 12:02 martinvk1

Ok, got it, still this only means that instead of having equal weights in the summation, the network learns the individual summation weights.

So did you try this and got better results? Thanks!

adrianastan avatar Feb 20 '23 08:02 adrianastan

@adrianastan In my experiment I increased the dimensionality of the encoder to fit the embedded symbols with speaker information concatenated to them. That way, the encoder receives intact and clearly separated features and can decide how to deal with it. It may be better to add speaker information after symbol encoding or even later, depending on the goal. So far, I haven't had any problems with speaker similarity, but I ran into other challenges with multispeaker FastPitch, such as needing mean/std pitch values for every speaker in the dataset, and also to make sure that each speaker has both short, medium and long utterances. I find that transformers generalize badly to input lengths they have not explicitly been trained on for a given speaker.

martinvk1 avatar Feb 20 '23 09:02 martinvk1

Ok, great, I will give it a try. Thanks!

adrianastan avatar Feb 21 '23 08:02 adrianastan

@adrianastan Hope it works out for you. Just wanted to add that I got great results by concatenating the speaker embedding directly to the input of the 1) pitch predictor 2) duration predictor 3) energy predictor and 4) decoder, like below. I am using a speaker_embedding_dim of 256, symbol_embedding_dim of 512 which means that these layers are now 768.

--- a/PyTorch/SpeechSynthesis/FastPitch/fastpitch/model.py
+++ b/PyTorch/SpeechSynthesis/FastPitch/fastpitch/model.py
@@ -262,14 +262,17 @@ class FastPitch(nn.Module):
             spk_emb.mul_(self.speaker_emb_weight)
 
         # Input FFT
-        enc_out, enc_mask = self.encoder(inputs, conditioning=spk_emb)
+        enc_out, enc_mask = self.encoder(inputs, conditioning=0) # Do not condition here
+
+        spk_emb_repeated = spk_emb.repeat(1, enc_out.shape[1], 1)
+        enc_out_spk = torch.cat([enc_out, spk_emb_repeated], dim=2)
 
         # Predict durations
-        log_dur_pred = self.duration_predictor(enc_out, enc_mask).squeeze(-1)
+        log_dur_pred = self.duration_predictor(enc_out_spk, enc_mask).squeeze(-1) # Condition here
         dur_pred = torch.clamp(torch.exp(log_dur_pred) - 1, 0, max_duration)
 
         # Predict pitch
-        pitch_pred = self.pitch_predictor(enc_out, enc_mask).permute(0, 2, 1)
+        pitch_pred = self.pitch_predictor(enc_out_spk, enc_mask).permute(0, 2, 1) # Condition here
 
         # Alignment
         text_emb = self.encoder.word_emb(inputs)
@@ -301,7 +304,7 @@ class FastPitch(nn.Module):
 
         # Predict energy
         if self.energy_conditioning:
-            energy_pred = self.energy_predictor(enc_out, enc_mask).squeeze(-1)
+            energy_pred = self.energy_predictor(enc_out_spk, enc_mask).squeeze(-1) # Condition here
 
             # Average energy over characters
             energy_tgt = average_pitch(energy_dense.unsqueeze(1), dur_tgt)
@@ -317,8 +320,11 @@ class FastPitch(nn.Module):
         len_regulated, dec_lens = regulate_len(
             dur_tgt, enc_out, pace, mel_max_len)
 
+        spk_emb_repeated = spk_emb.repeat(1, len_regulated.shape[1], 1)
+        len_regulated_spk = torch.cat([len_regulated, spk_emb_repeated], dim=2)
+
         # Output FFT
-        dec_out, dec_mask = self.decoder(len_regulated, dec_lens)
+        dec_out, dec_mask = self.decoder(len_regulated_spk, dec_lens) # Condition here
         mel_out = self.proj(dec_out)
         return (mel_out, dec_mask, dur_pred, log_dur_pred, pitch_pred,
                 pitch_tgt, energy_pred, energy_tgt, attn_soft, attn_hard,

martinvk1 avatar Feb 21 '23 11:02 martinvk1

ah, thanks for sharing the details. will try this out.

dsplog avatar Feb 23 '23 02:02 dsplog

@adrianastan Hope it works out for you. Just wanted to add that I got great results by concatenating the speaker embedding directly to the input of the 1) pitch predictor 2) duration predictor 3) energy predictor and 4) decoder, like below. I am using a speaker_embedding_dim of 256, symbol_embedding_dim of 512 which means that these layers are now 768.

--- a/PyTorch/SpeechSynthesis/FastPitch/fastpitch/model.py
+++ b/PyTorch/SpeechSynthesis/FastPitch/fastpitch/model.py
@@ -262,14 +262,17 @@ class FastPitch(nn.Module):
             spk_emb.mul_(self.speaker_emb_weight)
 
         # Input FFT
-        enc_out, enc_mask = self.encoder(inputs, conditioning=spk_emb)
+        enc_out, enc_mask = self.encoder(inputs, conditioning=0) # Do not condition here
+
+        spk_emb_repeated = spk_emb.repeat(1, enc_out.shape[1], 1)
+        enc_out_spk = torch.cat([enc_out, spk_emb_repeated], dim=2)
 
         # Predict durations
-        log_dur_pred = self.duration_predictor(enc_out, enc_mask).squeeze(-1)
+        log_dur_pred = self.duration_predictor(enc_out_spk, enc_mask).squeeze(-1) # Condition here
         dur_pred = torch.clamp(torch.exp(log_dur_pred) - 1, 0, max_duration)
 
         # Predict pitch
-        pitch_pred = self.pitch_predictor(enc_out, enc_mask).permute(0, 2, 1)
+        pitch_pred = self.pitch_predictor(enc_out_spk, enc_mask).permute(0, 2, 1) # Condition here
 
         # Alignment
         text_emb = self.encoder.word_emb(inputs)
@@ -301,7 +304,7 @@ class FastPitch(nn.Module):
 
         # Predict energy
         if self.energy_conditioning:
-            energy_pred = self.energy_predictor(enc_out, enc_mask).squeeze(-1)
+            energy_pred = self.energy_predictor(enc_out_spk, enc_mask).squeeze(-1) # Condition here
 
             # Average energy over characters
             energy_tgt = average_pitch(energy_dense.unsqueeze(1), dur_tgt)
@@ -317,8 +320,11 @@ class FastPitch(nn.Module):
         len_regulated, dec_lens = regulate_len(
             dur_tgt, enc_out, pace, mel_max_len)
 
+        spk_emb_repeated = spk_emb.repeat(1, len_regulated.shape[1], 1)
+        len_regulated_spk = torch.cat([len_regulated, spk_emb_repeated], dim=2)
+
         # Output FFT
-        dec_out, dec_mask = self.decoder(len_regulated, dec_lens)
+        dec_out, dec_mask = self.decoder(len_regulated_spk, dec_lens) # Condition here
         mel_out = self.proj(dec_out)
         return (mel_out, dec_mask, dur_pred, log_dur_pred, pitch_pred,
                 pitch_tgt, energy_pred, energy_tgt, attn_soft, attn_hard,

Hi @martinvk1, I'm trying to change the model as per your advice. But I had issues with the model size and I replaced --in-fft-output-size with twice the size but still getting the size error at: https://github.com/NVIDIA/DeepLearningExamples/blob/afea561ecff80b82f17316a0290f6f34c486c9a5/PyTorch/SpeechSynthesis/FastPitch/fastpitch/transformer.py#L207 Please tell us in more detail how you changed the encoder.

Slava715 avatar May 08 '23 22:05 Slava715