mlx-audio icon indicating copy to clipboard operation
mlx-audio copied to clipboard

any examples how to use the spark model ? thanks for this amazing tool to use on mac

Open shengkaixuan opened this issue 9 months ago • 6 comments

shengkaixuan avatar May 13 '25 01:05 shengkaixuan

Voice cloning:

python -m mlx_audio.tts.generate --model Spark-TTS-0.5B-6bit --text " Get started today, P I P install M L X dash audio" --play --file_prefix spark  --sample_rate 16000 --pitch 1.0 --speed 1.5 --verbose --ref_text "Shoutout to Lucas Newman, Ivan Fioravanti, Andrei and Cheek Kim for their amazing contributions to MLX Audio." --ref_audio shoutout_000.mp3

TTS:

python -m mlx_audio.tts.generate --model Spark-TTS-0.5B-6bit --text " Get started today, P I P install M L X dash audio" --play --file_prefix spark  --sample_rate 16000 --pitch 1.0 --speed 1.5 --gender male --verbose 

Blaizzy avatar May 13 '25 20:05 Blaizzy

@Blaizzy
When I use the command " python -m mlx_audio.tts.generate --model mlx-community/Spark-TTS-0.5B-fp16 --text "嗯,今天是个特别的日子,天气嘛,大概是23度左右,挺舒服的。Well… it's sunny and bright, perfect for a walk, don’t you think? 顺便来学一句日语吧:「おはようございます」,也就是“早上好”的意思。然后是韩语:「안녕하세요」,就是“你好”。Now, let’s do a quick count — one, two, three,四,五,六,七,八!对啦,别忘了,今天是2025年5月12日,星期一。Let’s start the multilingual TTS test — ready? go!" ",

the generated audio is normal. However, when I use the code in server.py to generate audio, the synthesized audio speed changes. May I ask what the reason is?

hsoftxl avatar May 14 '25 02:05 hsoftxl

@Blaizzy
this is my codes

from mlx_audio.tts.generate import generate_audio

from fastapi.responses import FileResponse, HTMLResponse, JSONResponse
import logging
from mlx_audio.tts.utils import load_model
import uuid
import os
import sys
import numpy as np
import soundfile as sf
from fastrtc import ReplyOnPause, Stream, get_stt_model

# Configure logging
def setup_logging(verbose: bool = False):
    level = logging.DEBUG # if verbose else logging.INFO
    format_str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
    if verbose:
        format_str = "%(asctime)s - %(name)s - %(levelname)s - %(funcName)s:%(lineno)d - %(message)s"

    logging.basicConfig(level=level, format=format_str)
    return logging.getLogger("mlx_audio_server")


logger = setup_logging()  # Will be updated with verbose setting in main()


# Load the model once on server startup.
# You can change the model path or pass arguments as needed.
# For performance, load once globally:
tts_model = None  # Will be loaded when the server starts
audio_player = None  # Will be initialized when the server starts
stt_model = get_stt_model()

OUTPUT_FOLDER = "outputs"
os.makedirs(OUTPUT_FOLDER, exist_ok=True)
logger.debug(f"Using output folder: {OUTPUT_FOLDER}")

def tts_endpoint(
    text: str ,
    voice: str = "af_heart",
    speed: float = 1.0,
    model: str = "mlx-community/Spark-TTS-0.5B-fp16",
):
    """
    POST an x-www-form-urlencoded form with 'text' (and optional 'voice', 'speed', and 'model').
    We run TTS on the text, save the audio in a unique file,
    and return JSON with the filename so the client can retrieve it.
    """
    global tts_model

    if not text.strip():
        return JSONResponse({"error": "Text is empty"}, status_code=400)

    # Validate speed parameter
    try:
        speed_float = float(speed)
        if speed_float < 0.5 or speed_float > 2.0:
            return JSONResponse(
                {"error": "Speed must be between 0.5 and 2.0"}, status_code=400
            )
    except ValueError:
        return JSONResponse({"error": "Invalid speed value"}, status_code=400)

    # Store current model repo_id for comparison
    current_model_repo_id = (
        getattr(tts_model, "repo_id", None) if tts_model is not None else None
    )

    # Load the model if it's not loaded or if a different model is requested
    if tts_model is None or current_model_repo_id != model:
        try:
            logger.debug(f"Loading TTS model from {model}")
            tts_model = load_model(model)
            logger.debug("TTS model loaded successfully")
        except Exception as e:
            logger.error(f"Error loading TTS model: {str(e)}")
            return JSONResponse(
                {"error": f"Failed to load model: {str(e)}"}, status_code=500
            )

    # We'll do something like the code in model.generate() from the TTS library:
    # Generate the unique filename
    unique_id = str(uuid.uuid4())
    filename = f"tts_{unique_id}.wav"
    output_path = os.path.join(OUTPUT_FOLDER, filename)

    logger.debug(
        f"Generating TTS for text: '{text[:50]}...' with voice: {voice}, speed: {speed_float}, model: {model}"
    )
    logger.debug(f"Output file will be: {output_path}")

    # We'll use the high-level "model.generate" method:
    results = tts_model.generate(
        text=text,
        voice=voice,
        speed=speed_float,
        lang_code=voice[0],
        verbose=False,
    )

    # We'll just gather all segments (if any) into a single wav
    # It's typical for multi-segment text to produce multiple wave segments:
    audio_arrays = []
    for segment in results:
        audio_arrays.append(segment.audio)

    # If no segments, return error
    if not audio_arrays:
        logger.error("No audio segments generated")
        return JSONResponse({"error": "No audio generated"}, status_code=500)

    # Concatenate all segments
    cat_audio = np.concatenate(audio_arrays, axis=0)

    # Write the audio as a WAV
    try:
        sf.write(output_path, cat_audio, 24000)
        logger.debug(f"Successfully wrote audio file to {output_path}")

        # Verify the file exists
        if not os.path.exists(output_path):
            logger.error(f"File was not created at {output_path}")
            return JSONResponse(
                {"error": "Failed to create audio file"}, status_code=500
            )

        # Check file size
        file_size = os.path.getsize(output_path)
        logger.debug(f"File size: {file_size} bytes")

        if file_size == 0:
            logger.error("File was created but is empty")
            return JSONResponse(
                {"error": "Generated audio file is empty"}, status_code=500
            )

    except Exception as e:
        logger.error(f"Error writing audio file: {str(e)}")
        return JSONResponse(
            {"error": f"Failed to save audio: {str(e)}"}, status_code=500
        )

    return {"filename": output_path}


