mlx-audio icon indicating copy to clipboard operation
mlx-audio copied to clipboard

Feature Request: Voice Cloning for Orpheus

Open dwohlfahrt opened this issue 10 months ago • 21 comments

As the title says, would be awesome to have built-in support for Orpheus's voice cloning capability. Thanks!!

dwohlfahrt avatar Mar 31 '25 19:03 dwohlfahrt

I looked around and I don't see any indication this is doable with their current model -- the GitHub page claims it, but all the examples use the fine-tuned named voices, which are effectively just sticking a "[voice]:" prefix at the beginning of the text prompt.

Maybe there's some other special token that you can use to prefix an audio prompt? Has anyone seen a working example?

lucasnewman avatar Apr 01 '25 21:04 lucasnewman

Ah, ok, I found it in the Colab for the pretrained-only model here.

Looks like we can just use an audio prefix prompt -- I'll see if I can add it this weekend.

lucasnewman avatar Apr 01 '25 21:04 lucasnewman

Image

@lucasnewman that notebook seems to be broken according to their repo.

Give it a try but if it doesn't work we'll need to wait till they fix it.

Blaizzy avatar Apr 01 '25 22:04 Blaizzy

I requested them to fix it on X a couple of weeks ago when they were gathering feedback from the launch.

Blaizzy avatar Apr 01 '25 22:04 Blaizzy

Here's a voice cloning script I put together based on that colab. Works locally on my m4... you know, if you have the patience to wait :)

import argparse
import librosa
import soundfile as sf
import torch

from huggingface_hub import snapshot_download
from snac import SNAC
from transformers import AutoModelForCausalLM, AutoTokenizer


model_name = "canopylabs/orpheus-tts-0.1-pretrained"
tokenizer = AutoTokenizer.from_pretrained(model_name)

device = "cuda" if torch.cuda.is_available() else "cpu"
snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")


# Download only model config and safetensors
model_path = snapshot_download(
    repo_id=model_name,
    allow_patterns=[
        "config.json",
        "*.safetensors",
        "model.safetensors.index.json",
    ],
    ignore_patterns=[
        "optimizer.pt",
        "pytorch_model.bin",
        "training_args.bin",
        "scheduler.pt",
        "tokenizer.json",
        "tokenizer_config.json",
        "special_tokens_map.json",
        "vocab.json",
        "merges.txt",
        "tokenizer.*"
    ]
)

model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
model.cpu()

