MiniCPM-V icon indicating copy to clipboard operation
MiniCPM-V copied to clipboard

No training code found for implementing the Speech-to-Speech framework

Open everwind opened this issue 9 months ago • 1 comments

Speech-to-Speech Framework

https://openbmb.notion.site/MiniCPM-o-2-6-A-GPT-4o-Level-MLLM-for-Vision-Speech-and-Multimodal-Live-Streaming-on-Your-Phone-185ede1b7a558042b5d5e45e6b237da9

We bridge the audio encoder and LLM backbone via encoded audio representations (i.e., no ASR is performed). The LLM backbone and speech decoder are bridged in a hybrid fashion: (1) A dense speech embedding from the LLM hidden representations controls the voice, emotion, accent, and other fine-grained speech-centric features. During training, gradients from the speech decoder are back-propagated to full model parameters, including the LLM backbone, and the audio/visual encoders. Note that the whole model is trained end-to-end without any intermediate loss or supervision. (2) We also feed the text tokens from the LLM backbone to the speech decoder for better semantic control and training data efficiency.

I can not find the code for implementing the Speech-to-Speech framework in this project 没有找到 端到端 音频到音频的训练代码,也就是直接从音频解码模型将梯度端到端传递到大模型主干的训练代码(gradients from the speech decoder are back-propagated to full model parameters)。

everwind avatar Mar 24 '25 11:03 everwind

Yes, we have not yet release the training code for speech decoder and the end-to-end training code for the full MiniCPM-o-2.6. We will consider open-sourcing training code in the future. Currently, it is possible for people who are familiar with tokenizer-based speech generation to implement this training in a few days. The core idea is, pass the last_hidden_states from the LLM and project it into TTS decoder latent space to produce speaker embedding, and embed text tokens into TTS decoder latent space, and embed audio tokens into TTS decoder latent space, finally construct a window that looks like ||speaker embedding| text tokens| audio tokens| audio eos|| and back-propagate the CE loss. For audio token modeling, you can refer to

            # Encode 4 layers of codes to audio embedding by layer
            audio_embed_all_layers = []
            for i in range(self.num_vq):
                audio_codes_layer_i = []
                for codes in all_audio_codes:
                    audio_codes_layer_i.append(
                        codes[i, :].squeeze(0),
                    )
                # Pad each layer of audio codes to fixed length
                audio_codes_layer_i = pad_sequence(audio_codes_layer_i, batch_first=True)
                # Encode each layer of audio codes into embedding (parallelized)
                audio_embed_layer_i = self.emb_code[i](audio_codes_layer_i) # [batch_size, seq_len, gpt_hidden_dim]
                audio_embed_all_layers.append(audio_embed_layer_i)
            
            # Here we need to calculate the audio_embed of four layers and add them up
            # According to the official implementation of ChatTTS https://github.com/2noise/ChatTTS/blob/51ec0c784c2795b257d7a6b64274e7a36186b731/ChatTTS/model/gpt.py#L451
            audio_embed_all_layers = torch.stack(audio_embed_all_layers, dim=0) # [num_vq, seq_len, gpt_hidden_dim]
            audio_embed_all_layers = torch.sum(audio_embed_all_layers, dim=0, keepdim=False) # [seq_len, gpt_hidden_dim]

and for CE loss, you can refer to


            # predict audio codes using last_hidden_state by gpt TTS decoder
            logits_all_vq_layers = []
            for num_vq_iter in range(self.num_vq):
                logits_i = self.head_code[num_vq_iter](tts_last_hidden_state) # [batch, seq_len_max, audio_codebook_vocab]
                logits_all_vq_layers.append(logits_i)
            logits_all_vq_layers = torch.stack(logits_all_vq_layers, dim=0) # [num_vq, batch_size, seq_len_max, audio_codebook_vocab], stack, insert one extra dimension
            logits_all_vq_layers = logits_all_vq_layers.permute(1, 2, 0, 3) # [batch_size, seq_len_max, num_vq, audio_codebook_vocab]
            
            # compute model predictions
            shift_logits = logits_all_vq_layers[:, :-1, :, :].contiguous() # [batch_size, seq_len_max-1, num_vq, audio_codebook_vocab]
            shift_labels = labels[:, 1:, :].contiguous() # [batch_size, seq_len_max-1, num_vq]

bokesyo avatar Apr 27 '25 04:04 bokesyo