text='''
嗯,今天是个特别的日子,天气嘛,大概是 23 度左右,挺舒服的。
Well… it's sunny and bright, perfect for a walk, don’t you think? 😄
顺便来学一句日语吧:「おはようございます」,也就是“早上好”的意思。
然后是韩语:「안녕하세요」,就是“你好”。
Now, let’s do a quick count — one, two, three, 四,五,六,七,八!
对啦,别忘了,今天是 2025 年 5 月 12 日,星期一。
Let’s start the multilingual TTS test — ready? go!
'''


# Example: Generate an audiobook chapter as mp3 audio
# generate_audio(
#     text=text,
#     # model_path="mlx-community/Spark-TTS-0.5B-fp16",
#     model_path="mlx-community/Llama-OuteTTS-1.0-1B-bf16",
#     stream=True,
#     join_audio=True
# )

result = tts_endpoint(text)
print(result)
print("Audiobook chapter successfully generated!")

hsoftxl avatar May 14 '25 02:05 hsoftxl

Came here because I noticed that when passing join_audio=True into mlx_audio.tts.generate.generate_audio() the pitch changes. (It's hard to gauge if the speed changes as the length of the output audio is different each time. Also, I didn't investigate that hard.)

And it looks like the same operation for joining audio that happens in generate_audio() when join_audio=True is done in server.py.

https://github.com/Blaizzy/mlx-audio/blob/2460d15586a1e945dcc20776aefe32c4a4e6ed62/mlx_audio/server.py#L179

    # We'll just gather all segments (if any) into a single wav
    # It's typical for multi-segment text to produce multiple wave segments:
    audio_arrays = []
    for segment in results:
        audio_arrays.append(segment.audio)

Maybe joining the audio segments could be optional as it is when calling generate_audio()? I don't use the server so I don't have an opinion on it.

-- Edit: Just realized that a bool in the server for joining segments wouldn't actually solve whatever "the problem" is that causes the pitch/speed change. It does seem to only happen with spark.

@hsoftxl is the generated audio faster/higher pitch if you run the command with the --join_audio flag? I'm willing to bet that it will be and if you call generate_audio() without passing in join_audio=True the generated audio file will be the appropriate speed/pitch.

mkell43 avatar May 16 '25 18:05 mkell43

@mkell43

try: sf.write(output_path, cat_audio, 24000) logger.debug(f"Successfully wrote audio file to {output_path}")

change to

try: sf.write(output_path, cat_audio, 16000) logger.debug(f"Successfully wrote audio file to {output_path}")

this way , can fix.

hsoftxl avatar May 26 '25 07:05 hsoftxl

@Blaizzy When I use the command " python -m mlx_audio.tts.generate --model mlx-community/Spark-TTS-0.5B-fp16 --text "嗯,今天是个特别的日子,天气嘛,大概是23度左右,挺舒服的。Well… it's sunny and bright, perfect for a walk, don’t you think? 顺便来学一句日语吧:「おはようございます」,也就是“早上好”的意思。然后是韩语:「안녕하세요」,就是“你好”。Now, let’s do a quick count — one, two, three,四,五,六,七,八!对啦,别忘了,今天是2025年5月12日,星期一。Let’s start the multilingual TTS test — ready? go!" ",

the generated audio is normal. However, when I use the code in server.py to generate audio, the synthesized audio speed changes. May I ask what the reason is?

This is was fixed in #153. The initial server was built around Kokoro but I redesigned it to support all the new models and tasks (stt, tts and sts)

Blaizzy avatar May 26 '25 09:05 Blaizzy