def redistribute_codes(code_list):
    layer_1 = []
    layer_2 = []
    layer_3 = []
    
    for i in range((len(code_list)+1)//7):
        layer_1.append(code_list[7*i])
        layer_2.append(code_list[7*i+1]-4096)
        layer_3.append(code_list[7*i+2]-(2*4096))
        layer_3.append(code_list[7*i+3]-(3*4096))
        layer_2.append(code_list[7*i+4]-(4*4096))
        layer_3.append(code_list[7*i+5]-(5*4096))
        layer_3.append(code_list[7*i+6]-(6*4096))
    
    codes = [torch.tensor(layer_1).unsqueeze(0),
            torch.tensor(layer_2).unsqueeze(0),
            torch.tensor(layer_3).unsqueeze(0)]
    
    audio_hat = snac_model.decode(codes)
    return audio_hat


def tokenise_audio(waveform):
    waveform = torch.from_numpy(waveform).unsqueeze(0)
    waveform = waveform.to(dtype=torch.float32)
    waveform = waveform.unsqueeze(0)

    with torch.inference_mode():
        codes = snac_model.encode(waveform)

    all_codes = []
    for i in range(codes[0].shape[1]):
        all_codes.append(codes[0][0][i].item()+128266)
        all_codes.append(codes[1][0][2*i].item()+128266+4096)
        all_codes.append(codes[2][0][4*i].item()+128266+(2*4096))
        all_codes.append(codes[2][0][(4*i)+1].item()+128266+(3*4096))
        all_codes.append(codes[1][0][(2*i)+1].item()+128266+(4*4096))
        all_codes.append(codes[2][0][(4*i)+2].item()+128266+(5*4096))
        all_codes.append(codes[2][0][(4*i)+3].item()+128266+(6*4096))

    return all_codes


def generate_cloned_speech(prompts: str | list[str], voice_clone_sample_path: str, voice_clone_sample_transcript: str):
    start_tokens = torch.tensor([[ 128259]], dtype=torch.int64)
    end_tokens = torch.tensor([[128009, 128260, 128261, 128257]], dtype=torch.int64)
    final_tokens = torch.tensor([[128258, 128262]], dtype=torch.int64)

    audio_array, sample_rate = librosa.load(voice_clone_sample_path, sr=24000)
    myts = tokenise_audio(audio_array)
    
    prompt_tokked = tokenizer(voice_clone_sample_transcript, return_tensors="pt")
    input_ids = prompt_tokked["input_ids"]

    zeroprompt_input_ids = torch.cat([start_tokens, input_ids, end_tokens, torch.tensor([myts]), final_tokens], dim=1) # SOH SOT Text EOT EOH

    all_modified_input_ids = []

    if isinstance(prompts, str):
        prompts = [prompts]
        
    for prompt in prompts:
        input_ids = tokenizer(prompt, return_tensors="pt").input_ids
        second_input_ids = torch.cat([zeroprompt_input_ids, start_tokens, input_ids, end_tokens], dim=1)
        all_modified_input_ids.append(second_input_ids)
    
    all_padded_tensors = []
    all_attention_masks = []

    max_length = max([modified_input_ids.shape[1] for modified_input_ids in all_modified_input_ids])

    for modified_input_ids in all_modified_input_ids:
        padding = max_length - modified_input_ids.shape[1]
        padded_tensor = torch.cat([torch.full((1, padding), 128263, dtype=torch.int64), modified_input_ids], dim=1)
        attention_mask = torch.cat([torch.zeros((1, padding), dtype=torch.int64), torch.ones((1, modified_input_ids.shape[1]), dtype=torch.int64)], dim=1)
        all_padded_tensors.append(padded_tensor)
        all_attention_masks.append(attention_mask)
    
    all_padded_tensors = torch.cat(all_padded_tensors, dim=0)
    all_attention_masks = torch.cat(all_attention_masks, dim=0)

    input_ids = all_padded_tensors.to("cpu")
    attention_mask = all_attention_masks.to("cpu")
    
    with torch.no_grad():
        generated_ids = model.generate(
            input_ids=input_ids,
            # attention_mask=attention_mask,
            max_new_tokens=990,
            # max_new_tokens=1500,
            do_sample=True,
            temperature=0.5,
            # top_k=40,
            top_p=0.9,
            repetition_penalty=1.1,
            num_return_sequences=1,
            eos_token_id=128258,
            # end_token_id=128009
        )
    
    token_to_find = 128257
    token_to_remove = 128258

    # Check if the token exists in the tensor
    token_indices = (generated_ids == token_to_find).nonzero(as_tuple=True)
    
    if len(token_indices[1]) > 0:
        last_occurrence_idx = token_indices[1][-1].item()
        cropped_tensor = generated_ids[:, last_occurrence_idx+1:]
    else:
        cropped_tensor = generated_ids
    
    mask = cropped_tensor != token_to_remove
    processed_rows = []
    for row in cropped_tensor:
        # Apply the mask to each row
        masked_row = row[row != token_to_remove]
        processed_rows.append(masked_row)
    
    code_lists = []
    for row in processed_rows:
        # row is a 1D tensor with its own length
        row_length = row.size(0)
        new_length = (row_length // 7) * 7  # largest multiple of 7 that fits in this row
        trimmed_row = row[:new_length]
        trimmed_row = [t - 128266 for t in trimmed_row]
        code_lists.append(trimmed_row)
    
    my_samples = []
    for code_list in code_lists:
        samples = redistribute_codes(code_list)
        my_samples.append(samples)
    
    for i, sample in enumerate(my_samples):
        try:
            output_file = f"generated_audio_{i}.wav"
            sample_rate = 24000  # Same sample rate used in the model
            sf.write(output_file, sample.detach().cpu().squeeze().numpy(), sample_rate)
        except Exception as e:
            print(f"Error saving audio at index {i}: {e}")
            return my_samples
    
    return my_samples


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run TTS inference with optional voice cloning")
    parser.add_argument("--prompt", type=str, required=True, help="Text prompt to synthesize")
    parser.add_argument("--clone_audio_path", type=str, help="Path to audio file for voice cloning")
    parser.add_argument("--clone_transcript", type=str, help="Transcript of the voice clone sample")
    
    args = parser.parse_args()
    
    samples = generate_cloned_speech(args.prompt, args.clone_audio_path, args.clone_transcript)
    
    print(f"Generated {len(samples)} audio samples")

dwohlfahrt avatar Apr 02 '25 03:04 dwohlfahrt

@Blaizzy Can you upload the base pretrained model to the mlx-community hub so we can try it?

lucasnewman avatar Apr 02 '25 15:04 lucasnewman

Sure, will do ASAP.

Just finishing uploading ModernBert

Blaizzy avatar Apr 02 '25 16:04 Blaizzy

Done :)

Blaizzy avatar Apr 02 '25 16:04 Blaizzy

I put up a draft PR in https://github.com/Blaizzy/mlx-audio/pull/75, but it doesn't seem to work that well yet. When I find some time I'll compare against the reference @dwohlfahrt posted -- it's probably some kind of tokenization issue.

lucasnewman avatar Apr 02 '25 17:04 lucasnewman

Sure, let me know when it's ready to review or if you need my help :)

Blaizzy avatar Apr 02 '25 18:04 Blaizzy

Ok, I ran the implementation above in a Colab with CUDA, and it has the exact same tokenization scheme (good) and gets the same weird results that no reasonable person could say is voice matching (bad).

I'm not quite sure what to make of this -- I wonder if the authors declared this feature working prematurely and it's not really effective in practice. I think at this point it probably makes sense for the authors to prove that it actually works with the model as they claim, since as far as I can tell 1) their sample code doesn't work, and 2) all community reproductions of the feature have failed.

If anyone has a working example, please let us know!

lucasnewman avatar Apr 02 '25 21:04 lucasnewman

@lucasnewman I haven’t run all that many prompts through my script above, but can confirm that voice cloning worked extremely well on some and not at all in others. What appears to be the difference maker is prompt length. Again, very limited testing, but all of my longer prompts (10-15 seconds) cloned great while the shorter ones exhibited zero similarity to the clone sample whatsoever. Take it fwiw.

dwohlfahrt avatar Apr 03 '25 00:04 dwohlfahrt

Ok, I ran the implementation above in a Colab with CUDA, and it has the exact same tokenization scheme (good) and gets the same weird results that no reasonable person could say is voice matching (bad).

I'm not quite sure what to make of this -- I wonder if the authors declared this feature working prematurely and it's not really effective in practice. I think at this point it probably makes sense for the authors to prove that it actually works with the model as they claim, since as far as I can tell 1) their sample code doesn't work, and 2) all community reproductions of the feature have failed.

If anyone has a working example, please let us know!

Thanks for looking into it @lucasnewman!

As I suspected and mentioned earlier here, it's indeed broken. All we can do now is wait and monitor the repo for when they fix the voice cloning.

Blaizzy avatar Apr 03 '25 15:04 Blaizzy

@dwohlfahrt could you try the same prompts with the closed PR from @lucasnewman and see if the results match the torch.

If so, I think we can re-open it and merge with some warning whenever users decide to do voice cloning.

Blaizzy avatar Apr 03 '25 15:04 Blaizzy

Ok, I reopened https://github.com/Blaizzy/mlx-audio/pull/75 — you can merge it if people want to try it. It doesn't really hurt anything since it's only active if ref_audio/ref_text is passed.

lucasnewman avatar Apr 03 '25 15:04 lucasnewman

Could you add a warning ⚠️ for users to let them know about potential issues ?

Blaizzy avatar Apr 03 '25 15:04 Blaizzy

Awesome, thanks guys! I'm tied up for the next few hours, but will pull this down and do some comparison testing vs torch this evening.

dwohlfahrt avatar Apr 03 '25 16:04 dwohlfahrt

Voice cloning still busted?

sihayas avatar Apr 10 '25 20:04 sihayas

Yes @sihayas

It seems to still be the case

https://github.com/canopyai/Orpheus-TTS/issues/134#issuecomment-2792282824

Blaizzy avatar Apr 10 '25 21:04 Blaizzy

Voice cloning still busted?

There's no release for it yet, so you need to install it from source:

pip install git+https://github.com/Blaizzy/mlx-audio.git@main

Then you can use the --ref_audio and --ref_text parameters to provide a voice sample to match. Note that it doesn't work for short generations due to what appears to be a limitation in the underlying model.

You may have better luck with Sesame CSM -- it's a better model imho and handles voice matching pretty well.

lucasnewman avatar Apr 10 '25 21:04 lucasnewman

New release coming tomorrow :)

@Charmaineem please add the information in this issue to the upcoming docs you are working on 👌🏽

Blaizzy avatar Apr 10 '25 21:04 Blaizzy