chatterbox icon indicating copy to clipboard operation
chatterbox copied to clipboard

🚀 Proposal: Optimize Data Preprocessing for >4x Faster TTS Training

Open havok2-htwo opened this issue 5 months ago ‱ 33 comments

Problem: The current finetune_t3.py script's on-the-fly data processing (audio loading, resampling, tokenization, VE) in Dataset.getitem creates a major CPU bottleneck. For a ~5000 sample dataset (1.4GB raw audio, generated using ElevenLabs, which conveniently provided corresponding .txt files for the speech content), this limited iteration time to ~2.2-2.9s/it even with optimized dataloader_num_workers (down from an initial ~5s/it).

Proposed Solution: Offline Preprocessing - I've implemented a two-stage approach: preprocess_data.py (New Script): A standalone, parallelized (multi-CPU core) script that processes the entire raw dataset once. It handles all expensive CPU tasks (audio ops, tokenization, VE) using the .wav and associated .txt files, and saves the resulting tensors for each sample into individual .pt files. finetune_t3_preprocessed.py (Modified Training Script): The Dataset now simply loads these precomputed .pt files, drastically reducing CPU work during training. (Code for both scripts will be embedded below for review.) Key Benefits & Observed Results: Training Iteration Time: Reduced to ~1.2 seconds/iteration. This is a >4x speedup over the initial baseline and ~2x faster than optimized on-the-fly processing for this dataset. Reduced CPU Bottleneck: Training CPU load is minimal; dataloader_num_workers can be lowered (e.g., 2-8). Maximized GPU Utilization: VRAM becomes the primary constraint for batch size. Efficient Workflow: Preprocessing is a one-time cost, benefiting all subsequent training runs. Data Size: Preprocessed data for the 5000 samples was <100MB (from 1.4GB raw audio). This optimization significantly speeds up fine-tuning, especially when working with datasets structured with audio and corresponding text files. I'm happy to provide the scripts and discuss further.

Code see the posts below.

havok2-htwo avatar Jun 20 '25 10:06 havok2-htwo

how would we run inference (using the gradio) I mean in terms of how do we point to the new model we trained?

rikabi89 avatar Jun 23 '25 20:06 rikabi89

I asked gemeni:

Okay, here's a shorter way to explain how to use the fine-tuned model for inference, for example, in a Gradio app:

"To use your fine-tuned model for inference (e.g., with Gradio):

Locate your fine-tuned model directory. This is the --output_dir from your training (e.g., ../checkpoints/chatterbox_10k_DE_.../). This directory should contain: t3_cfg.safetensors (your T3 weights) ve.safetensors s3gen.safetensors tokenizer.json In your Gradio script (or any inference script), load the model using ChatterboxTTS.from_local(): Generated python from chatterbox.tts import ChatterboxTTS import torch

MODEL_DIR = "path/to/your/finetuned_model_dir" # <-- CHANGE THIS DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

tts_model = ChatterboxTTS.from_local(ckpt_dir=MODEL_DIR, device=DEVICE)

Now use tts_model.generate(...) in your Gradio functions

content_copy download Use code with caution. Python No need to edit tts.py from the library. The from_local method is designed to load models saved this way. Just make sure the directory you point to has all the necessary files. If you want to use a specific intermediate checkpoint (like checkpoint-8000), you'd first need to manually prepare a directory that contains all these components (extract T3 weights from the Trainer's save and copy VE, S3Gen, Tokenizer files into it), and then point MODEL_DIR to that prepared directory."

I don‘t use gradio in my project, but I‘ll hope it helps. I managed to get a much better german experience - model by using 10.000 german samples created with chatgpt and elevenlabs. I just ran the finetuning around 12h on my RTX5090.

havok2-htwo avatar Jun 24 '25 00:06 havok2-htwo

@havok2-htwo thanks for providing the script! are you able to share your trained weights for german chatterbox? I would like to try it out! Did you observe loss or anything while training?

rotatorotator avatar Jun 25 '25 09:06 rotatorotator

@havok2-htwo thanks for providing the script! are you able to share your trained weights for german chatterbox? I would like to try it out! Did you observe loss or anything while training?

Hi! I've updated the scripts and will upload them here in the next few days (currently at work and not at home). I've optimized several parts and ran multiple training sessions with my 5090 over the past few days. The German output is getting pretty good – not perfect, but much more understandable now.

I ran into a few issues during training. Some of the training data was too short, which caused strange results. My current training data script now only creates samples if the input is long enough. I also cleaned most of the samples by hand. Have generated 12000 Samples with Elevenlabs here.

I also added a lot of examples with tricky pronunciations like “Ziel,” “Generation,” and similar words to help reduce mispronunciations.

Right now, I'm still having trouble with the start and end of the generated audio: my fine-tuned model often skips the first word and doesn't stop speaking when the text ends. I asked Gemini about it, and it suggested that this might be due to missing silence at the beginning and end of the audio files. So I’m now adding 300ms of silence to both ends of each file.

This is actually my first time fine-tuning a model – aside from playing with deepfakes back in the day. 😊

By the way – any suggestions on where best to upload the trained weights? Would GitHub be suitable, or should I use something like Hugging Face or Google Drive? I will try to retrain my model with the added silence, and then will share it. Give me some days.

havok2-htwo avatar Jun 25 '25 11:06 havok2-htwo

@rotatorotator I have updated the code above completly

havok2-htwo avatar Jun 25 '25 12:06 havok2-htwo

@havok2-htwo thanks for providing the script! are you able to share your trained weights for german chatterbox? I would like to try it out! Did you observe loss or anything while training?

Hi! I've updated the scripts and will upload them here in the next few days (currently at work and not at home). I've optimized several parts and ran multiple training sessions with my 5090 over the past few days. The German output is getting pretty good – not perfect, but much more understandable now.

I ran into a few issues during training. Some of the training data was too short, which caused strange results. My current training data script now only creates samples if the input is long enough. I also cleaned most of the samples by hand. Have generated 12000 Samples with Elevenlabs here.

I also added a lot of examples with tricky pronunciations like “Ziel,” “Generation,” and similar words to help reduce mispronunciations.

Right now, I'm still having trouble with the start and end of the generated audio: my fine-tuned model often skips the first word and doesn't stop speaking when the text ends. I asked Gemini about it, and it suggested that this might be due to missing silence at the beginning and end of the audio files. So I’m now adding 300ms of silence to both ends of each file.

This is actually my first time fine-tuning a model – aside from playing with deepfakes back in the day. 😊

By the way – any suggestions on where best to upload the trained weights? Would GitHub be suitable, or should I use something like Hugging Face or Google Drive? I will try to retrain my model with the added silence, and then will share it. Give me some days.

Thank you I will try these out. Your previous code seemed to work well for English. I was wondering if you know how we might approach training a non-latin based language such as Arabic is there a cut and dry way to do this without too much wrangling? p.s Huggingface is good for models.

rikabi89 avatar Jun 25 '25 12:06 rikabi89

@rotatorotator I have updated the code above completly

Thanks, but I did not mean the code but the weights you have trained e.g. the T3 file. As mentioned Huggingface is a good option for sharing these. FYI there is another one training german chatterbox: https://huggingface.co/SebastianBodza/Kartoffelbox-v0.1 If you need gpu recources for training I would be happy to contribute to that.

rotatorotator avatar Jun 25 '25 13:06 rotatorotator

Great @havok2-htwo @rotatorotator can I use it for Arabic? can you guide me about data and steps?

cod3r0k avatar Jun 25 '25 21:06 cod3r0k

Great @havok2-htwo @rotatorotator can I use it for Arabic? can you guide me about data and steps?

Sure, I believe it should work if you have enough speech samples. I'm currently redesigning a lot of my code based on Gemini’s suggestions. I used the training code as a base from this repository, but I found a few issues in the tts.py file and also while reading through the cond_enc.py. Gemini provided me with the correct values for this model and the proper tokens for the start and stop points. I hope this will help me fix the issues and run a proper fine-tuning session to fully leverage the potential of this base model.

I also realized that I trained for 128 epochs, but that was way too much. I actually achieved the best results at epoch 5, after which the training suffered from overfitting. So I will address that as well. I've generated a sample where you can hear the result, but there are some issues at the end that still need improvement.

(https://jmp.sh/s/71S206JyeoDuZa0YK2ve)

havok2-htwo avatar Jun 25 '25 23:06 havok2-htwo

Does it mean, that the finetuning time (12h on your 5090) was the 128 epochs and with training for about 5 epochs we could train in just a few minutes? Probably with a bigger dataset we would need to train more epochs.

rotatorotator avatar Jun 26 '25 13:06 rotatorotator

Does it mean, that the finetuning time (12h on your 5090) was the 128 epochs and with training for about 5 epochs we could train in just a few minutes? Probably with a bigger dataset we would need to train more epochs.

Yes, the bottleneck was definitely the CPU — and the fact that each audio file had to be converted into a .pt file before the GPU could even start working. I just preprocessed all the WAV files (basically extracted features) and then fed everything to the GPU. Now it’s running steadily at 600W under full load.

But I noticed that during training, the sequence length was limited to just 300 tokens, even though the original model was trained with 2048 tokens. When I try to match that original token length, my 32GB of VRAM isn’t sufficient — especially if I enable evaluation. Even with a batch size of 1, I’m already hitting 28GB of VRAM.

Still, I think it’s important not to go below the original sequence length — that might actually be the reason why the model keeps rambling at the end of sentences


But with Batch: 1 its slower than with 5.

havok2-htwo avatar Jun 26 '25 15:06 havok2-htwo

how big is the batch influence on vram usage? I am thinking about just renting a H100 with 80 GB Vram or something like that to train if we have a dataset which is good enough

JMLLR1 avatar Jun 26 '25 16:06 JMLLR1

Great @havok2-htwo @rotatorotator can I use it for Arabic? can you guide me about data and steps?

Sure, I believe it should work if you have enough speech samples. I'm currently redesigning a lot of my code based on Gemini’s suggestions. I used the training code as a base from this repository, but I found a few issues in the tts.py file and also while reading through the cond_enc.py. Gemini provided me with the correct values for this model and the proper tokens for the start and stop points. I hope this will help me fix the issues and run a proper fine-tuning session to fully leverage the potential of this base model.

I also realized that I trained for 128 epochs, but that was way too much. I actually achieved the best results at epoch 5, after which the training suffered from overfitting. So I will address that as well. I've generated a sample where you can hear the result, but there are some issues at the end that still need improvement.

(https://jmp.sh/s/71S206JyeoDuZa0YK2ve)

Thanks. Could you share step by step of your finetuning code?

cod3r0k avatar Jun 26 '25 17:06 cod3r0k

Her comes my new code: preprocess_data.py: - this creates pre calculated stuff for the gpu wit some variances

# preprocess_data.py
import argparse
import logging
from pathlib import Path
import os
from typing import Any, Dict, List, Tuple, Optional, Union
import functools
import gc # FĂŒr Garbage Collection

import torch
import torch.nn.functional as F
import librosa
import numpy as np
import soundfile as sf
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor, as_completed

# Versuche, audiotsm zu importieren
try:
    from audiotsm import wsola
    from audiotsm.io.array import ArrayReader, ArrayWriter
    AUDIOTSM_AVAILABLE = True
    logger_init = logging.getLogger(__name__) # TemporĂ€rer Logger fĂŒr Import-Status
    logger_init.info("audiotsm gefunden und wird fĂŒr Pitch Shifting verwendet (kombiniert mit Resampling).")
except ImportError:
    AUDIOTSM_AVAILABLE = False
    logger_init = logging.getLogger(__name__)
    logger_init.warning("audiotsm nicht gefunden. Fallback auf librosa.effects.pitch_shift. FĂŒr audiotsm-basiertes Pitching, 'pip install audiotsm' installieren.")


# Chatterbox specific imports
try:
    from chatterbox.tts import ChatterboxTTS, punc_norm
    from chatterbox.models.t3.modules.t3_config import T3Config
    from chatterbox.models.s3tokenizer import S3_SR, S3Tokenizer
except ImportError:
    print("Stelle sicher, dass das Chatterbox-Paket korrekt installiert ist und im PYTHONPATH liegt.")
    raise

logger = logging.getLogger(__name__) # Haupt-Logger
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(name)s - %(message)s")

# --- Augmentierungsfunktionen ---
def change_speed(audio_array: np.ndarray, sr: int, speed_factor: float) -> Optional[np.ndarray]:
    if speed_factor == 1.0: return audio_array.copy()
    if not (0.5 <= speed_factor <= 2.0):
        logger.debug(f"Speed factor {speed_factor} out of range [0.5, 2.0]. Skipping.")
        return None
    try:
        audio_array_copy = np.ascontiguousarray(audio_array)
        # Wenn audiotsm verfĂŒgbar ist, könnte man es auch hier verwenden fĂŒr bessere QualitĂ€t
        if AUDIOTSM_AVAILABLE and audio_array_copy.ndim == 1: # audiotsm ArrayReader erwartet 1D fĂŒr Mono
            try:
                # audiotsm erwartet Samples als int16 oder float32, aber ArrayReader/Writer kĂŒmmern sich darum
                # wenn channels=1. Librosa liefert float32.
                reader = ArrayReader(np.array([audio_array_copy.T], dtype=np.float32)) # Muss [channels, samples] sein
                writer = ArrayWriter(channels=1) # Wird 1D Array zurĂŒckgeben
                tsm = wsola(channels=1, speed=speed_factor)
                tsm.run(reader, writer)
                stretched_audio = writer.data.T.flatten() # ZurĂŒck zu 1D
                return stretched_audio if len(stretched_audio) > 0 else None
            except Exception as e_tsm_speed:
                logger.warning(f"audiotsm speed change failed for factor {speed_factor}: {e_tsm_speed}. Fallback auf librosa.")
                stretched_audio = librosa.effects.time_stretch(y=audio_array_copy, rate=speed_factor)
        else:
            stretched_audio = librosa.effects.time_stretch(y=audio_array_copy, rate=speed_factor)
        return stretched_audio if len(stretched_audio) > 0 else None
    except Exception as e:
        logger.warning(f"change_speed (librosa) failed for factor {speed_factor}: {e}")
        return None

def change_pitch(audio_array: np.ndarray, sr: int, n_steps: float) -> Optional[np.ndarray]:
    if n_steps == 0.0: return audio_array.copy()
    if not (-12.0 <= n_steps <= 12.0):
        logger.debug(f"Pitch steps {n_steps} out of range [-12, 12]. Skipping.")
        return None
    
    audio_array_copy = np.ascontiguousarray(audio_array)

    if AUDIOTSM_AVAILABLE and audio_array_copy.ndim == 1: # audiotsm ArrayReader erwartet 1D fĂŒr Mono
        try:
            semitone_ratio = 2.0**(1.0/12.0)
            pitch_shift_factor = semitone_ratio**n_steps

            # 1. Time-stretch mit audiotsm (wsola)
            # Der speed_factor fĂŒr audiotsm ist hier gleich dem pitch_shift_factor
            # Wenn pitch_shift_factor > 1 (höherer Pitch), wird das Audio gestaucht (klingt schneller).
            # Wenn pitch_shift_factor < 1 (tieferer Pitch), wird das Audio gedehnt (klingt langsamer).
            reader = ArrayReader(np.array([audio_array_copy.T], dtype=np.float32))
            writer = ArrayWriter(channels=1)
            tsm = wsola(channels=1, speed=pitch_shift_factor) # Speed = pitch_factor
            tsm.run(reader, writer)
            time_stretched_audio = writer.data.T.flatten()

            if len(time_stretched_audio) == 0:
                logger.warning("audiotsm pitch_shift: time_stretched_audio is empty.")
                return None

            # 2. Resample, um die ursprĂŒngliche Dauer wiederherzustellen, aber mit der neuen Tonhöhe.
            # Die neue "effektive" Samplerate des gestretchten Audios, wenn es die originale Tonhöhe hÀtte,
            # wÀre sr * pitch_shift_factor. Um es auf die originale Dauer mit der *neuen* Tonhöhe
            # zu bringen, resampeln wir es von dieser "effektiven" SR zurĂŒck zur originalen SR.
            # Oder einfacher: Das gestretchte Audio hat nun N' Samples.
            # Es soll wieder N Samples haben.
            # Die Samples wurden effektiv mit `pitch_shift_factor` "gespielt".
            # Wir resampeln es nun so, dass es auf die OriginallÀnge kommt.
            # Die OriginallÀnge in Samples war `len(audio_array_copy)`.
            # Die neue LĂ€nge in Samples ist `len(time_stretched_audio)`.
            #
            # Korrekte Logik fĂŒr Resampling nach TSM fĂŒr Pitch Shift:
            # Das time_stretched_audio hat jetzt eine LĂ€nge, die `len(original_audio) / pitch_shift_factor` entspricht.
            # Wir wollen es auf die ursprĂŒngliche LĂ€nge `len(original_audio)` resampeln.
            # Librosa's resample nimmt target_sr.
            # Das gestreckte Audio ist immer noch bei `sr`. Wir wollen es so resampeln,
            # als ob es mit `sr / pitch_shift_factor` abgespielt wĂŒrde, um die ursprĂŒngliche Dauer zu erhalten.
            # Also resampeln wir es von `sr` zu `sr / pitch_shift_factor`.
            # Nein, das ist falsch herum.
            #
            # Wenn wir Audio mit Faktor S gestreckt haben, ist es jetzt S mal so lang.
            # Samplerate ist noch SR.
            # Wir wollen die Tonhöhe um Faktor P Àndern.
            # Stretch-Faktor fĂŒr TSM = P. LĂ€nge des TSM-Audios ist OriginalLĂ€nge / P.
            # Dieses TSM-Audio (LĂ€nge L/P, SR) muss jetzt resampelt werden.
            # Wenn P > 1 (Pitch hoch), ist TSM-Audio kĂŒrzer. Wir mĂŒssen es "langsamer" resampeln (target_sr < sr).
            # Target SR = SR / P.

            # Das gestreckte Audio (immer noch bei `sr`) hat jetzt die Samples, die dem neuen Pitch entsprechen, aber bei falscher Dauer.
            # Um die Dauer zu korrigieren und den Pitch beizubehalten, resampeln wir.
            # Wenn der Pitch erhöht wurde (pitch_shift_factor > 1), ist time_stretched_audio kĂŒrzer.
            # Wir mĂŒssen es "dehnen" durch Resampling auf eine niedrigere SR, um die OriginallĂ€nge zu erreichen.
            # Die neue SR, um die OriginallÀnge zu erreichen, wÀre sr / pitch_shift_factor.
            # Dann hÀtten wir das Originalaudio, nur langsamer/schneller abgespielt. Das ist nicht das Ziel.
            #
            # Richtig:
            # 1. Audio time-stretchen mit Faktor `1/pitch_shift_factor`. D.h. `speed = 1/pitch_shift_factor`
            #    Audio wird lÀnger, wenn Pitch erhöht werden soll.
            #    Audio wird kĂŒrzer, wenn Pitch gesenkt werden soll.
            # 2. Resample dieses gestretchten Audios von `sr` zu `sr * pitch_shift_factor`.
            
            tsm_speed = 1.0 / pitch_shift_factor
            
            reader_corr = ArrayReader(np.array([audio_array_copy.T], dtype=np.float32))
            writer_corr = ArrayWriter(channels=1)
            tsm_corr = wsola(channels=1, speed=tsm_speed)
            tsm_corr.run(reader_corr, writer_corr)
            time_stretched_audio_corr = writer_corr.data.T.flatten()

            if len(time_stretched_audio_corr) == 0:
                logger.warning("audiotsm pitch_shift (korrigiert): time_stretched_audio_corr is empty.")
                return None
            
            # Die Ziel-Samplerate fĂŒr librosa.resample ist sr.
            # Das time_stretched_audio_corr wurde effektiv mit sr_orig / tsm_speed = sr_orig * pitch_shift_factor "erzeugt".
            # Wir resampeln es jetzt von dieser "virtuellen" Samplerate zur originalen Samplerate.
            # Nein, librosa.resample: y, orig_sr, target_sr.
            # Das time_stretched_audio_corr hat die Samplerate `sr`.
            # Es muss auf eine neue Samplerate resampelt werden, so dass bei Wiedergabe mit `sr` der Pitch-Shift entsteht.
            # Dies ist konzeptionell, was librosa.effects.pitch_shift intern macht.
            #
            # Der einfachste Weg, das "richtig" zu machen, ist, `librosa.effects.pitch_shift` als Referenz zu nehmen:
            # Es verwendet STFT, verschiebt Bins und macht ISTFT.
            #
            # Um es mit TSM + Resample zu machen:
            # - Strecke das Audio um Faktor `1/P` (P=pitch_factor). Audio ist jetzt `L/P` lang. (Falsch, Audio ist L* (1/ (1/P)) = L*P lang).
            #   Nein, wenn `speed = S`, neue LĂ€nge = `alte LĂ€nge / S`.
            #   Also, `speed = 1/P`. Neue LĂ€nge = `alte LĂ€nge * P`.
            # - Resample dieses `L*P` lange Audio (das immer noch SR `sr` hat) auf die ursprĂŒngliche LĂ€nge `L`.
            #   Das bedeutet, die Anzahl der Samples muss von `(L*P)/sr * sr` auf `L/sr * sr` geÀndert werden.
            #   Das ist ein Resampling von `sr` zu `sr / P`.
            # Das ist effektiv Time-Stretching um Faktor `P` und dann Resampling zurĂŒck.

            # Dieser Ansatz ist leider nicht so trivial, wie er scheint, um die QualitÀt zu halten.
            # Librosa's `pitch_shift` ist trotz der Artefakte oft die direktere Methode fĂŒr reines Pitch-Shifting.
            # Wenn `audiotsm` fĂŒr Pitch-Shifting verwendet wird, ist es oft in komplexeren Systemen
            # oder fĂŒr Effekte, wo DauerĂ€nderung auch okay ist.
            #
            # FĂŒr diesen Anwendungsfall ist ein direkter Fallback auf librosa.effects.pitch_shift
            # wahrscheinlich pragmatischer, wenn pyrubberband nicht verfĂŒgbar ist.
            # Die obige audiotsm-Logik war ein Versuch, es abzubilden, aber die KomplexitÀt/QualitÀts-Tradeoffs
            # sind signifikant.
            logger.debug(f"AUDIOTSM Pitch-Shift-Implementierung hier vereinfacht auf Fallback zu Librosa, da die korrekte Kombination von TSM und Resample fĂŒr reines Pitch-Shifting komplex ist und die QualitĂ€t variieren kann.")
            pitched_audio_librosa = librosa.effects.pitch_shift(y=audio_array_copy, sr=sr, n_steps=n_steps, bins_per_octave=24) # Erhöhe bins_per_octave
            return pitched_audio_librosa if len(pitched_audio_librosa) > 0 else None

        except Exception as e_audiotsm:
            logger.warning(f"audiotsm-basierter Pitch Shift Versuch fehlgeschlagen: {e_audiotsm}. Fallback auf librosa.")
            try:
                pitched_audio_librosa = librosa.effects.pitch_shift(y=audio_array_copy, sr=sr, n_steps=n_steps, bins_per_octave=24)
                return pitched_audio_librosa if len(pitched_audio_librosa) > 0 else None
            except Exception as e_librosa:
                logger.error(f"librosa.effects.pitch_shift (Fallback) failed for steps {n_steps}: {e_librosa}")
                return None
    else: # audiotsm nicht verfĂŒgbar oder Audio nicht 1D, nur librosa verwenden
        try:
            # Erhöhe bins_per_octave fĂŒr potenziell bessere QualitĂ€t mit librosa
            pitched_audio = librosa.effects.pitch_shift(y=audio_array_copy, sr=sr, n_steps=n_steps, bins_per_octave=24)
            return pitched_audio if len(pitched_audio) > 0 else None
        except Exception as e:
            logger.error(f"librosa.effects.pitch_shift failed for steps {n_steps}: {e}")
            return None

def change_volume(audio_array: np.ndarray, gain_db: float) -> Optional[np.ndarray]:
    if gain_db == 0.0: return audio_array.copy()
    if not (-20.0 <= gain_db <= 20.0):
        logger.debug(f"Volume gain_db {gain_db} out of range [-20, 20]. Skipping.")
        return None
    try:
        gain_factor = np.power(10, gain_db / 20.0)
        vol_changed_audio = audio_array * gain_factor
        return vol_changed_audio
    except Exception as e:
        logger.warning(f"change_volume failed for gain_db {gain_db}: {e}")
        return None
# --- Ende Augmentierungsfunktionen ---

def preprocess_sample_wrapper(
    task_input: Tuple[Dict[str, Any], int, int],
    text_tokenizer: Any,
    speech_tokenizer: S3Tokenizer,
    voice_encoder: Any,
    t3_config_from_model: T3Config,
    max_text_len_arg: int,
    max_speech_len_arg: int,
    audio_prompt_duration_s: float,
    output_dir: Path,
    silence_padding_ms: int,
    do_augment: bool,
    arg_speeds: List[float],
    arg_pitches: List[float],
    arg_volumes_db: List[float]
) -> List[Tuple[int, bool, str]]:

    original_pair, original_file_index, global_sample_offset = task_input
    audio_fpath = original_pair["audio"]
    text_content = original_pair["text"]

    augmentation_results = []
    processed_sample_counter_for_this_file = 0

    param_combinations: List[Tuple[float, float, float]] = []
    param_combinations.append((1.0, 0.0, 0.0)) 

    if do_augment:
        for speed_val in arg_speeds: param_combinations.append((speed_val, 0.0, 0.0))
        for pitch_val in arg_pitches: param_combinations.append((1.0, pitch_val, 0.0))
        for volume_val in arg_volumes_db: param_combinations.append((1.0, 0.0, volume_val))
    
    param_combinations = sorted(list(set(param_combinations)))
    logger.debug(f"OrigIdx {original_file_index} ({audio_fpath.name}): {len(param_combinations)} individual augmentations planned (incl. original).")

    wav_array_orig_sr_float32: Optional[np.ndarray] = None
    sr_orig_raw: Optional[int] = None
    try:
        wav_array_orig_sr_float32, sr_orig_raw_loaded = librosa.load(str(audio_fpath), sr=None, mono=True, dtype=np.float32)
        if wav_array_orig_sr_float32 is None or len(wav_array_orig_sr_float32) == 0:
            raise ValueError("Geladene Audiodatei ist leer oder ungĂŒltig.")
        sr_orig_raw = sr_orig_raw_loaded 
    except Exception as e_load:
        logger.error(f"Error loading original audio {audio_fpath}: {e_load}", exc_info=True)
        num_expected_variants_for_this_file = 1
        if do_augment:
             num_expected_variants_for_this_file += len(arg_speeds) + len(arg_pitches) + len(arg_volumes_db)
        for i in range(num_expected_variants_for_this_file):
             augmentation_results.append((global_sample_offset + i, False, f"{audio_fpath.name}_original_load_exception"))
        return augmentation_results

    test_audio_output_dir = None
    if original_file_index == 0:
        test_audio_output_dir = output_dir / "test_audio_samples_debug"
        test_audio_output_dir.mkdir(parents=True, exist_ok=True)
        logger.info(f"Saving test audios for {audio_fpath.name} (OriginalIndex 0) in {test_audio_output_dir}")

    for speed_factor, pitch_steps, gain_db in param_combinations:
        current_global_sample_index = global_sample_offset + processed_sample_counter_for_this_file
        
        is_original_params = (speed_factor == 1.0 and pitch_steps == 0.0 and gain_db == 0.0)
        aug_name_parts = []
        if not is_original_params:
            if speed_factor != 1.0: aug_name_parts.append(f"s{str(speed_factor).replace('.', 'p')}")
            if pitch_steps != 0.0: aug_name_parts.append(f"p{str(pitch_steps).replace('.', 'p').replace('-', 'm')}")
            if gain_db != 0.0: aug_name_parts.append(f"v{str(gain_db).replace('.', 'p').replace('-', 'm')}")
            aug_name = "_".join(aug_name_parts)
        else:
            aug_name = "original"
        if not aug_name: aug_name = f"s{str(speed_factor).replace('.', 'p')}_p{str(pitch_steps).replace('.', 'p').replace('-', 'm')}_v{str(gain_db).replace('.', 'p').replace('-', 'm')}"

        try:
            if sr_orig_raw is None or wav_array_orig_sr_float32 is None :
                 raise RuntimeError("Original-Samplerate oder Audiodaten nicht korrekt geladen.")

            current_audio_to_process = wav_array_orig_sr_float32.copy()
            current_sr = sr_orig_raw
            
            if speed_factor != 1.0:
                augmented_audio = change_speed(current_audio_to_process, current_sr, speed_factor)
                if augmented_audio is None: raise ValueError(f"Speed change (factor {speed_factor}) failed.")
                current_audio_to_process = augmented_audio
            
            if pitch_steps != 0.0:
                augmented_audio = change_pitch(current_audio_to_process, current_sr, pitch_steps)
                if augmented_audio is None: raise ValueError(f"Pitch change ({pitch_steps} steps) failed.")
                current_audio_to_process = augmented_audio

            if gain_db != 0.0:
                augmented_audio = change_volume(current_audio_to_process, gain_db)
                if augmented_audio is None: raise ValueError(f"Volume change ({gain_db}dB) failed.")
                current_audio_to_process = augmented_audio
            
            if current_sr != S3_SR:
                final_audio_for_preprocess = librosa.resample(y=current_audio_to_process, orig_sr=current_sr, target_sr=S3_SR)
            else:
                final_audio_for_preprocess = current_audio_to_process

            if len(final_audio_for_preprocess) == 0:
                raise ValueError("Audio nach Augmentierung/Resampling leer.")

            if test_audio_output_dir and original_file_index == 0:
                test_audio_filename = f"{audio_fpath.stem}_AUG_{aug_name}_origidx{original_file_index}_gidx{current_global_sample_index}_SR{S3_SR}.wav"
                try:
                    sf.write(str(test_audio_output_dir / test_audio_filename), final_audio_for_preprocess, S3_SR, subtype='FLOAT')
                except Exception as e_sf:
                    logger.warning(f"Could not save test audio {test_audio_filename}: {e_sf}")

            processed_data = preprocess_sample(
                audio_array_input=final_audio_for_preprocess,
                original_sr_input_after_augmentation=S3_SR, 
                audio_fpath_for_logging=audio_fpath, text=text_content,
                text_tokenizer=text_tokenizer, speech_tokenizer=speech_tokenizer,
                voice_encoder=voice_encoder, t3_config_from_model=t3_config_from_model,
                max_text_len_arg=max_text_len_arg, max_speech_len_arg=max_speech_len_arg,
                audio_prompt_duration_s=audio_prompt_duration_s, silence_padding_ms=silence_padding_ms,
            )

            if processed_data:
                base_filename = audio_fpath.stem
                safe_base_filename = "".join(c if c.isalnum() or c in ['-', '_'] else '_' for c in base_filename)
                output_filename_pt = f"{safe_base_filename}_aug_{aug_name}_{current_global_sample_index:07d}.pt"
                torch.save(processed_data, output_dir / output_filename_pt)
                augmentation_results.append((current_global_sample_index, True, output_filename_pt))
            else:
                err_msg = f"{audio_fpath.name}_aug_{aug_name}_preprocess_failed"
                augmentation_results.append((current_global_sample_index, False, err_msg))
        
        except ValueError as ve: 
            logger.warning(f"Augmentation variant '{aug_name}' for {audio_fpath.name} (GlobalIdx {current_global_sample_index}) failed: {ve}. Skipping.")
            augmentation_results.append((current_global_sample_index, False, f"{audio_fpath.name}_aug_{aug_name}_value_error"))
        except RuntimeError as rterr:
             logger.error(f"Runtime error for {audio_fpath.name} (Aug: {aug_name}, GlobalIdx {current_global_sample_index}): {rterr}", exc_info=False)
             augmentation_results.append((current_global_sample_index, False, f"{audio_fpath.name}_aug_{aug_name}_runtime_error"))
        except Exception as e_outer:
            logger.error(f"Unexpected error for {audio_fpath.name} (Aug: {aug_name}, GlobalIdx {current_global_sample_index}): {e_outer}", exc_info=True)
            augmentation_results.append((current_global_sample_index, False, f"{audio_fpath.name}_aug_{aug_name}_exception"))
        finally:
            if 'current_audio_to_process' in locals() and current_audio_to_process is not None: del current_audio_to_process
            if 'final_audio_for_preprocess' in locals() and final_audio_for_preprocess is not None: del final_audio_for_preprocess
            if 'augmented_audio' in locals() and augmented_audio is not None: del augmented_audio
        
        processed_sample_counter_for_this_file += 1
        if processed_sample_counter_for_this_file >= 150:
            logger.warning(f"Max augmentation limit (150) for {audio_fpath.name} reached.")
            remaining_planned_slots = len(param_combinations) - processed_sample_counter_for_this_file
            for i in range(remaining_planned_slots):
                augmentation_results.append((global_sample_offset + processed_sample_counter_for_this_file + i, False, f"{audio_fpath.name}_max_aug_limit_reached"))
            break
    
    expected_results_count = len(param_combinations)
    while len(augmentation_results) < expected_results_count:
        missing_idx = len(augmentation_results)
        augmentation_results.append((global_sample_offset + missing_idx, False, f"{audio_fpath.name}_processing_loop_incomplete_at_idx_{missing_idx}"))
        # logger.debug(f"Padded missing result for {audio_fpath.name} at local variant index {missing_idx}")

    if wav_array_orig_sr_float32 is not None: del wav_array_orig_sr_float32
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    return augmentation_results


def preprocess_sample(
    audio_array_input: np.ndarray, 
    original_sr_input_after_augmentation: int, 
    audio_fpath_for_logging: Path,
    text: str,
    text_tokenizer: Any,
    speech_tokenizer: S3Tokenizer,
    voice_encoder: Any,
    t3_config_from_model: T3Config,
    max_text_len_arg: int,
    max_speech_len_arg: int,
    audio_prompt_duration_s: float,
    silence_padding_ms: int,
) -> Optional[Dict[str, torch.Tensor]]:
    try:
        if len(audio_array_input) == 0: return None

        if original_sr_input_after_augmentation != S3_SR:
            wav_16k_final = librosa.resample(y=audio_array_input, orig_sr=original_sr_input_after_augmentation, target_sr=S3_SR)
        else:
            wav_16k_final = audio_array_input.copy()

        if wav_16k_final.ndim > 1: wav_16k_final = librosa.to_mono(wav_16k_final)
        if wav_16k_final.dtype != np.float32: wav_16k_final = wav_16k_final.astype(np.float32)
        if len(wav_16k_final) == 0: return None

        if silence_padding_ms > 0:
            padding_samples = int((silence_padding_ms / 1000.0) * S3_SR)
            silence_array = np.zeros(padding_samples, dtype=np.float32)
            wav_16k_padded_silence = np.concatenate((silence_array, wav_16k_final, silence_array))
        else:
            wav_16k_padded_silence = wav_16k_final

        if len(wav_16k_padded_silence) == 0: return None
        
        MAX_AUDIO_SAMPLES_FOR_S3TOKENIZER_INPUT = int(18.0 * S3_SR) 
        wav_16k_for_tokenizer_input = wav_16k_padded_silence[:MAX_AUDIO_SAMPLES_FOR_S3TOKENIZER_INPUT]
        
        MIN_VE_AUDIO_LEN_SAMPLES = S3_SR // 10 
        if len(wav_16k_padded_silence) < MIN_VE_AUDIO_LEN_SAMPLES:
            padding_needed = MIN_VE_AUDIO_LEN_SAMPLES - len(wav_16k_padded_silence)
            wav_for_ve = np.pad(wav_16k_padded_silence, (0, padding_needed), 'constant')
        else:
            wav_for_ve = wav_16k_padded_silence

        speaker_emb_np = voice_encoder.embeds_from_wavs([wav_for_ve], sample_rate=S3_SR)
        speaker_emb = torch.from_numpy(speaker_emb_np[0]).cpu() 

        start_text_token_id = t3_config_from_model.start_text_token
        stop_text_token_id = t3_config_from_model.stop_text_token
        start_speech_token_id = t3_config_from_model.start_speech_token
        stop_speech_token_id = t3_config_from_model.stop_speech_token
        speech_prompt_len_from_config = t3_config_from_model.speech_cond_prompt_len

        MIN_TOKENIZER_AUDIO_LEN_SAMPLES = S3_SR // 20 
        if len(wav_16k_for_tokenizer_input) < MIN_TOKENIZER_AUDIO_LEN_SAMPLES: return None

        raw_speech_tokens_batch, speech_token_lengths_batch = speech_tokenizer.forward([wav_16k_for_tokenizer_input])
        if raw_speech_tokens_batch is None or speech_token_lengths_batch is None or speech_token_lengths_batch.squeeze(0).item() == 0: return None
        raw_speech_tokens = raw_speech_tokens_batch.cpu().squeeze(0)[:speech_token_lengths_batch.cpu().squeeze(0).item()]
        
        speech_tokens_unpadded_full_sequence = F.pad(raw_speech_tokens, (1, 0), value=start_speech_token_id)
        speech_tokens_unpadded_full_sequence = F.pad(speech_tokens_unpadded_full_sequence, (0, 1), value=stop_speech_token_id)
        current_total_speech_len_incl_start_stop = len(speech_tokens_unpadded_full_sequence)

        MIN_PREDICTABLE_CONTENT_TOKENS = 1
        if current_total_speech_len_incl_start_stop < (1 + MIN_PREDICTABLE_CONTENT_TOKENS + 1): return None
        
        speech_token_actual_len = torch.tensor(current_total_speech_len_incl_start_stop, dtype=torch.long).cpu()
        padding_token_id_speech = stop_speech_token_id
        if current_total_speech_len_incl_start_stop > max_speech_len_arg:
            padded_speech_tokens = speech_tokens_unpadded_full_sequence[:max_speech_len_arg-1]
            padded_speech_tokens = torch.cat([padded_speech_tokens, torch.tensor([stop_speech_token_id],dtype=torch.long).cpu()])
            speech_token_actual_len = torch.tensor(max_speech_len_arg, dtype=torch.long).cpu()
        else:
            padded_speech_tokens = F.pad(speech_tokens_unpadded_full_sequence, (0, max_speech_len_arg - current_total_speech_len_incl_start_stop), value=padding_token_id_speech)
        
        if padded_speech_tokens.size(0) != max_speech_len_arg: return None

        enc_cond_audio_len_samples_prompt = int(audio_prompt_duration_s * S3_SR)
        cond_audio_segment_for_prompt_np = wav_16k_padded_silence[:enc_cond_audio_len_samples_prompt]
        
        target_prompt_tensor_len = speech_prompt_len_from_config
        cond_prompt_speech_tokens_padded: torch.Tensor

        if len(cond_audio_segment_for_prompt_np) < MIN_TOKENIZER_AUDIO_LEN_SAMPLES:
            cond_prompt_unpadded_temp = torch.full((1,), start_speech_token_id, dtype=torch.long).cpu()
        else:
            try:
                cond_prompt_tokens_batch, cond_prompt_lens_batch = speech_tokenizer.forward([cond_audio_segment_for_prompt_np], max_len=None)
                if cond_prompt_tokens_batch is not None and cond_prompt_lens_batch is not None and cond_prompt_lens_batch.squeeze(0).item() > 0:
                    cond_prompt_unpadded_temp = cond_prompt_tokens_batch.cpu().squeeze(0)[:cond_prompt_lens_batch.cpu().squeeze(0).item()]
                else:
                    cond_prompt_unpadded_temp = torch.full((1,), start_speech_token_id, dtype=torch.long).cpu()
            except Exception:
                cond_prompt_unpadded_temp = torch.full((1,), start_speech_token_id, dtype=torch.long).cpu()

        current_prompt_tokens_len = cond_prompt_unpadded_temp.size(0)
        if current_prompt_tokens_len > target_prompt_tensor_len:
            cond_prompt_speech_tokens_padded = cond_prompt_unpadded_temp[:target_prompt_tensor_len]
        elif current_prompt_tokens_len < target_prompt_tensor_len:
            cond_prompt_speech_tokens_padded = F.pad(cond_prompt_unpadded_temp, (0, target_prompt_tensor_len - current_prompt_tokens_len), value=start_speech_token_id)
        else:
            cond_prompt_speech_tokens_padded = cond_prompt_unpadded_temp
        
        if cond_prompt_speech_tokens_padded.size(0) != target_prompt_tensor_len: return None

        normalized_text = punc_norm(text)
        if not normalized_text.strip(): return None
            
        raw_text_tokens_for_text = text_tokenizer.text_to_tokens(normalized_text).cpu().squeeze(0)
        text_tokens_unpadded_full_text = F.pad(raw_text_tokens_for_text, (1, 0), value=start_text_token_id)
        text_tokens_unpadded_full_text = F.pad(text_tokens_unpadded_full_text, (0, 1), value=stop_text_token_id)
        current_text_len_unpadded = len(text_tokens_unpadded_full_text)
        text_token_actual_len = torch.tensor(current_text_len_unpadded, dtype=torch.long).cpu()
        
        padding_token_id_text = stop_text_token_id
        if current_text_len_unpadded > max_text_len_arg:
            padded_text_tokens = text_tokens_unpadded_full_text[:max_text_len_arg-1]
            padded_text_tokens = torch.cat([padded_text_tokens, torch.tensor([stop_text_token_id], dtype=torch.long).cpu()])
            text_token_actual_len = torch.tensor(max_text_len_arg, dtype=torch.long).cpu()
        else:
            padded_text_tokens = F.pad(text_tokens_unpadded_full_text, (0, max_text_len_arg - current_text_len_unpadded), value=padding_token_id_text)

        if padded_text_tokens.size(0) != max_text_len_arg: return None

        emotion_adv_scalar_tensor = torch.tensor(0.5, dtype=torch.float).cpu()
        
        return {
            "text_tokens": padded_text_tokens.long().cpu(), 
            "text_token_lens": text_token_actual_len.long().cpu(), 
            "speech_tokens": padded_speech_tokens.long().cpu(), 
            "speech_token_lens": speech_token_actual_len.long().cpu(), 
            "t3_cond_speaker_emb": speaker_emb.float().cpu(), 
            "t3_cond_prompt_speech_tokens": cond_prompt_speech_tokens_padded.long().cpu(), 
            "t3_cond_emotion_adv": emotion_adv_scalar_tensor.cpu()
        }
    except Exception as e:
        logger.error(f"Critical error in preprocess_sample for {audio_fpath_for_logging} (Text: '{text[:30]}...'): {e}", exc_info=False)
        return None


def main():
    parser = argparse.ArgumentParser(description="Preprocess TTS data with augmentation.")
    # ... (Argumente bleiben gleich) ...
    parser.add_argument("--model_load_path", type=str, required=True)
    parser.add_argument("--is_hub_model", action="store_true")
    parser.add_argument("--dataset_dir", type=str)
    parser.add_argument("--metadata_file", type=str)
    parser.add_argument("--output_dir", type=str, required=True)
    parser.add_argument("--max_text_len", type=int, default=2048)
    parser.add_argument("--max_speech_len", type=int, default=4096)
    parser.add_argument("--audio_prompt_duration_s", type=float, default=3.0)
    parser.add_argument("--cache_dir", type=str, default=None)
    parser.add_argument("--num_workers", type=int, default=None)
    parser.add_argument("--audio_ext", type=str, default="wav")
    parser.add_argument("--text_ext", type=str, default="txt")
    parser.add_argument("--silence_padding_ms", type=int, default=300, help="Silence in ms to add at start/end of audio *after* augmentations and resampling, before tokenization.")
    parser.add_argument("--augment_data", action="store_true", help="Enable audio data augmentation.")
    parser.add_argument("--aug_speeds", type=float, nargs='*', default=[], help="Speed factors (e.g., 0.9 1.1). Neutral 1.0 ignored.")
    parser.add_argument("--aug_pitches", type=float, nargs='*', default=[], help="Pitch shifts in semitones (e.g., -1 1). Neutral 0.0 ignored.")
    parser.add_argument("--aug_volumes_db", type=float, nargs='*', default=[], help="Volume changes dB (e.g., -3 3). Neutral 0.0 ignored.")

    args = parser.parse_args()
    output_dir_path_obj = Path(args.output_dir)
    output_dir_path_obj.mkdir(parents=True, exist_ok=True)

    valid_speeds = sorted(list(set(s for s in args.aug_speeds if 0.5 <= s <= 2.0 and s != 1.0)))
    valid_pitches = sorted(list(set(p for p in args.aug_pitches if -12.0 <= p <= 12.0 and p != 0.0)))
    valid_volumes_db = sorted(list(set(v for v in args.aug_volumes_db if -20.0 <= v <= 20.0 and v != 0.0)))
    
    augmentation_params_for_wrapper = {
        "speeds": valid_speeds,
        "pitches": valid_pitches,
        "volumes_db": valid_volumes_db,
    }

    pitch_method_used = "librosa.effects.pitch_shift (bins_per_octave=24)"
    if AUDIOTSM_AVAILABLE:
        # Hinweis: Auch wenn AUDIOTSM verfĂŒgbar ist, habe ich die Implementierung in change_pitch
        # so geĂ€ndert, dass sie derzeit auf librosa zurĂŒckfĂ€llt, da die audiotsm-Variante fĂŒr
        # reines Pitch-Shifting komplexer war als ursprĂŒnglich gedacht und die QualitĂ€t variieren kann.
        # Wenn du eine dedizierte audiotsm Pitch-Shift-Lösung einbauen willst, muss change_pitch angepasst werden.
        # Aktuell ist die Meldung etwas irrefĂŒhrend, da change_pitch intern entscheidet.
        logger.info(f"audiotsm ist verfĂŒgbar. Pitch Shifting wird intern librosa.effects.pitch_shift (bins_per_octave=24) verwenden, da die audiotsm-Pitch-Shift-Logik vereinfacht wurde.")


    if args.augment_data:
        logger.info(f"Data augmentation ENABLED. Pitch shifting verwendet: {pitch_method_used}.")
        # ... (Rest der Log-Ausgaben bleibt gleich) ...
        logger.info(f"  Individual augmentations (plus original):")
        if augmentation_params_for_wrapper["speeds"]: logger.info(f"    Speeds: {augmentation_params_for_wrapper['speeds']}")
        if augmentation_params_for_wrapper["pitches"]: logger.info(f"    Pitches (semitones): {augmentation_params_for_wrapper['pitches']}")
        if augmentation_params_for_wrapper["volumes_db"]: logger.info(f"    Volumes (dB): {augmentation_params_for_wrapper['volumes_db']}")
        if not any(augmentation_params_for_wrapper.values()):
            logger.warning("WARNING: --augment_data set, but no valid non-neutral augmentation parameters provided. Only originals processed.")
    else:
        logger.info("Data augmentation DEACTIVATED. Only original samples processed.")
        augmentation_params_for_wrapper = {"speeds": [], "pitches": [], "volumes_db": []}

    # ... (Rest der main-Funktion bleibt unverÀndert bis zum Ende) ...
    logger.info(f"Loading ChatterboxTTS model from: {args.model_load_path} ...")
    model_source_path: str
    if args.is_hub_model:
        from huggingface_hub import snapshot_download
        try: model_source_path = snapshot_download(repo_id=args.model_load_path, cache_dir=args.cache_dir, allow_patterns=["*.safetensors", "*.json", "*.pt", "config.json", "*.model"])
        except Exception as e: logger.error(f"Hub Download Error: {e}"); return
    else:
        model_source_path = args.model_load_path
        if not Path(model_source_path).is_dir(): logger.error(f"Local path not found: {model_source_path}"); return
    
    chatterbox_model: ChatterboxTTS
    try:
        chatterbox_model = ChatterboxTTS.from_local(
            ckpt_dir_for_model_weights=model_source_path,
            base_model_dir_for_static_components=model_source_path,
            device="cpu",
            target_t3_config_for_model=None
        )
        chatterbox_model.t3.to("cpu")
        chatterbox_model.s3gen.to("cpu")
        chatterbox_model.ve.to("cpu")
        if hasattr(chatterbox_model, 'tokenizer') and hasattr(chatterbox_model.tokenizer, 'to'):
             chatterbox_model.tokenizer.to("cpu")

        config_used = chatterbox_model.t3.hp
        logger.info(f"Using T3-Config (Base): PromptLen={getattr(config_used, 'speech_cond_prompt_len', 'N/A')}")
        logger.info(f"ChatterboxTTS model and components loaded to CPU.")
    except Exception as e: logger.error(f"Model loading error: {e}", exc_info=True); return

    text_tokenizer_instance = chatterbox_model.tokenizer
    speech_tokenizer_instance = chatterbox_model.s3gen.tokenizer
    voice_encoder_instance = chatterbox_model.ve
    t3_config_from_loaded_model = chatterbox_model.t3.hp
    
    file_pairs: List[Dict[str, Any]] = []
    if args.metadata_file:
        metadata_p = Path(args.metadata_file).resolve(); dataset_root = metadata_p.parent
        logger.info(f"Reading metadata from: {metadata_p}")
        try:
            with open(metadata_p, 'r', encoding='utf-8') as f:
                for i, line in enumerate(f):
                    parts = line.strip().split('|'); text_content_meta = ""
                    if len(parts) < 2: parts = line.strip().split('\t') 
                    if len(parts) >= 1: audio_file_rel = parts[0]
                    if len(parts) >= 2: text_content_meta = "|".join(parts[1:])
                    else: logger.debug(f"Line {i+1} in meta has no text. Skip."); continue
                    
                    audio_p_candidate1 = dataset_root / audio_file_rel
                    audio_p_candidate2 = Path(audio_file_rel) 
                    
                    audio_p_resolved = None
                    if audio_p_candidate1.exists() and audio_p_candidate1.is_file():
                        audio_p_resolved = audio_p_candidate1.resolve()
                    elif audio_p_candidate2.exists() and audio_p_candidate2.is_file() and audio_p_candidate2.is_absolute():
                         audio_p_resolved = audio_p_candidate2.resolve()

                    if audio_p_resolved: file_pairs.append({"audio": audio_p_resolved, "text": text_content_meta.strip()})
                    else: logger.debug(f"Audio not found for meta entry '{audio_file_rel}'. Tried {audio_p_candidate1}, {audio_p_candidate2}. Skip.")
        except Exception as e_meta: logger.error(f"Meta loading error: {e_meta}", exc_info=True); return
    elif args.dataset_dir:
        dataset_p = Path(args.dataset_dir).resolve(); logger.info(f"Searching for pairs in: {dataset_p}")
        if not dataset_p.is_dir(): logger.error(f"Dataset directory not found: {dataset_p}"); return
        found_audio_files = list(dataset_p.rglob(f"*.{args.audio_ext}"))
        for audio_path_obj in tqdm(found_audio_files, desc="Checking audio/text pairs"):
            text_path = audio_path_obj.with_suffix(f".{args.text_ext}")
            if text_path.exists():
                try:
                    text_content = text_path.read_text(encoding='utf-8').strip()
                    if text_content: file_pairs.append({"audio": audio_path_obj.resolve(), "text": text_content})
                    else: logger.debug(f"Text file {text_path} is empty. Skip.")
                except Exception as e_read_text: 
                    logger.warning(f"Could not read text file {text_path}: {e_read_text}. Skip.")
    else: logger.error("Neither --dataset_dir nor --metadata_file provided. Aborting."); return
    if not file_pairs: logger.error("No audio/text pairs found. Aborting."); return


    logger.info(f"Found {len(file_pairs)} original audio-text pairs. Starting preprocessing...")
    
    num_workers_requested = args.num_workers if args.num_workers is not None and args.num_workers > 0 else os.cpu_count()
    if num_workers_requested is None: num_workers_requested = 1
    
    num_aug_variants_per_file = 1 
    if args.augment_data:
        num_aug_variants_per_file += len(augmentation_params_for_wrapper["speeds"])
        num_aug_variants_per_file += len(augmentation_params_for_wrapper["pitches"])
        num_aug_variants_per_file += len(augmentation_params_for_wrapper["volumes_db"])
    if num_aug_variants_per_file == 0 : num_aug_variants_per_file = 1

    num_workers = min(num_workers_requested, len(file_pairs), 60 if os.name == 'nt' else (os.cpu_count() or 1) * 5)
    if num_workers <= 0: num_workers = 1
    logger.info(f"Using {num_workers} worker processes. Each original file generates up to {num_aug_variants_per_file} samples.")

    tasks_for_executor = []
    current_global_offset_for_filenames = 0
    for i, pair in enumerate(file_pairs):
        tasks_for_executor.append((pair, i, current_global_offset_for_filenames))
        current_global_offset_for_filenames += num_aug_variants_per_file 

    logger.info(f"Total expected samples to attempt (incl. augmentations): {current_global_offset_for_filenames}")

    process_fn_partial = functools.partial(
        preprocess_sample_wrapper,
        text_tokenizer=text_tokenizer_instance,
        speech_tokenizer=speech_tokenizer_instance,
        voice_encoder=voice_encoder_instance,
        t3_config_from_model=t3_config_from_loaded_model,
        max_text_len_arg=args.max_text_len,
        max_speech_len_arg=args.max_speech_len,
        audio_prompt_duration_s=args.audio_prompt_duration_s,
        output_dir=output_dir_path_obj,
        silence_padding_ms=args.silence_padding_ms,
        do_augment=args.augment_data,
        arg_speeds=augmentation_params_for_wrapper["speeds"],
        arg_pitches=augmentation_params_for_wrapper["pitches"],
        arg_volumes_db=augmentation_params_for_wrapper["volumes_db"]
    )

    total_successful_pt_files = 0
    total_failed_or_skipped_variants = 0 
    total_original_files_submitted_to_workers = 0

    executor = None
    try:
        executor = ProcessPoolExecutor(max_workers=num_workers)
        futures = []
        for task_input_tuple in tasks_for_executor:
            futures.append(executor.submit(process_fn_partial, task_input_tuple))
            total_original_files_submitted_to_workers +=1
        
        for future in tqdm(as_completed(futures), total=len(futures), desc="Processing original files (incl. augmentations)"):
            try:
                augmentation_batch_results = future.result() 
                for _, success, _ in augmentation_batch_results:
                    if success:
                        total_successful_pt_files += 1
                    else:
                        total_failed_or_skipped_variants += 1
            except Exception as e_future:
                logger.error(f"A worker process failed critically: {e_future}", exc_info=True)
                total_failed_or_skipped_variants += num_aug_variants_per_file 
    except Exception as e_pool:
        logger.error(f"Error occurred in ProcessPoolExecutor: {e_pool}", exc_info=True)
    finally:
        if executor:
            logger.info("Shutting down worker processes...")
            executor.shutdown(wait=True)
        del chatterbox_model, text_tokenizer_instance, speech_tokenizer_instance, voice_encoder_instance
        del t3_config_from_loaded_model, file_pairs, tasks_for_executor, process_fn_partial
        gc.collect()
        logger.info("Main process cleanup done.")

    logger.info(f"Preprocessing COMPLETE.")
    logger.info(f"  {total_original_files_submitted_to_workers} original files submitted to workers.")
    logger.info(f"  Successfully created {total_successful_pt_files} .pt files.")
    if total_failed_or_skipped_variants > 0:
        logger.warning(f"  {total_failed_or_skipped_variants} variants failed or were skipped.")
    logger.info(f"  Saved .pt files in: '{args.output_dir}'.")
    test_audio_dir_check = output_dir_path_obj / "test_audio_samples_debug"
    if test_audio_dir_check.exists() and any(test_audio_dir_check.iterdir()):
        logger.info(f"  Test audio samples saved in: '{test_audio_dir_check}'.")

if __name__ == "__main__":
    main()

havok2-htwo avatar Jun 26 '25 21:06 havok2-htwo

finetune_t3_preprocessed.py: - the gpu magic starts and heats up your room

# finetune_t3_preprocessed.py

import argparse
import logging
import os
import json
from pathlib import Path
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Union, Any
import shutil
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
import librosa
import numpy as np
import soundfile as sf
from torch.utils.tensorboard import SummaryWriter
import gc # HinzugefĂŒgt fĂŒr Speicherbereinigung

# --- BEGIN MODIFICATION FOR PICKLE ERROR ---
try:
    import numpy as pynumpy_for_pickle # GeÀnderter Alias
    import torch.serialization
    safe_globals_to_add = [
        pynumpy_for_pickle.core.multiarray._reconstruct,
        pynumpy_for_pickle.dtype,
        pynumpy_for_pickle.ndarray,
    ]
    try:
        safe_globals_to_add.append(pynumpy_for_pickle.dtypes.UInt32DType)
    except AttributeError:
        logging.warning("numpy.dtypes.UInt32DType not found, not adding to safe globals.")
    torch.serialization.add_safe_globals(safe_globals_to_add)
    logging.info(
        f"Attempted to add NumPy types to torch.serialization.get_safe_globals() for checkpoint loading: "
        f"{[g.__name__ if hasattr(g, '__name__') else str(g) for g in safe_globals_to_add]}"
    )
except ImportError:
    logging.warning("Could not import numpy or torch.serialization to add safe globals.")
except AttributeError as e:
    logging.warning(f"General AttributeError while trying to add safe globals: {e}.")
# --- END MODIFICATION FOR PICKLE ERROR ---

from transformers import (
    HfArgumentParser,
    EarlyStoppingCallback,
    set_seed,
    Trainer,
    PretrainedConfig,
    TrainerCallback, # HinzugefĂŒgt
    TrainerState,    # HinzugefĂŒgt
    TrainerControl   # HinzugefĂŒgt
)
from transformers import TrainingArguments as HfTrainingArguments
from datasets import load_dataset, DatasetDict, VerificationMode, Audio
import datasets

from chatterbox.tts import ChatterboxTTS, Conditionals, punc_norm, REPO_ID
from chatterbox.models.t3.t3 import T3, T3Cond
from chatterbox.models.t3.modules.t3_config import T3Config
from chatterbox.models.s3tokenizer import S3_SR, S3Tokenizer

logger = logging.getLogger(__name__)

# --- Detailed Memory Clearing Callback ---
logger_callback = logging.getLogger(__name__ + ".DetailedMemoryClearingCallback")

class DetailedMemoryClearingCallback(TrainerCallback):
    def _log_memory_usage(self, stage: str, device: Union[str, torch.device]):
        if torch.cuda.is_available():
            torch.cuda.synchronize(device) # Wichtig fĂŒr akkurate Messungen
            allocated = torch.cuda.memory_allocated(device) / (1024 ** 2)
            reserved = torch.cuda.memory_reserved(device) / (1024 ** 2)
            # max_allocated = torch.cuda.max_memory_allocated(device) / (1024 ** 2) # RĂŒcksetzen fĂŒr nĂ€chste Phase
            # max_reserved = torch.cuda.max_memory_reserved(device) / (1024 ** 2) # RĂŒcksetzen fĂŒr nĂ€chste Phase
            # torch.cuda.reset_peak_memory_stats(device) # Max-Werte fĂŒr die nĂ€chste Phase zurĂŒcksetzen

            logger_callback.info(
                f"CUDA Memory ({stage}) on device {device}: \n"
                f"  Allocated: {allocated:.2f} MB\n"
                f"  Reserved: {reserved:.2f} MB"
                # f"  Max Allocated (since last reset): {max_allocated:.2f} MB\n"
                # f"  Max Reserved (since last reset): {max_reserved:.2f} MB"
            )
            # Detailliertere Infos bei Bedarf:
            # logger_callback.debug(torch.cuda.memory_summary(device=device, abbreviated=False))

    def _clear_memory(self, stage: str, device: Union[str, torch.device]):
        logger_callback.info(f"Attempting to clear memory ({stage}) on device {device}...")
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        logger_callback.info(f"Memory clearing attempt ({stage}) finished.")

    def on_train_begin(self, args: HfTrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        if state.is_world_process_zero:
            logger_callback.info("Training started. Initial memory state.")
            self._log_memory_usage("on_train_begin", args.device)

    def on_epoch_begin(self, args: HfTrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        if state.is_world_process_zero:
            self._log_memory_usage(f"Epoch {state.epoch:.0f} begin, before cleanup", args.device)
            self._clear_memory(f"Epoch {state.epoch:.0f} begin", args.device)
            self._log_memory_usage(f"Epoch {state.epoch:.0f} begin, after cleanup", args.device)

    def on_step_begin(self, args: HfTrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        # Nur loggen, wenn ein Logging-Step erreicht ist, um die Logs nicht zu ĂŒberfluten
        if state.is_world_process_zero and state.global_step > 0 and \
           args.logging_steps > 0 and state.global_step % args.logging_steps == 0:
            self._log_memory_usage(f"Step {state.global_step} begin", args.device)


    def on_evaluate(self, args: HfTrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        if state.is_world_process_zero:
            self._log_memory_usage("Before on_evaluate cleanup", args.device)
            self._clear_memory("on_evaluate", args.device)
            self._log_memory_usage("After on_evaluate cleanup", args.device)

    def on_log(self, args: HfTrainingArguments, state: TrainerState, control: TrainerControl, logs=None, **kwargs):
        if state.is_world_process_zero and logs is not None:
            is_eval_log = any(k.startswith("eval_") for k in logs)
            current_stage = "on_log (post-eval)" if is_eval_log else f"on_log (step {state.global_step})"

            self._log_memory_usage(f"Before {current_stage} cleanup", args.device)
            self._clear_memory(current_stage, args.device)
            self._log_memory_usage(f"After {current_stage} cleanup", args.device)

    def on_save(self, args: HfTrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        if state.is_world_process_zero:
            self._log_memory_usage("Before on_save cleanup", args.device)
            self._clear_memory("on_save", args.device)
            self._log_memory_usage("After on_save cleanup", args.device)

    def on_train_end(self, args: HfTrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        if state.is_world_process_zero:
            self._log_memory_usage("Before on_train_end cleanup", args.device)
            self._clear_memory("on_train_end", args.device)
            self._log_memory_usage("After on_train_end cleanup", args.device)

# --- Argument Klassen ---
@dataclass
class CustomTrainingArguments(HfTrainingArguments):
    early_stopping_patience: Optional[int] = field(
        default=None, metadata={"help": "Enable early stopping with specified patience. Default: None (disabled)."}
    )
    generate_samples_on_finish: bool = field(
        default=True, metadata={"help": "Generate audio samples after training/evaluation."}
    )
    num_samples_to_generate: int = field(
        default=3, metadata={"help": "Number of audio samples to generate PER CHECKPOINT."}
    )
    sample_texts: Optional[List[str]] = field(
        default_factory=lambda: ["Hallo Welt, dies ist ein Test.", "Wie klingt meine Stimme nach dem Training?", "Die kĂŒnstliche Intelligenz macht Fortschritte."],
        metadata={"help": "List of texts to use for generating audio samples."}
    )
    generate_from_all_checkpoints: bool = field(
        default=False, metadata={"help": "If true, iterate through all checkpoint subfolders to generate samples, not just the final/best model."}
    )
    # HfTrainingArguments hat bereits per_device_eval_batch_size.
    # Wenn wir es hier mit einem anderen Default definieren, ĂŒberschreiben wir es.
    # Der Standard in HfTrainingArguments ist 8. FĂŒr TTS ist 1 oder 2 oft sicherer.
    per_device_eval_batch_size: int = field(
        default=1, metadata={"help": "Batch size for evaluation per device. Lower to save VRAM."}
    )


@dataclass
class ModelArguments:
    model_name_or_path: Optional[str] = field(default=None, metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"})
    local_model_dir: Optional[str] = field(default=None, metadata={"help": "Path to local directory containing ve.safetensors, t3_cfg.safetensors, etc. Overrides model_name_or_path for loading."})
    cache_dir: Optional[str] = field(default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"})
    freeze_voice_encoder: bool = field(default=True, metadata={"help": "Freeze the Voice Encoder."})
    freeze_s3gen: bool = field(default=True, metadata={"help": "Freeze the S3Gen model (speech token to waveform)."})

@dataclass
class DataArguments:
    preprocessed_data_dir: str = field(metadata={"help": "Path to the directory containing preprocessed .pt files."})
    eval_split_size: float = field(default=0.05, metadata={"help": "Fraction of preprocessed data to use for evaluation if splitting."})
    max_text_len: int = field(default=300, metadata={"help": "Max text len the model was trained with / expects."})
    max_speech_len: int = field(default=750, metadata={"help": "Max speech len the model was trained with / expects."})
    original_dataset_dir: Optional[str] = field(default=None, metadata={"help": "Path to the original raw audio dataset_dir (for sample generation conditioning)."})
    original_metadata_file: Optional[str] = field(default=None, metadata={"help": "Path to original raw metadata_file (for sample generation conditioning)."})
    original_dataset_name: Optional[str] = field(default=None, metadata={"help": "Original HF dataset name (for sample generation conditioning)."})
    original_dataset_config_name: Optional[str] = field(default=None, metadata={"help": "Original HF dataset config name (for sample generation conditioning)."})
    original_audio_column_name: str = field(default="audio", metadata={"help": "Audio column in original HF dataset (for sample generation conditioning)."})
    original_text_column_name: str = field(default="text", metadata={"help": "Text column in original HF dataset."})

# --- Dataset und Collator ---
class PreprocessedSpeechDataset(Dataset):
    def __init__(self, preprocessed_files: List[Path]):
        self.preprocessed_files = preprocessed_files
        if not self.preprocessed_files:
            raise ValueError("No preprocessed files provided to PreprocessedSpeechDataset.")
        logger.info(f"Initialized PreprocessedSpeechDataset with {len(self.preprocessed_files)} files.")

    def __len__(self):
        return len(self.preprocessed_files)

    def __getitem__(self, idx) -> Optional[Dict[str, torch.Tensor]]:
        file_path = self.preprocessed_files[idx]
        try:
            data_dict = torch.load(file_path, weights_only=None) # weights_only=None fĂŒr Ă€ltere PyTorch Versionen oder wenn nicht-Tensor Daten enthalten sind
            required_keys = ["text_tokens", "text_token_lens", "speech_tokens", "speech_token_lens",
                             "t3_cond_speaker_emb", "t3_cond_prompt_speech_tokens", "t3_cond_emotion_adv"]
            for key in required_keys:
                if key not in data_dict:
                    logger.error(f"Missing key '{key}' in preprocessed file {file_path}. Skipping.")
                    return None
            return data_dict
        except Exception as e:
            logger.error(f"Error loading or validating preprocessed file {file_path}: {e}. Skipping.")
            return None

@dataclass
class SpeechDataCollator:
    t3_config: T3Config
    def __call__(self, features: List[Optional[Dict[str, Any]]]) -> Dict[str, Any]:
        valid_features = [f for f in features if f is not None]
        if not valid_features:
            dummy_device = torch.device("cpu") # Oder das Device des Modells, wenn bekannt
            speaker_emb_dim = self.t3_config.llama_hidden_size
            prompt_len_collator = self.t3_config.speech_cond_prompt_len
            max_text_len_collator = self.t3_config.max_text_tokens
            max_speech_len_collator = self.t3_config.max_speech_tokens
            logger.warning("COLLATOR_DEBUG: SpeechDataCollator received no valid features. Returning empty batch structure.")
            return {
                "text_tokens": torch.empty(0, max_text_len_collator, dtype=torch.long, device=dummy_device),
                "text_token_lens": torch.empty(0, dtype=torch.long, device=dummy_device),
                "speech_tokens": torch.empty(0, max_speech_len_collator, dtype=torch.long, device=dummy_device),
                "speech_token_lens": torch.empty(0, dtype=torch.long, device=dummy_device),
                "t3_cond_speaker_emb": torch.empty(0, speaker_emb_dim, dtype=torch.float, device=dummy_device),
                "t3_cond_prompt_speech_tokens": torch.empty(0, prompt_len_collator, dtype=torch.long, device=dummy_device),
                "t3_cond_emotion_adv": torch.empty(0, 1, 1, dtype=torch.float, device=dummy_device),
                "labels_text": torch.empty(0, max_text_len_collator -1 if max_text_len_collator > 0 else 0, dtype=torch.long, device=dummy_device),
                "labels_speech": torch.empty(0, max_speech_len_collator -1 if max_speech_len_collator > 0 else 0, dtype=torch.long, device=dummy_device),
            }
        features = valid_features
        first_tensor_device = features[0]["text_tokens"].device # Behalte Device vom ersten Feature
        padded_text_tokens = torch.stack([f["text_tokens"] for f in features]).to(first_tensor_device)
        text_token_lens = torch.stack([f["text_token_lens"] for f in features]).to(first_tensor_device)
        padded_speech_tokens = torch.stack([f["speech_tokens"] for f in features]).to(first_tensor_device)
        speech_token_lens = torch.stack([f["speech_token_lens"] for f in features]).to(first_tensor_device)
        t3_cond_speaker_emb = torch.stack([f["t3_cond_speaker_emb"] for f in features]).to(first_tensor_device)
        t3_cond_prompt_speech_tokens = torch.stack([f["t3_cond_prompt_speech_tokens"] for f in features]).to(first_tensor_device)
        emotion_adv_scalars = torch.stack([f["t3_cond_emotion_adv"] for f in features]).to(first_tensor_device)
        t3_cond_emotion_adv = emotion_adv_scalars.view(len(features), 1, 1) # Ensure [B, 1, 1]
        IGNORE_ID = -100
        prompt_len = self.t3_config.speech_cond_prompt_len
        shifted_text = padded_text_tokens[:, 1:].contiguous()
        T_text = shifted_text.size(1)
        text_lens_minus_one = (text_token_lens - 1).clamp(min=0)
        arange_text = torch.arange(T_text, device=shifted_text.device)
        mask_pad_text = arange_text[None, :] >= text_lens_minus_one[:, None]
        labels_text = shifted_text.clone()
        labels_text[mask_pad_text] = IGNORE_ID
        shifted_speech = padded_speech_tokens[:, 1:].contiguous()
        T_speech = shifted_speech.size(1)
        speech_lens_minus_one = (speech_token_lens - 1).clamp(min=0)
        arange_speech = torch.arange(T_speech, device=shifted_speech.device)
        mask_pad_speech = arange_speech[None, :] >= speech_lens_minus_one[:, None]
        mask_prompt = arange_speech[None, :] < prompt_len
        mask_prompt = mask_prompt.expand(len(features), T_speech)
        mask_speech_total = mask_pad_speech | mask_prompt
        labels_speech = shifted_speech.clone()
        labels_speech[mask_speech_total] = IGNORE_ID
        num_valid_labels_speech = (labels_speech != IGNORE_ID).sum().item()
        if padded_text_tokens.size(0) > 0 and num_valid_labels_speech == 0:
            logger.warning(f"COLLATOR_WARNUNG: Batch (Size {padded_text_tokens.size(0)}) enthÀlt KEINE validen Sprach-Labels! "
                           f"Speech Lens (vor Padding): {speech_token_lens.tolist()}, Prompt Len (t3_config): {prompt_len}. ")
        return {"text_tokens": padded_text_tokens, "text_token_lens": text_token_lens,
                "speech_tokens": padded_speech_tokens, "speech_token_lens": speech_token_lens,
                "t3_cond_speaker_emb": t3_cond_speaker_emb,
                "t3_cond_prompt_speech_tokens": t3_cond_prompt_speech_tokens,
                "t3_cond_emotion_adv": t3_cond_emotion_adv,
                "labels_text": labels_text, "labels_speech": labels_speech}

# --- Modell-Wrapper fĂŒr Hugging Face Trainer ---
class T3ForFineTuning(torch.nn.Module):
    def __init__(self, t3_model: T3, chatterbox_t3_config: T3Config):
        super().__init__()
        self.t3 = t3_model
        self.chatterbox_t3_config = chatterbox_t3_config # Sollte die Ziel-Config sein
        class HFCompatibleConfig(PretrainedConfig): # Minimal config for HF Trainer compatibility
            model_type = "chatterbox_t3_finetune"
            def __init__(self, **kwargs):
                self.llama_config_name = kwargs.pop("llama_config_name", None)
                self.text_tokens_dict_size = kwargs.pop("text_tokens_dict_size", None)
                self.speech_tokens_dict_size = kwargs.pop("speech_tokens_dict_size", None)
                self.max_text_tokens = kwargs.pop("max_text_tokens", 256)
                self.max_speech_tokens = kwargs.pop("max_speech_tokens", 800)
                self.speech_cond_prompt_len = kwargs.pop("speech_cond_prompt_len", 100)
                self.start_text_token = kwargs.pop("start_text_token", 0)
                self.stop_text_token = kwargs.pop("stop_text_token", 1)
                self.start_speech_token = kwargs.pop("start_speech_token", 0)
                self.stop_speech_token = kwargs.pop("stop_speech_token", 1)
                # Weitere Felder aus T3Config, die fĂŒr die HF-Integration relevant sein könnten
                self.llama_hidden_size = kwargs.pop("llama_hidden_size", 4096) # Beispiel
                super().__init__(**kwargs)
        self.config = HFCompatibleConfig(**chatterbox_t3_config.__dict__)

    def forward(self, text_tokens, text_token_lens, speech_tokens, speech_token_lens,
                t3_cond_speaker_emb, t3_cond_prompt_speech_tokens, t3_cond_emotion_adv,
                labels_text = None, labels_speech=None):
        target_device = text_tokens.device # Alle Inputs sollten auf dem gleichen Device sein
        current_t3_cond = T3Cond(
            speaker_emb=t3_cond_speaker_emb.to(target_device), # Sicherstellen, dass Konditionierungen auf dem richtigen GerÀt sind
            cond_prompt_speech_tokens=t3_cond_prompt_speech_tokens.to(target_device),
            cond_prompt_speech_emb=None, # Wird im Training typischerweise nicht ĂŒbergeben, sondern intern erzeugt, falls nötig
            emotion_adv=t3_cond_emotion_adv.to(target_device) # Shape [B, 1, 1]
        ).to(device=target_device) # Sicherstellen, dass das T3Cond Objekt auf dem richtigen GerÀt ist

        loss_text, loss_speech, speech_logits = self.t3.loss(
            t3_cond=current_t3_cond, text_tokens=text_tokens, text_token_lens=text_token_lens,
            speech_tokens=speech_tokens, speech_token_lens=speech_token_lens,
            labels_text=labels_text, labels_speech=labels_speech
        )
        if loss_text is None or loss_speech is None: # Sollte nicht passieren, wenn Labels da sind
            zero_loss_device = speech_logits.device if speech_logits is not None else target_device
            zero_loss = torch.tensor(0.0, device=zero_loss_device, requires_grad=self.training)
            if loss_text is None: loss_text = zero_loss.clone()
            if loss_speech is None: loss_speech = zero_loss.clone()
        total_loss = loss_text + loss_speech
        # Trainer erwartet ein Tupel, dessen erstes Element der Loss ist.
        # Weitere Elemente (z.B. Logits) können fĂŒr Metrikberechnung etc. zurĂŒckgegeben werden.
        return (total_loss, speech_logits) if speech_logits is not None else (total_loss,)

# --- Funktionen fĂŒr Sample-Generierung ---
def find_checkpoint_folders_for_generation(base_dir: str) -> list[Path]:
    base_path = Path(base_dir).resolve()
    if not base_path.is_dir():
        logger.error(f"Basis-Checkpoint-Verzeichnis fĂŒr Sample-Generierung nicht gefunden: {base_path}")
        return []
    checkpoint_folders = [d for d in base_path.iterdir() if d.is_dir() and d.name.startswith("checkpoint-")]
    checkpoint_folders.append(base_path) # FĂŒr das "beste" Modell im Basisordner
    def sort_key(p: Path):
        if p == base_path: return -1 # Ensure base_path (best model) is processed, perhaps first or last based on need
        try: return int(p.name.split('-')[-1])
        except (ValueError, IndexError): return float('inf') # Put malformed names at the end
    return sorted(list(set(checkpoint_folders)), key=sort_key) # Use set to remove duplicates if base_path was also a checkpoint

def set_model_components_to_eval_mode(model_instance: ChatterboxTTS):
    if model_instance is None: return
    logger.info("  Setze Modellkomponenten in den Eval-Modus...")
    if hasattr(model_instance, 't3') and model_instance.t3 is not None: model_instance.t3.eval()
    if hasattr(model_instance, 's3gen') and model_instance.s3gen is not None: model_instance.s3gen.eval()
    if hasattr(model_instance, 've') and model_instance.ve is not None: model_instance.ve.eval()

def get_display_path(path_obj: Path, base_path_for_display: Path) -> str:
    try:
        common = Path(os.path.commonpath([str(path_obj.resolve()), str(base_path_for_display.resolve())]))
        if common == base_path_for_display.resolve():
            return str(path_obj.relative_to(base_path_for_display))
        return path_obj.name
    except ValueError:
        return path_obj.name

def generate_and_save_samples_from_checkpoints(
    output_dir_str: str,
    data_args_original: DataArguments,
    training_args_custom: CustomTrainingArguments,
    base_model_path_for_static_components_str: str
):
    if not training_args_custom.generate_samples_on_finish:
        logger.info("Überspringe Sample-Generierung gemĂ€ĂŸ Konfiguration.")
        return

    main_training_output_dir = Path(output_dir_str).resolve()
    paths_to_generate_from: List[Path]
    if training_args_custom.generate_from_all_checkpoints:
        logger.info("Starte erweiterte Sample-Generierung fĂŒr verschiedene Checkpoints...")
        paths_to_generate_from = find_checkpoint_folders_for_generation(str(main_training_output_dir))
    else:
        paths_to_generate_from = [main_training_output_dir]
        logger.info(f"Generiere Samples nur vom finalen/besten Modell in: {main_training_output_dir}")

    if not paths_to_generate_from:
        logger.warning(f"Keine Checkpoint-Pfade in {main_training_output_dir} gefunden zum Testen.")
        return

    overall_sample_output_dir = main_training_output_dir / "generated_samples_all_checkpoints_audio"
    overall_sample_output_dir.mkdir(parents=True, exist_ok=True)
    
    script_dir_parent_for_display = Path(__file__).parent.resolve().parent
    logger.info(f"Gefundene Checkpoint-Pfade zum Testen (relativ zu {script_dir_parent_for_display}):")
    for p in paths_to_generate_from:
        logger.info(f"  - {get_display_path(p, script_dir_parent_for_display)}")

    generation_device = training_args_custom.device
    selected_conditioning_audio = None

    if data_args_original.original_dataset_dir:
        orig_dir = Path(data_args_original.original_dataset_dir)
        if not orig_dir.is_absolute():
             script_execution_dir = Path.cwd()
             orig_dir = (script_execution_dir / orig_dir).resolve()

        if orig_dir.is_dir():
            try:
                audio_files = list(orig_dir.rglob("*.wav")) + list(orig_dir.rglob("*.mp3")) + list(orig_dir.rglob("*.flac"))
                if audio_files:
                    selected_conditioning_audio = str(np.random.choice(audio_files).resolve())
                    logger.info(f"ZufÀllige Konditionierungs-Audiodatei ausgewÀhlt: {selected_conditioning_audio}")
                else: logger.warning(f"Keine .wav/.mp3/.flac Dateien in {orig_dir} fĂŒr Konditionierung gefunden.")
            except Exception as e:  logger.warning(f"Fehler beim Suchen nach Konditionierungs-Audio in {orig_dir}: {e}")
    
    if not selected_conditioning_audio and training_args_custom.sample_texts:
        logger.warning("Kein Konditionierungs-Audio gefunden. Versuche mit Default-Konditionierung des Modells (falls vorhanden).")
            
    texts_to_generate = training_args_custom.sample_texts
    if not texts_to_generate: texts_to_generate = ["Hallo, dies ist eine Standard-Sprachausgabe."]
    num_to_gen_per_ckpt = min(training_args_custom.num_samples_to_generate, len(texts_to_generate))

    for ckpt_path_for_t3 in paths_to_generate_from:
        ckpt_name_for_folder = ckpt_path_for_t3.name
        if ckpt_path_for_t3 == main_training_output_dir and not ckpt_path_for_t3.name.startswith("checkpoint-"):
            ckpt_name_for_folder = "final_best_model"

        logger.info(f"\n--- Teste Checkpoint: {ckpt_name_for_folder} (T3-Gewichte von: {get_display_path(ckpt_path_for_t3, script_dir_parent_for_display)}) ---")
        
        current_model_instance = None
        try:
            current_model_instance = ChatterboxTTS.from_local(
                ckpt_dir_for_model_weights=str(ckpt_path_for_t3),
                base_model_dir_for_static_components=base_model_path_for_static_components_str,
                device=generation_device
            )
            set_model_components_to_eval_mode(current_model_instance)
            logger.info(f"  Modell erfolgreich fĂŒr Checkpoint '{ckpt_name_for_folder}' geladen.")

            if selected_conditioning_audio:
                try:
                    logger.info(f"    Bereite Konditionierung mit Audio-Prompt vor: {Path(selected_conditioning_audio).name}")
                    current_model_instance.prepare_conditionals(audio_prompt_path=selected_conditioning_audio)
                except TypeError as te:
                     logger.error(f"    TypeError bei Konditionierung fĂŒr {ckpt_name_for_folder}: {te}. PrĂŒfen Sie Signatur von prepare_conditionals.", exc_info=True)
                except Exception as e_cond:
                    logger.error(f"    Anderer Fehler bei Konditionierung fĂŒr {ckpt_name_for_folder}: {e_cond}. Versuche mit Default.", exc_info=True)
            elif hasattr(current_model_instance, 'conds') and current_model_instance.conds is not None and current_model_instance.conds.is_prepared:
                 logger.info("    Verwende Default/geladene Konditionierung aus dem Modell.")
            else:
                logger.warning(f"    Kein Audio-Prompt & keine Default-Konditionierung fĂŒr {ckpt_name_for_folder}. Es wird versucht, ohne explizite Konditionierung zu generieren oder die prepare_conditionals-Methode des Modells wird Default-Verhalten haben.")


            ckpt_specific_sample_dir = overall_sample_output_dir / ckpt_name_for_folder
            ckpt_specific_sample_dir.mkdir(parents=True, exist_ok=True)

            for i in range(num_to_gen_per_ckpt):
                text_to_synth = texts_to_generate[i]
                logger.info(f"      Generiere Audio fĂŒr Text: \"{text_to_synth}\"")
                try:
                    wav_out_tensor, sr_out = current_model_instance.generate(text_to_synth)
                    wav_out_numpy = wav_out_tensor.squeeze().cpu().numpy()
                    filename_safe_text = "".join(c if c.isalnum() else "_" for c in text_to_synth[:30]).strip("_")
                    output_path = ckpt_specific_sample_dir / f"sample_{i+1}_{filename_safe_text}.wav"
                    sf.write(str(output_path), wav_out_numpy, sr_out)
                    logger.info(f"      Audio gespeichert unter: {output_path.relative_to(Path.cwd())}")
                except ValueError as ve:
                    logger.error(f"      ValueError bei Generierung fĂŒr '{text_to_synth}' mit {ckpt_name_for_folder}: {ve}.", exc_info=True)
                except Exception as e_gen:
                    logger.error(f"      Fehler bei Generierung fĂŒr '{text_to_synth}' mit {ckpt_name_for_folder}: {e_gen}", exc_info=True)
        except FileNotFoundError as e_load_fnf:
            logger.error(f"FileNotFoundError beim Laden von Checkpoint {ckpt_name_for_folder}: {e_load_fnf}.")
        except Exception as e_load:
            logger.error(f"Allgemeiner Fehler beim Laden/Verarbeiten von {ckpt_name_for_folder}: {e_load}", exc_info=True)
        finally:
            if current_model_instance is not None: del current_model_instance
            if isinstance(generation_device, torch.device) and generation_device.type == "cuda":
                 torch.cuda.empty_cache()
            elif isinstance(generation_device, str) and "cuda" in generation_device:
                 torch.cuda.empty_cache()
            logger.info(f"  Verarbeitung fĂŒr Checkpoint {ckpt_name_for_folder} abgeschlossen.")
    logger.info(f"\nAlle Sample-Generierungen abgeschlossen. Samples in: {overall_sample_output_dir.relative_to(Path.cwd())}")

# --- Hauptfunktion main() ---
def main():
    if os.name == 'nt':
        try:
            import torch.multiprocessing as mp
            mp.set_sharing_strategy('file_system')
            logger.info("Set multiprocessing sharing strategy to 'file_system' for Windows.")
        except Exception as e:
            logger.warning(f"Could not set multiprocessing sharing strategy to 'file_system': {e}.")

    parser = HfArgumentParser((ModelArguments, DataArguments, CustomTrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
        handlers=[logging.StreamHandler()] # Sicherstellen, dass Logs auf stdout/stderr gehen
    )
    # Setze auch den Root-Logger Level, falls andere Libraries darunter loggen
    logging.getLogger().setLevel(logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN)
    datasets.utils.logging.set_verbosity_info()


    logger.info("Training/evaluation parameters %s", training_args)
    logger.info("Model parameters %s", model_args)
    logger.info("Data parameters %s", data_args)
    set_seed(training_args.seed)

    # Wichtiger Hinweis zur Evaluation und Early Stopping:
    if training_args.early_stopping_patience is not None and training_args.early_stopping_patience > 0:
        if training_args.evaluation_strategy == "no" or \
           (training_args.evaluation_strategy == "steps" and training_args.eval_steps >= training_args.max_steps and training_args.num_train_epochs > 0) or \
           (training_args.evaluation_strategy == "steps" and training_args.eval_steps > 100000) : # Heuristik fĂŒr sehr hohe eval_steps
            logger.warning(
                f"EarlyStoppingCallback is enabled (patience={training_args.early_stopping_patience}), "
                f"but evaluation_strategy is '{training_args.evaluation_strategy}' "
                f"with eval_steps={training_args.eval_steps}. "
                "Early stopping requires regular evaluations to function correctly. "
                "Consider setting evaluation_strategy to 'steps' and eval_steps to a reasonable "
                "value (e.g., equal to save_steps or logging_steps)."
            )

    base_model_path_for_static_and_initial_t3_str: Optional[str] = None
    original_model_dir_for_copy_logic: Optional[Path] = None

    script_dir = Path(__file__).parent.resolve()

    if model_args.local_model_dir:
        lm_path = Path(model_args.local_model_dir)
        if not lm_path.is_absolute(): lm_path = (script_dir / lm_path).resolve()
        logger.info(f"Verwende local_model_dir als Basis: {lm_path}")
        if not lm_path.is_dir():
             logger.error(f"local_model_dir {lm_path} nicht gefunden! Abbruch.")
             return
        base_model_path_for_static_and_initial_t3_str = str(lm_path)
        original_model_dir_for_copy_logic = lm_path
    else:
        repo_to_download = model_args.model_name_or_path or REPO_ID
        logger.info(f"Verwende Modell von Hugging Face Hub ({repo_to_download}) als Basis.")
        download_base_parent_dir = script_dir.parent / "model_hub_cache"
        repo_name_for_dir = repo_to_download.replace("/", "_").replace("\\", "_")
        snapshot_dir_path = download_base_parent_dir / repo_name_for_dir
        snapshot_dir_path.mkdir(parents=True, exist_ok=True)
        logger.info(f"Stelle sicher, dass das Originalmodell in {snapshot_dir_path} vorhanden ist.")
        from huggingface_hub import snapshot_download
        try:
            actual_snapshot_dir = snapshot_download(
                repo_id=repo_to_download, cache_dir=model_args.cache_dir,
                local_dir=str(snapshot_dir_path), local_dir_use_symlinks=False, # local_dir muss str sein
                allow_patterns=["*.safetensors", "*.json", "*.pt", "config.json"],
            )
            base_model_path_for_static_and_initial_t3_str = actual_snapshot_dir
            original_model_dir_for_copy_logic = Path(actual_snapshot_dir)
            logger.info(f"Originalmodell-Komponenten sind in {actual_snapshot_dir} verfĂŒgbar.")
        except Exception as e:
            logger.error(f"Fehler beim Herunterladen des Originalmodells von {repo_to_download}: {e}", exc_info=True)
            return
            
    if not base_model_path_for_static_and_initial_t3_str:
        logger.error("Konnte keinen Pfad fĂŒr Basismodellkomponenten bestimmen. Abbruch.")
        return

    target_t3_config = T3Config()
    logger.info(f"T3Config Standardwerte: max_text={target_t3_config.max_text_tokens}, max_speech={target_t3_config.max_speech_tokens}")
    target_t3_config.max_text_tokens = data_args.max_text_len
    target_t3_config.max_speech_tokens = data_args.max_speech_len
    logger.info(f"Ziel-T3Config fĂŒr Training und Initialisierung: "
                f"max_text_tokens={target_t3_config.max_text_tokens}, "
                f"max_speech_tokens={target_t3_config.max_speech_tokens}, "
                f"start_text_token={target_t3_config.start_text_token}, "
                f"speech_cond_prompt_len={target_t3_config.speech_cond_prompt_len}")

    logger.info(f"Lade initiales T3-Modell und passe Gewichte an Ziel-Config an (Quelle: {base_model_path_for_static_and_initial_t3_str})")
    try:
        temp_full_model_for_t3_init = ChatterboxTTS.from_local(
            ckpt_dir_for_model_weights=base_model_path_for_static_and_initial_t3_str,
            base_model_dir_for_static_components=base_model_path_for_static_and_initial_t3_str,
            device="cpu",
            target_t3_config_for_model=target_t3_config
        )
        t3_model_to_finetune = temp_full_model_for_t3_init.t3
        t3_config_for_trainer = t3_model_to_finetune.hp

        if t3_config_for_trainer.max_text_tokens != data_args.max_text_len or \
           t3_config_for_trainer.max_speech_tokens != data_args.max_speech_len:
            logger.error(
                f"FATAL: Mismatch in T3 Konfiguration nach dem Laden! "
                f"Modell max_text: {t3_config_for_trainer.max_text_tokens} vs DataArgs: {data_args.max_text_len}. "
                f"Modell max_speech: {t3_config_for_trainer.max_speech_tokens} vs DataArgs: {data_args.max_speech_len}."
            )
            return
        del temp_full_model_for_t3_init
        logger.info("Initiales T3-Modell und Konfiguration erfolgreich geladen und an Ziel-LĂ€ngen angepasst.")
    except TypeError as te:
        logger.error(f"TypeError beim initialen Laden von ChatterboxTTS.from_local: {te}.", exc_info=True)
        return
    except Exception as e_load_init:
        logger.error(f"Allgemeiner Fehler beim initialen Laden des T3-Modells: {e_load_init}", exc_info=True)
        return

    for param in t3_model_to_finetune.parameters(): param.requires_grad = True
    if model_args.freeze_voice_encoder: # Diese sollten von ChatterboxTTS selbst gehandhabt werden, wenn T3 trainiert wird
        logger.warning("freeze_voice_encoder/freeze_s3gen sind in diesem Skript nicht direkt implementiert, "
                       "da nur das T3-Modul trainiert wird. Die ChatterboxTTS-Klasse sollte das Einfrieren "
                       "der anderen Komponenten bei Bedarf intern handhaben.")

    preprocessed_data_path = Path(data_args.preprocessed_data_dir)
    if not preprocessed_data_path.is_absolute():
        preprocessed_data_path = (script_dir / preprocessed_data_path).resolve()
    logger.info(f"Lade vorverarbeitete Daten von: {preprocessed_data_path}")
    if not preprocessed_data_path.is_dir():
        raise FileNotFoundError(f"Vorverarbeitetes Datenverzeichnis nicht gefunden: {preprocessed_data_path}")
    all_pt_files = sorted(list(preprocessed_data_path.glob("*.pt")))
    if not all_pt_files:
        raise FileNotFoundError(f"Keine .pt Dateien in {preprocessed_data_path} gefunden")
    
    rng = np.random.default_rng(training_args.seed)
    rng.shuffle(all_pt_files)

    train_pt_files: List[Path]
    eval_pt_files: Optional[List[Path]] = None
    if training_args.do_eval and data_args.eval_split_size > 0 and len(all_pt_files) > 1 :
        num_eval_samples = int(len(all_pt_files) * data_args.eval_split_size)
        if num_eval_samples == 0 and len(all_pt_files) > 1 : num_eval_samples = 1
        if num_eval_samples >= len(all_pt_files): num_eval_samples = max(0, len(all_pt_files) -1)
        
        if num_eval_samples > 0 and (len(all_pt_files) - num_eval_samples) > 0 :
            train_pt_files = all_pt_files[:-num_eval_samples]
            eval_pt_files = all_pt_files[-num_eval_samples:]
            logger.info(f"Dateien aufgeteilt: {len(train_pt_files)} Training, {len(eval_pt_files)} Evaluation.")
        else:
            train_pt_files = all_pt_files
            logger.warning("Split nicht möglich oder resultiert in 0 Trainings-/Evaluationssamples. Verwende alle fĂŒr Training.")
    else:
        train_pt_files = all_pt_files
        logger.info(f"Verwende alle {len(train_pt_files)} Dateien fĂŒr Training (kein Eval-Split).")

    train_dataset = PreprocessedSpeechDataset(train_pt_files) if train_pt_files else None
    eval_dataset = PreprocessedSpeechDataset(eval_pt_files) if eval_pt_files and training_args.do_eval else None

    if not train_dataset:
        logger.error("Keine Trainingsdaten. Abbruch.")
        return

    data_collator = SpeechDataCollator(t3_config_for_trainer)
    hf_trainable_model = T3ForFineTuning(t3_model_to_finetune, t3_config_for_trainer)

    callbacks_list = []
    if training_args.early_stopping_patience is not None and training_args.early_stopping_patience > 0:
        callbacks_list.append(EarlyStoppingCallback(early_stopping_patience=training_args.early_stopping_patience))

    # DetailedMemoryClearingCallback hinzufĂŒgen
    if training_args.do_train or training_args.do_eval: # Callback ist nĂŒtzlich fĂŒr beides
        callbacks_list.append(DetailedMemoryClearingCallback())

    trainer_instance = Trainer(
        model=hf_trainable_model, args=training_args,
        train_dataset=train_dataset, eval_dataset=eval_dataset,
        data_collator=data_collator, callbacks=callbacks_list if callbacks_list else None,
    )

    if training_args.do_train:
        logger.info("*** Training des T3-Modells ***")
        resume_from_ckt = training_args.resume_from_checkpoint
        if isinstance(resume_from_ckt, str):
            resume_path = Path(resume_from_ckt)
            if not resume_path.is_absolute(): resume_path = (script_dir / resume_path).resolve()
            if not resume_path.exists():
                logger.warning(f"resume_from_checkpoint Pfad {resume_path} existiert nicht. Starte neu.")
                resume_from_ckt = None
            else: resume_from_ckt = str(resume_path)
        
        train_result = trainer_instance.train(resume_from_checkpoint=resume_from_ckt)
        # Manuelles Speichern des reinen T3-Modells, wenn Trainer.save_model() den Wrapper speichert
        trainer_instance.save_model() # Speichert den HF-Wrapper (T3ForFineTuning)
        logger.info(f"HF Trainer Modell (Wrapper) gespeichert in: {training_args.output_dir}")

        logger.info("Speichere reine T3-Gewichte (t3_cfg.safetensors)...")
        model_to_save_t3 = trainer_instance.model.t3 if hasattr(trainer_instance.model, 't3') else \
                           (trainer_instance.model.module.t3 if hasattr(trainer_instance.model, 'module') and hasattr(trainer_instance.model.module, 't3') else None)

        if model_to_save_t3:
            finetuned_t3_state_dict = model_to_save_t3.state_dict()
            output_t3_safetensor_path = Path(training_args.output_dir).resolve() / "t3_cfg.safetensors"
            from safetensors.torch import save_file as save_safetensors_file
            save_safetensors_file(finetuned_t3_state_dict, str(output_t3_safetensor_path))
            logger.info(f"Reine T3-Gewichte gespeichert: {output_t3_safetensor_path}")
            
            # Speichere auch die T3Config des trainierten Modells
            output_t3_config_path = Path(training_args.output_dir).resolve() / "t3_model_config.json"
            with open(output_t3_config_path, 'w') as f:
                json.dump(model_to_save_t3.hp.__dict__, f, indent=4) # Speichere T3Config (hp)
            logger.info(f"T3-Modellkonfiguration (aus t3.hp) gespeichert: {output_t3_config_path}")

        else: logger.error("Konnte T3-Submodul nicht extrahieren zum separaten Speichern!")


        if original_model_dir_for_copy_logic and original_model_dir_for_copy_logic.is_dir():
            logger.info(f"Kopiere statische Komponenten von {original_model_dir_for_copy_logic} nach {training_args.output_dir}")
            final_output_dir = Path(training_args.output_dir).resolve()
            files_to_copy = ["ve.safetensors", "s3gen.safetensors", "tokenizer.json", "conds.pt"]
            for f_name in files_to_copy:
                src_path = original_model_dir_for_copy_logic / f_name
                dest_path = final_output_dir / f_name
                if src_path.exists():
                    try:
                        shutil.copy2(src_path, dest_path)
                        logger.info(f"  '{f_name}' kopiert nach {dest_path}")
                    except Exception as e_copy: logger.warning(f"  Fehler beim Kopieren von {src_path} nach {dest_path}: {e_copy}")
                elif f_name != "conds.pt": # conds.pt ist optional
                    logger.warning(f"  Kritische Quelldatei {src_path} nicht gefunden zum Kopieren.")
                else:
                    logger.info(f"  Optionale Quelldatei {src_path} (conds.pt) nicht gefunden, wird nicht kopiert.")
            logger.info(f"Modellkomponenten strukturiert in {final_output_dir}")
        else: logger.warning(f"Kein gĂŒltiger Pfad fĂŒr original_model_dir_for_copy_logic ({original_model_dir_for_copy_logic}). Statische Komponenten nicht kopiert.")

        metrics = train_result.metrics
        trainer_instance.log_metrics("train", metrics)
        trainer_instance.save_metrics("train", metrics)
        trainer_instance.save_state()

    if training_args.do_eval and eval_dataset:
        logger.info("*** Evaluiere T3-Modell ***")
        metrics = trainer_instance.evaluate()
        trainer_instance.log_metrics("eval", metrics)
        trainer_instance.save_metrics("eval", metrics)

    if training_args.generate_samples_on_finish:
        can_generate_samples = False
        final_output_dir_path = Path(training_args.output_dir).resolve()
        # PrĂŒfen, ob Modell trainiert wurde oder ein Checkpoint existiert, von dem geladen werden kann.
        # Und ob die notwendigen Dateien fĂŒr die Sample-Generierung vorhanden sind.
        model_files_exist = (final_output_dir_path / "t3_cfg.safetensors").exists() or \
                            (final_output_dir_path / "model.safetensors").exists() # HF Trainer speichert model.safetensors

        if (training_args.do_train or training_args.resume_from_checkpoint or training_args.do_eval) and \
           final_output_dir_path.is_dir() and model_files_exist:
            can_generate_samples = True
            logger.info("Bedingungen fĂŒr Sample-Generierung nach Training/Evaluation/Resume erfĂŒllt.")
        else:
            logger.warning(f"Kein valides Modell/Output-Verzeichnis ({final_output_dir_path}, Files exist: {model_files_exist}) fĂŒr Sample-Generierung gefunden oder kein Training/Eval/Resume durchgefĂŒhrt.")
        
        if can_generate_samples:
            logger.info("Starte Sample-Generierung...")
            # FĂŒr die Sample-Generierung verwenden wir den Output-Ordner des Trainings,
            # da dort alle Komponenten (inkl. der kopierten statischen) liegen sollten.
            path_for_statics_for_generation = str(final_output_dir_path)

            # Wenn nur evaluiert wurde ohne Training und ohne Resume von einem *Trainings*-Checkpoint,
            # könnten statische Komponenten fehlen, falls nicht explizit dorthin kopiert.
            # Aber die Logik oben sollte die statischen Teile in den final_output_dir_path kopieren.
            if not (final_output_dir_path / "ve.safetensors").exists():
                 logger.warning(f"ve.safetensors nicht in {final_output_dir_path} gefunden. Sample-Generierung könnte fehlschlagen oder auf Basismodell zurĂŒckgreifen.")
                 # Fallback, falls die Kopierlogik nicht alles erfasst hat oder do_train=False war
                 if base_model_path_for_static_and_initial_t3_str and \
                    Path(base_model_path_for_static_and_initial_t3_str, "ve.safetensors").exists():
                    path_for_statics_for_generation = base_model_path_for_static_and_initial_t3_str
                    logger.info(f"  Verwende statische Komponenten von: {path_for_statics_for_generation} fĂŒr Sample Generierung")
            
            try:
                generate_and_save_samples_from_checkpoints(
                    output_dir_str=str(final_output_dir_path), # Wo die Checkpoints des aktuellen Laufs liegen
                    data_args_original=data_args,
                    training_args_custom=training_args,
                    base_model_path_for_static_components_str=path_for_statics_for_generation # Wo VE, S3Gen etc. zu finden sind
                )
            except ValueError as ve_path:
                 logger.error(f"Fehler bei der Pfad-Logik in generate_and_save_samples_from_checkpoints: {ve_path}", exc_info=True)
                 overall_sample_dir_name = "generated_samples_all_checkpoints_audio" # Aus der Funktion bekannt
                 logger.info(f"  Absoluter Pfad fĂŒr generierte Samples wĂ€re: {final_output_dir_path / overall_sample_dir_name}")
            except Exception as e_gen_samples:
                 logger.error(f"Genereller Fehler wÀhrend der Sample-Generierung: {e_gen_samples}", exc_info=True)

        else:
            logger.info("Überspringe Sample-Generierung: Bedingungen nicht erfĂŒllt.")
            
    logger.info("Finetuning-Skript beendet.")

if __name__ == "__main__":
    main()

How to start preprocess: python preprocess_data.py --model_load_path "ResembleAI/chatterbox" --is_hub_model --dataset_dir ./audio_data --output_dir ./processed_tts_data_AUGMENTED_WITH_PARAMS_2048_4096 --max_text_len 2048 --max_speech_len 4096 --silence_padding_ms 500 --num_workers 14 --augment_data --aug_speeds 0.9 --aug_volumes_db -2 2

How to start training RTX 5090 (ca. 29GB/32GB VRAM): python finetune_t3_preprocessed.py --output_dir "../checkpoints/chatterbox_GER_finetune_5VARIANTS_2048_4096_CMD" --local_model_dir "../model_hub_cache/ResembleAI_chatterbox" --preprocessed_data_dir "X:\KI\anaconda3\envs\chatterbox_train_env\MyGenesisTTSProject\src\processed_tts_data_5VARIANTS_2048_4096" --original_dataset_dir "./audio_data" --eval_split_size 0.05 --num_train_epochs 20 --per_device_train_batch_size 1 --per_device_eval_batch_size 1 --gradient_accumulation_steps 32 --learning_rate 5e-6 --warmup_steps 200 --logging_strategy "steps" --logging_steps 50 --evaluation_strategy "steps" --eval_steps 200 --save_strategy "steps" --save_steps 200 --save_total_limit 15 --fp16 True --report_to "tensorboard" --dataloader_num_workers 4 --dataloader_pin_memory True --do_train --do_eval --max_text_len 2048 --max_speech_len 4096 --save_safetensors True --load_best_model_at_end True --generate_samples_on_finish True --early_stopping_patience 5

How I save in between to test: (see max steps value its just a little bit more than the last checkpoint) python finetune_t3_preprocessed.py --output_dir "../checkpoints/chatterbox_GER_finetune_5VARIANTS_2048_4096_CMD" --local_model_dir "../model_hub_cache/ResembleAI_chatterbox" --preprocessed_data_dir "X:\KI\anaconda3\envs\chatterbox_train_env\MyGenesisTTSProject\src\processed_tts_data_5VARIANTS_2048_4096" --original_dataset_dir "./audio_data" --eval_split_size 0.05 --num_train_epochs 20 --per_device_train_batch_size 1 --per_device_eval_batch_size 1 --gradient_accumulation_steps 32 --learning_rate 5e-6 --warmup_steps 200 --logging_strategy "steps" --logging_steps 50 --evaluation_strategy "steps" --eval_steps 200 --save_strategy "steps" --save_steps 200 --save_total_limit 15 --fp16 True --report_to "tensorboard" --dataloader_num_workers 4 --dataloader_pin_memory True --do_train --do_eval --max_text_len 2048 --max_speech_len 4096 --save_safetensors True --load_best_model_at_end True --generate_samples_on_finish True --early_stopping_patience 5 --resume_from_checkpoint "../checkpoints/chatterbox_GER_finetune_5VARIANTS_2048_4096_CMD/checkpoint-1000" --max_steps 1010

How to continue: (from last checkpoin) python finetune_t3_preprocessed.py --output_dir "../checkpoints/chatterbox_GER_finetune_5VARIANTS_2048_4096_CMD" --local_model_dir "../model_hub_cache/ResembleAI_chatterbox" --preprocessed_data_dir "X:\KI\anaconda3\envs\chatterbox_train_env\MyGenesisTTSProject\src\processed_tts_data_5VARIANTS_2048_4096" --original_dataset_dir "./audio_data" --eval_split_size 0.05 --num_train_epochs 20 --per_device_train_batch_size 1 --per_device_eval_batch_size 1 --gradient_accumulation_steps 32 --learning_rate 5e-6 --warmup_steps 200 --logging_strategy "steps" --logging_steps 50 --evaluation_strategy "steps" --eval_steps 200 --save_strategy "steps" --save_steps 200 --save_total_limit 15 --fp16 True --report_to "tensorboard" --dataloader_num_workers 4 --dataloader_pin_memory True --do_train --do_eval --max_text_len 2048 --max_speech_len 4096 --save_safetensors True --load_best_model_at_end True --generate_samples_on_finish True --early_stopping_patience 5 --resume_from_checkpoint "../checkpoints/chatterbox_GER_finetune_5VARIANTS_2048_4096_CMD/checkpoint-1000"

havok2-htwo avatar Jun 26 '25 21:06 havok2-htwo

I also edited the tts:

python
# tts.py

from dataclasses import dataclass, field
from pathlib import Path
import logging
from typing import Optional, Tuple, Dict, Any, Union
import inspect # FĂŒr Debugging

import librosa
import numpy as np
import torch
import perth
import torch.nn.functional as F
from huggingface_hub import snapshot_download
from safetensors.torch import load_file

from .models.t3 import T3
from .models.s3tokenizer import S3_SR, drop_invalid_tokens
from .models.s3gen import S3GEN_SR, S3Gen
from .models.tokenizers import EnTokenizer
from .models.voice_encoder import VoiceEncoder
from .models.t3.modules.cond_enc import T3Cond
from .models.t3.modules.t3_config import T3Config # Importiere T3Config


logger = logging.getLogger(__name__)

REPO_ID = "ResembleAI/chatterbox" # Standard-Repo, kann ĂŒberschrieben werden

def punc_norm(text: str) -> str:
    if not text:
        logger.warning("punc_norm erhielt leeren Text. Standardtext wird zurĂŒckgegeben.")
        return "You need to add some text for me to talk."
    if text[0].islower():
        text = text[0].upper() + text[1:]
    text = " ".join(text.split())
    punc_to_replace = [
        ("...", ", "), ("
", ", "), (":", ","), (" - ", ", "), (";", ", "),
        ("—", "-"), ("–", "-"), (" ,", ","), ("“", "\""), ("”", "\""),
        ("‘", "'"), ("’", "'"),
    ]
    for old, new in punc_to_replace:
        text = text.replace(old, new)
    text = text.rstrip(" ")
    if not any(text.endswith(p) for p in {".", "!", "?", "-", ","}):
        text += "."
    return text

@dataclass
class Conditionals:
    t3: T3Cond
    gen: Dict[str, Any]

    def to(self, device: Union[torch.device, str]) -> 'Conditionals':
        target_device = device if isinstance(device, torch.device) else torch.device(device)
        # Nutze die .to() Methode von T3Cond
        self.t3 = self.t3.to(device=target_device)

        for k, v in self.gen.items():
            if torch.is_tensor(v):
                self.gen[k] = v.to(target_device)
        return self

    def save(self, fpath: Path):
        # T3Cond hat eine eigene .save(), aber die speichert __dict__.
        # Wir speichern hier auch __dict__ fĂŒr Konsistenz mit T3Cond.load(**kwargs)
        arg_dict = dict(t3=self.t3.__dict__, gen=self.gen)
        torch.save(arg_dict, fpath)
        logger.info(f"Conditionals gespeichert unter: {fpath}")

    @classmethod
    def load(cls, fpath: Path, map_location: str = "cpu") -> 'Conditionals':
        map_location_device = torch.device(map_location)
        loaded_data = torch.load(fpath, map_location=map_location_device, weights_only=None)
        logger.info(f"Lade Conditionals von: {fpath} auf {map_location_device}")

        t3_cond_attributes_dict = loaded_data['t3']
        if not isinstance(t3_cond_attributes_dict, dict):
            raise TypeError(f"Unerwarteter Typ fĂŒr 't3' Daten in Conditionals: {type(t3_cond_attributes_dict)}. Erwartet dict.")

        processed_t3_cond_attrs = {}
        for key, value in t3_cond_attributes_dict.items():
            if isinstance(value, torch.Tensor):
                processed_t3_cond_attrs[key] = value.to(map_location_device)
            else:
                processed_t3_cond_attrs[key] = value
        
        try:
            t3_instance = T3Cond(**processed_t3_cond_attrs)
        except TypeError as e:
            logger.error(f"Fehler beim Instanziieren von T3Cond mit **kwargs: {e}")
            logger.error(f"VerfĂŒgbare SchlĂŒssel im verarbeiteten T3Cond Daten-Dictionary: {list(processed_t3_cond_attrs.keys())}")
            try:
                sig = inspect.signature(T3Cond)
                logger.error(f"Erwartete Parameter fĂŒr T3Cond: {list(sig.parameters.keys())}")
            except Exception: pass
            raise

        gen_data_dict = loaded_data['gen']
        processed_gen_data_dict = {
            k: v.to(map_location_device) if isinstance(v, torch.Tensor) else v
            for k, v in gen_data_dict.items()
        }
        return cls(t3_instance, processed_gen_data_dict)


class ChatterboxTTS:
    ENC_COND_LEN = 6 * S3_SR
    DEC_COND_LEN = 10 * S3GEN_SR

    def __init__(
        self,
        t3: T3,
        s3gen: S3Gen,
        ve: VoiceEncoder,
        tokenizer: EnTokenizer,
        device: torch.device,
        conds: Optional[Conditionals] = None,
        s3_tokenizer_sample_rate: int = S3_SR
    ):
        self.sr = S3GEN_SR
        self.t3 = t3 # Sollte bereits die korrekte T3Config (hp Attribut) haben
        self.s3gen = s3gen
        self.ve = ve
        self.tokenizer = tokenizer
        self.device = device
        self.conds = conds
        self.watermarker = perth.PerthImplicitWatermarker()
        self.s3_tokenizer_sample_rate = s3_tokenizer_sample_rate
        logger.info(f"ChatterboxTTS initialisiert auf GerÀt: {self.device}. T3 Config Max Text: {getattr(self.t3.hp, 'max_text_tokens', 'N/A')}, Max Speech: {getattr(self.t3.hp, 'max_speech_tokens', 'N/A')}")

    def to(self, device: Union[torch.device, str]) -> 'ChatterboxTTS':
        target_device = device if isinstance(device, torch.device) else torch.device(device)
        if self.device == target_device: # Bereits auf dem ZielgerÀt
            return self
        logger.info(f"Verschiebe ChatterboxTTS von {self.device} auf GerÀt: {target_device}")
        self.device = target_device
        self.t3.to(target_device)
        self.s3gen.to(target_device)
        self.ve.to(target_device)
        if self.conds:
            self.conds.to(target_device)
        return self

    @classmethod
    def from_local(cls,
                   ckpt_dir_for_model_weights: str,
                   base_model_dir_for_static_components: str,
                   device: str, # Kommt als String vom Aufrufer
                   target_t3_config_for_model: Optional[T3Config] = None # GeÀndert: target_t3_config_for_MODEL
                   ) -> 'ChatterboxTTS':

        logger.info(f"ChatterboxTTS.from_local:")
        logger.info(f"  Lade T3-spezifische Gewichte von: {ckpt_dir_for_model_weights}")
        logger.info(f"  Lade statische Komponenten von: {base_model_dir_for_static_components}")
        if target_t3_config_for_model:
            logger.info(f"  Verwende ĂŒbergebene target_t3_config fĂŒr T3-Initialisierung: "
                        f"max_text={target_t3_config_for_model.max_text_tokens}, "
                        f"max_speech={target_t3_config_for_model.max_speech_tokens}")
        else:
            logger.info("  Keine target_t3_config ĂŒbergeben, T3 wird mit Default-Config initialisiert.")

        target_device_obj = torch.device(device)
        if target_device_obj.type == "cuda" and not torch.cuda.is_available():
            logger.warning("CUDA angefordert, aber nicht verfĂŒgbar. Fallback auf CPU.")
            target_device_obj = torch.device("cpu")
        elif target_device_obj.type == "mps" and not torch.backends.mps.is_available(): # Korrigiert zu target_device_obj.type
            logger.warning("MPS angefordert, aber nicht verfĂŒgbar. Fallback auf CPU.")
            target_device_obj = torch.device("cpu")

        path_for_t3_specific_weights = Path(ckpt_dir_for_model_weights).resolve()
        path_for_static_files = Path(base_model_dir_for_static_components).resolve()

        # Validierung der Pfade
        if not path_for_static_files.is_dir():
            raise FileNotFoundError(f"Basis-Modellpfad fĂŒr statische Komponenten nicht gefunden: {path_for_static_files}")
        if not path_for_t3_specific_weights.is_dir() and not path_for_t3_specific_weights.is_file() : # Kann auch eine Datei sein
             # PrĂŒfen ob es eine Datei ist, falls ckpt_dir_for_model_weights direkt auf t3_cfg.safetensors zeigt
            if not (path_for_t3_specific_weights.is_file() and \
                any(str(path_for_t3_specific_weights).endswith(x) for x in [".safetensors", ".bin"])):
                 raise FileNotFoundError(f"Pfad fĂŒr T3-Modellgewichte weder Verzeichnis noch gĂŒltige Datei: {path_for_t3_specific_weights}")


        load_device_cpu_str = "cpu" # FĂŒr safetensors device-Argument

        # Statische Komponenten laden (VE, S3Gen, Tokenizer) - immer von base_model_dir...
        ve = VoiceEncoder()
        ve.load_state_dict(load_file(path_for_static_files / "ve.safetensors", device=load_device_cpu_str))
        ve.eval()

        s3gen = S3Gen()
        s3gen.load_state_dict(load_file(path_for_static_files / "s3gen.safetensors", device=load_device_cpu_str), strict=False)
        s3gen.eval()

        tokenizer = EnTokenizer(str(path_for_static_files / "tokenizer.json"))

        # Conditionals laden - immer von base_model_dir...
        conds = None
        conds_path_static = path_for_static_files / "conds.pt"
        if conds_path_static.exists():
            try:
                conds = Conditionals.load(conds_path_static, map_location=load_device_cpu_str)
            except Exception as e_conds:
                logger.warning(f"Konnte conds.pt nicht laden von {conds_path_static}: {e_conds}", exc_info=True)
        else:
            logger.info(f"Keine conds.pt in {path_for_static_files} (Basis fĂŒr statische Komponenten) gefunden.")

        # --- T3 Initialisierung und Gewichte laden ---
        # Initialisiere T3 mit der Ziel-Konfiguration, falls gegeben, sonst mit Default aus T3Config.py.
        # Diese Instanz (final_t3_instance) wird diejenige sein, die zurĂŒckgegeben wird.
        if target_t3_config_for_model:
            final_t3_instance = T3(hp=target_t3_config_for_model)
        else:
            final_t3_instance = T3() # Verwendet Default-Config aus T3Config.py

        # Finde T3-Gewichtsdatei
        t3_weights_file_to_load: Optional[Path] = None
        if path_for_t3_specific_weights.is_file() and any(str(path_for_t3_specific_weights).endswith(x) for x in [".safetensors", ".bin"]):
            t3_weights_file_to_load = path_for_t3_specific_weights
            logger.info(f"  T3-Gewichtsdatei direkt als Pfad angegeben: {t3_weights_file_to_load.name}")
        elif path_for_t3_specific_weights.is_dir():
            potential_ckpt_names = ["t3_cfg.safetensors", "model.safetensors", "pytorch_model.bin"]
            for name in potential_ckpt_names:
                if (path_for_t3_specific_weights / name).exists():
                    t3_weights_file_to_load = path_for_t3_specific_weights / name
                    logger.info(f"  T3-Gewichtsdatei in '{path_for_t3_specific_weights.name}' gefunden: {t3_weights_file_to_load.name}")
                    break
        
        # Fallback, wenn ckpt_dir == base_dir und dort eine t3_cfg.safetensors liegt (typisch fĂŒr Basismodell)
        if not t3_weights_file_to_load and path_for_t3_specific_weights == path_for_static_files:
            if (path_for_static_files / "t3_cfg.safetensors").exists():
                 t3_weights_file_to_load = path_for_static_files / "t3_cfg.safetensors"
                 logger.info(f"  Keine spezifische T3-Datei in '{path_for_t3_specific_weights.name}', verwende 't3_cfg.safetensors' aus Basis: {path_for_static_files.name}")

        if not t3_weights_file_to_load:
            raise FileNotFoundError(f"Keine passende T3-Gewichtsdatei in {path_for_t3_specific_weights} oder als Fallback in {path_for_static_files} gefunden.")

        logger.info(f"  Lade T3-Gewichte von: {t3_weights_file_to_load} auf CPU.")
        if str(t3_weights_file_to_load).endswith(".safetensors"):
            t3_state_raw = load_file(t3_weights_file_to_load, device=load_device_cpu_str)
        elif str(t3_weights_file_to_load).endswith(".bin"):
            t3_state_raw = torch.load(t3_weights_file_to_load, map_location=load_device_cpu_str)
        else:
            raise ValueError(f"Unbekanntes Dateiformat fĂŒr T3-Gewichte: {t3_weights_file_to_load}")

        # PrĂ€fix-Bereinigung fĂŒr HF Trainer Modelle
        cleaned_t3_state = {}
        if t3_weights_file_to_load.name in ["model.safetensors", "pytorch_model.bin"]: # Namen, die typisch fĂŒr HF Trainer sind
            if any(key.startswith("t3.") for key in t3_state_raw.keys()):
                logger.info("  Entferne 't3.' PrĂ€fix aus T3 State Dict SchlĂŒsseln.")
                for k, v_tensor in t3_state_raw.items():
                    if k.startswith("t3."): cleaned_t3_state[k[3:]] = v_tensor
                if not cleaned_t3_state: # Fallback, falls keine SchlĂŒssel passten
                    logger.warning("  PrĂ€fix-Entfernung ergab leeres State Dict fĂŒr T3, verwende rohes State Dict.")
                    cleaned_t3_state = t3_state_raw
            else: # Keine "t3." PrÀfixe, aber es ist eine HF Trainer Datei
                logger.info(f"  Verwende State Dict aus '{t3_weights_file_to_load.name}' direkt fĂŒr T3 (kein 't3.' PrĂ€fix gefunden).")
                cleaned_t3_state = t3_state_raw
        else: # z.B. t3_cfg.safetensors, sollte bereits korrekte SchlĂŒssel haben
            logger.info(f"  Verwende T3 State Dict aus '{t3_weights_file_to_load.name}' direkt.")
            cleaned_t3_state = t3_state_raw

        if not cleaned_t3_state:
            raise ValueError(f"T3 State Dict ist leer nach Lade- und Bereinigungsversuchen fĂŒr {t3_weights_file_to_load}.")

        # --- ANPASSUNG FÜR POSITIONAL EMBEDDINGS ---
        # `final_t3_instance` wurde mit der Ziel-Config (oder Default) initialisiert.
        # Ihre Embedding-Layer haben also die Ziel-Shapes.
        pos_emb_map = {
            "text_pos_emb.emb.weight": final_t3_instance.text_pos_emb.emb if hasattr(final_t3_instance, "text_pos_emb") else None,
            "speech_pos_emb.emb.weight": final_t3_instance.speech_pos_emb.emb if hasattr(final_t3_instance, "speech_pos_emb") else None
        }

        for emb_key_in_checkpoint, target_embedding_layer in pos_emb_map.items():
            if target_embedding_layer is None:
                logger.debug(f"Ziel-Embedding-Layer fĂŒr '{emb_key_in_checkpoint}' nicht in final_t3_instance gefunden. Überspringe Anpassung.")
                continue
            
            if emb_key_in_checkpoint in cleaned_t3_state:
                ckpt_emb_weights = cleaned_t3_state[emb_key_in_checkpoint]
                target_emb_weights_tensor = target_embedding_layer.weight # nn.Embedding.weight

                ckpt_shape = ckpt_emb_weights.shape
                target_shape = target_emb_weights_tensor.shape

                if ckpt_shape[0] != target_shape[0]: # Unterschied in der SequenzlÀngen-Dimension
                    logger.warning(
                        f"Passe GrĂ¶ĂŸe von Positional Embedding '{emb_key_in_checkpoint}' an: "
                        f"Checkpoint-Form {ckpt_shape} -> Modell-Form {target_shape}."
                    )
                    # Erstelle eine Kopie der initialisierten Gewichte des Zielmodells
                    # Diese werden dann teilweise mit den Checkpoint-Gewichten ĂŒberschrieben.
                    new_weights_for_model = target_emb_weights_tensor.data.clone()
                    
                    len_to_copy = min(ckpt_shape[0], target_shape[0])
                    # Die Embedding-Dimension (dim=1) sollte ĂŒbereinstimmen.
                    if ckpt_shape[1] != target_shape[1]:
                        logger.error(f"FATAL: Embedding-Dimension Mismatch fĂŒr '{emb_key_in_checkpoint}'. "
                                     f"Checkpoint: {ckpt_shape[1]}, Modell: {target_shape[1]}. Laden wird wahrscheinlich fehlschlagen.")
                        # Hier könnte man auch abbrechen oder versuchen, nur zu kopieren, wenn dim gleich ist.
                        # FĂŒrs Erste machen wir weiter und hoffen, dass load_state_dict den Fehler fĂ€ngt, falls Dims nicht passen.
                    
                    dim_to_copy = min(ckpt_shape[1], target_shape[1]) # Nur zur Sicherheit
                    
                    new_weights_for_model[:len_to_copy, :dim_to_copy] = ckpt_emb_weights[:len_to_copy, :dim_to_copy]
                    cleaned_t3_state[emb_key_in_checkpoint] = new_weights_for_model
                    logger.info(f"  '{emb_key_in_checkpoint}' angepasst auf LĂ€nge {len_to_copy}.")
                # Wenn Shapes gleich sind, keine Aktion.
            else:
                logger.warning(f"Positions-Embedding '{emb_key_in_checkpoint}' nicht im Checkpoint-State-Dict gefunden. "
                               "Wird im Modell initialisiert bleiben.")
        # --- ENDE DER ANPASSUNG ---

        missing_keys, unexpected_keys = final_t3_instance.load_state_dict(cleaned_t3_state, strict=False)
        if missing_keys: logger.warning(f"  Fehlende SchlĂŒssel beim Laden des T3 State Dict: {missing_keys}")
        if unexpected_keys: logger.warning(f"  Unerwartete SchlĂŒssel beim Laden des T3 State Dict: {unexpected_keys}")
        final_t3_instance.eval() # T3 ist jetzt auf CPU mit (hoffentlich) korrekten Gewichten

        # Erstelle ChatterboxTTS Instanz mit allen Komponenten auf CPU
        model_instance = cls(final_t3_instance, s3gen, ve, tokenizer, torch.device(load_device_cpu_str), conds=conds)
        
        # Verschiebe das GESAMTE Modell am Ende auf das ZielgerÀt
        return model_instance.to(target_device_obj)


    @classmethod
    def from_pretrained(cls, device: str, target_t3_config_for_model: Optional[T3Config] = None) -> 'ChatterboxTTS': # target_config hier auch
        logger.info(f"Lade vortrainiertes Modell '{REPO_ID}' vom Hugging Face Hub.")
        try:
            snapshot_dir = snapshot_download(
                repo_id=REPO_ID,
                allow_patterns=["*.safetensors", "*.json", "*.pt", "config.json"] # config.json fĂŒr T3Config
            )
            logger.info(f"Modell-Snapshot heruntergeladen nach: {snapshot_dir}")
        except Exception as e:
            logger.error(f"Fehler beim Herunterladen des Modells vom Hub {REPO_ID}: {e}", exc_info=True)
            raise
        
        # Wenn target_t3_config_for_model nicht explizit ĂŒbergeben wurde,
        # versuchen wir, die Config aus dem heruntergeladenen Snapshot zu laden,
        # falls der User es nicht schon getan hat.
        # Wenn target_t3_config_for_model hier None ist, wird T3 mit Default initialisiert.
        # Das ist okay, wenn der User danach in finetune die LĂ€ngen anpasst.
        # Die Anpassung der PosEmbs in from_local wird dann mit der Default-T3-Config arbeiten.
        # Ideal: Der Aufrufer von from_pretrained gibt die Ziel-Config an.
        if target_t3_config_for_model is None:
            logger.info("Keine explizite target_t3_config fĂŒr from_pretrained ĂŒbergeben. "
                        "T3 wird mit Default-Config initialisiert, es sei denn, der Basis-Snapshot enthÀlt eine ladbare Config "
                        "(was from_local nicht automatisch tut, es sei denn, target_t3_config wird ĂŒbergeben).")


        return cls.from_local(
            ckpt_dir_for_model_weights=snapshot_dir, # T3-Gewichte kommen auch von hier
            base_model_dir_for_static_components=snapshot_dir, # Statische Teile auch
            device=device,
            target_t3_config_for_model=target_t3_config_for_model # Weitergeben
        )

    def prepare_conditionals(self, wav_fpath: str, exaggeration: float = 0.5) -> Optional[Conditionals]:
        logger.info(f"Bereite Konditionierung mit Audio-Prompt vor: {wav_fpath}, Exaggeration: {exaggeration}")
        try:
            s3gen_ref_wav, sr_orig = librosa.load(wav_fpath, sr=None)
            if sr_orig != self.sr: # self.sr ist S3GEN_SR
                s3gen_ref_wav = librosa.resample(s3gen_ref_wav, orig_sr=sr_orig, target_sr=self.sr)

            s3gen_ref_wav_tensor = torch.from_numpy(s3gen_ref_wav)
            if s3gen_ref_wav_tensor.shape[0] < self.DEC_COND_LEN:
                logger.warning(f"Audio-Prompt {wav_fpath} kĂŒrzer ({s3gen_ref_wav_tensor.shape[0]}) als DEC_COND_LEN ({self.DEC_COND_LEN}). Padding.")
                s3gen_ref_wav_cut_tensor = F.pad(s3gen_ref_wav_tensor, (0, self.DEC_COND_LEN - s3gen_ref_wav_tensor.shape[0]))
            else:
                s3gen_ref_wav_cut_tensor = s3gen_ref_wav_tensor[:self.DEC_COND_LEN]
            s3gen_ref_wav_cut_np = s3gen_ref_wav_cut_tensor.numpy()

            s3gen_ref_dict = self.s3gen.embed_ref(s3gen_ref_wav_cut_np, self.sr, device=self.device)

            if self.s3_tokenizer_sample_rate != self.sr:
                ref_16k_wav_np = librosa.resample(s3gen_ref_wav, orig_sr=self.sr, target_sr=self.s3_tokenizer_sample_rate)
            else:
                ref_16k_wav_np = s3gen_ref_wav

            t3_cond_prompt_tokens = None
            # Verwende die Config der aktuellen T3-Instanz (self.t3.hp)
            t3_config_current = self.t3.hp # Dies sollte die korrekt initialisierte T3Config sein
            plen = getattr(t3_config_current, 'speech_cond_prompt_len', 0)

            if plen > 0:
                s3_tokzr = self.s3gen.tokenizer
                ref_16k_wav_tensor_prompt = torch.from_numpy(ref_16k_wav_np)
                if ref_16k_wav_tensor_prompt.shape[0] < self.ENC_COND_LEN:
                    prompt_audio_segment_tensor = F.pad(ref_16k_wav_tensor_prompt, (0, self.ENC_COND_LEN - ref_16k_wav_tensor_prompt.shape[0]))
                else:
                    prompt_audio_segment_tensor = ref_16k_wav_tensor_prompt[:self.ENC_COND_LEN]
                prompt_audio_segment_np = prompt_audio_segment_tensor.numpy()

                tokens_from_16k_batch, _ = s3_tokzr.forward([prompt_audio_segment_np], max_len=plen)
                if tokens_from_16k_batch is not None and tokens_from_16k_batch.numel() > 0:
                    t3_cond_prompt_tokens = torch.atleast_2d(tokens_from_16k_batch).to(self.device)
                else:
                    t3_cond_prompt_tokens = torch.zeros((1, plen), dtype=torch.long, device=self.device)
            else:
                logger.info("T3 speech_cond_prompt_len ist 0. Keine Sprachprompt-Tokens fĂŒr T3 Konditionierung.")

            ve_embed_np = self.ve.embeds_from_wavs([ref_16k_wav_np], sample_rate=self.s3_tokenizer_sample_rate)
            ve_embed = torch.from_numpy(ve_embed_np).mean(axis=0, keepdim=True).to(self.device)
            emotion_tensor = exaggeration * torch.ones(1, 1, 1, device=self.device)

            t3_cond_instance = T3Cond(
                speaker_emb=ve_embed,
                cond_prompt_speech_tokens=t3_cond_prompt_tokens,
                emotion_adv=emotion_tensor
                # clap_emb und cond_prompt_speech_emb bleiben None (Default aus T3Cond)
            ).to(device=self.device) # Sicherstellen, dass die T3Cond Instanz auf dem GerÀt ist

            processed_s3gen_ref_dict = {
                k: v.to(self.device) if isinstance(v, torch.Tensor) else v
                for k, v in s3gen_ref_dict.items()
            }

            new_conds = Conditionals(t3_cond_instance, processed_s3gen_ref_dict)
            self.conds = new_conds # Wichtig: self.conds aktualisieren
            logger.info(f"Konditionierung erfolgreich vorbereitet und in Instanz gespeichert.")
            return new_conds
        except Exception as e:
            logger.error(f"Fehler in prepare_conditionals fĂŒr {wav_fpath}: {e}", exc_info=True)
            self.conds = None
            return None

    def generate(
        self,
        text: str,
        audio_prompt_path: Optional[str] = None,
        exaggeration: Optional[float] = None,
        cfg_weight: float = 0.5,
        temperature: float = 0.8,
    ) -> Tuple[torch.Tensor, int]:
        if exaggeration is None: exaggeration = 0.5

        current_conds_to_use = self.conds

        if audio_prompt_path:
            prepared_conds = self.prepare_conditionals(audio_prompt_path, exaggeration=exaggeration)
            if prepared_conds is None:
                raise RuntimeError(f"Konnte Konditionierung nicht mit {audio_prompt_path} vorbereiten.")
            current_conds_to_use = prepared_conds
        elif not current_conds_to_use:
            raise ValueError("Keine Konditionierung vorhanden. Bitte audio_prompt_path angeben oder Default laden.")
        elif current_conds_to_use.t3.emotion_adv is None or exaggeration != current_conds_to_use.t3.emotion_adv.item():
            logger.info(f"Aktualisiere Exaggeration in bestehenden Konditionierungen auf: {exaggeration}")
            _cond_t3_old: T3Cond = current_conds_to_use.t3
            new_emotion_adv = exaggeration * torch.ones(1, 1, 1, device=self.device)
            current_conds_to_use.t3 = T3Cond( # Erstelle neue Instanz, da T3Cond ein dataclass ist
                speaker_emb=_cond_t3_old.speaker_emb,
                clap_emb=_cond_t3_old.clap_emb,
                cond_prompt_speech_tokens=_cond_t3_old.cond_prompt_speech_tokens,
                cond_prompt_speech_emb=_cond_t3_old.cond_prompt_speech_emb,
                emotion_adv=new_emotion_adv # Aktualisiert
            ).to(device=self.device)

        current_conds_to_use = current_conds_to_use.to(self.device) # Sicherstellen

        text_normalized = punc_norm(text)
        text_tokens_single = self.tokenizer.text_to_tokens(text_normalized).to(self.device)
        text_tokens_batch = text_tokens_single.unsqueeze(0)

        if cfg_weight > 0.0:
            text_tokens_batch = torch.cat([text_tokens_batch, text_tokens_batch], dim=0)

        # Verwende die Config der aktuellen T3-Instanz
        t3_config_current = self.t3.hp
        sot = getattr(t3_config_current, 'start_text_token', 0)
        eot = getattr(t3_config_current, 'stop_text_token', 1)
        if not (hasattr(t3_config_current, 'start_text_token')): # Einfacherer Check
             logger.warning("Standard SOT/EOT fĂŒr Text verwendet, da nicht in T3Config gefunden.")

        text_tokens_padded = F.pad(text_tokens_batch, (1, 0), value=sot)
        text_tokens_padded = F.pad(text_tokens_padded, (0, 1), value=eot)

        with torch.inference_mode():
            t3_cond_for_inference: T3Cond = current_conds_to_use.t3
            current_batch_size = text_tokens_padded.shape[0]

            # Batch-GrĂ¶ĂŸen der Konditionierungs-Tensoren anpassen
            for attr_name in ["speaker_emb", "clap_emb", "cond_prompt_speech_tokens", "cond_prompt_speech_emb", "emotion_adv"]:
                attr_val = getattr(t3_cond_for_inference, attr_name)
                if isinstance(attr_val, torch.Tensor) and attr_val.shape[0] == 1 and current_batch_size > 1:
                    # Wenn Batch-Dimension 1 ist aber current_batch_size > 1 (z.B. fĂŒr CFG), dann repeate.
                    num_repeats = [current_batch_size] + [1] * (attr_val.dim() - 1)
                    setattr(t3_cond_for_inference, attr_name, attr_val.repeat(*num_repeats))
                elif isinstance(attr_val, torch.Tensor) and attr_val.shape[0] != current_batch_size:
                    # Dieser Fall sollte idealerweise nicht auftreten, wenn der Prompt schon fĂŒr CFG vorbereitet wurde.
                    logger.warning(f"Unerwarteter Batch-Size Mismatch fĂŒr {attr_name} in T3Cond. "
                                   f"Tensor-Shape: {attr_val.shape}, Erwartet Batch: {current_batch_size}")


            plen_from_hp = getattr(t3_config_current, 'speech_cond_prompt_len', 0)
            if t3_cond_for_inference.cond_prompt_speech_tokens is None and plen_from_hp > 0:
                dummy_prompt = torch.zeros((current_batch_size, plen_from_hp), dtype=torch.long, device=self.device)
                t3_cond_for_inference.cond_prompt_speech_tokens = dummy_prompt
            
            max_speech_tokens = getattr(t3_config_current, 'max_speech_tokens', 1000)
            max_new_tokens = max_speech_tokens - (plen_from_hp if t3_cond_for_inference.cond_prompt_speech_tokens is not None else 0)
            max_new_tokens = max(50, max_new_tokens)

            speech_tokens_batched = self.t3.inference(
                t3_cond=t3_cond_for_inference,
                text_tokens=text_tokens_padded,
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                cfg_weight=cfg_weight,
            )

            speech_tokens_single = speech_tokens_batched[0] # Erster Teil fĂŒr CFG
            speech_tokens_clean = drop_invalid_tokens(speech_tokens_single.squeeze()).to(self.device)

            if speech_tokens_clean.numel() == 0:
                wav_np = np.zeros(int(self.sr * 0.5), dtype=np.float32) # Kurze Stille
            else:
                s3gen_conds = current_conds_to_use.gen
                # Sicherstellen, dass s3gen_conds auf dem richtigen GerÀt sind
                for k_gen, v_gen in s3gen_conds.items():
                    if isinstance(v_gen, torch.Tensor): s3gen_conds[k_gen] = v_gen.to(self.device)
                
                wav_batch, _ = self.s3gen.inference(
                    speech_tokens=speech_tokens_clean.unsqueeze(0),
                    ref_dict=s3gen_conds,
                )
                wav_np = wav_batch.squeeze(0).detach().cpu().numpy()

            watermarked_wav_np = self.watermarker.apply_watermark(wav_np, sample_rate=self.sr)
        return torch.from_numpy(watermarked_wav_np).unsqueeze(0), self.sr

and t3_config.py:

python
from ..llama_configs import LLAMA_CONFIGS


class T3Config:
    start_text_token = 255
    stop_text_token = 0
    text_tokens_dict_size = 704
    max_text_tokens = 2048

    start_speech_token = 6561
    stop_speech_token = 6562
    speech_tokens_dict_size = 8194
    max_speech_tokens = 4096

    llama_config_name = "Llama_520M"
    input_pos_emb = "learned"
    speech_cond_prompt_len = 10 #was 150, allows shorter input samples

    # For T3CondEnc
    encoder_type = "voice_encoder"
    speaker_embed_size = 256
    use_perceiver_resampler = True
    emotion_adv = True

    @property
    def n_channels(self):
        return LLAMA_CONFIGS[self.llama_config_name]["hidden_size"]

My training is currencly running, I don't know if my code helps you here, but I added a little bit of memory management, so its more efficient. Maybe you had to reduce the sample and token lenght from 2048/4096 to lower values by editing the function call parameters to get it running on 24GB or 16GB.

havok2-htwo avatar Jun 26 '25 21:06 havok2-htwo

how big is the batch influence on vram usage? I am thinking about just renting a H100 with 80 GB Vram or something like that to train if we have a dataset which is good enough

I have around 28GB VRAM with Batch Size of 1. I think with 80GB you are ready to go with higher values mate.

havok2-htwo avatar Jun 26 '25 21:06 havok2-htwo

Great @havok2-htwo @rotatorotator can I use it for Arabic? can you guide me about data and steps?

Sure, I believe it should work if you have enough speech samples. I'm currently redesigning a lot of my code based on Gemini’s suggestions. I used the training code as a base from this repository, but I found a few issues in the tts.py file and also while reading through the cond_enc.py. Gemini provided me with the correct values for this model and the proper tokens for the start and stop points. I hope this will help me fix the issues and run a proper fine-tuning session to fully leverage the potential of this base model. I also realized that I trained for 128 epochs, but that was way too much. I actually achieved the best results at epoch 5, after which the training suffered from overfitting. So I will address that as well. I've generated a sample where you can hear the result, but there are some issues at the end that still need improvement. (https://jmp.sh/s/71S206JyeoDuZa0YK2ve)

Thanks. Could you share step by step of your finetuning code?

I'll hope so :D first I cloned this repo here. Then I edited and added some code. May you had to install some additional moduls by pip and so on, I think I have installed two or thee things. Then you need a ...anaconda3\envs\chatterbox_train_env\MyGenesisTTSProject\src\audio_data folder with wav and txt pairs. So the audio file and the transcription of the voice. They need the same name here. I have generated aprox 12k samples with elevenlaps API, filtered it by hand and deleted unwanted audio/txt pairs.

Then you had to run the preprocess_data.py like I have described above. May you have to adjust some parameters to match folders and so on. This tool creates a ton of .pt files. Also adds 300ms or whatever you put in (I used 500ms) silence in front and end of the audiofiles before converting them. Also it creates variations, louder more quiet, slower, faster, whatever you enter, but pitch up/down and faster sounds not good, so I leaved them out.

Then your cpu creates these files. This process is much more efficient. Otherwise in each epoch your cpu had to do this again and again and your gpu idle around a lot.

After that you can run the training also like described above by starting the finetune_t3_preprocessed.py file. In the end you have somewhere here: ...anaconda3\envs\chatterbox_train_env\MyGenesisTTSProject\checkpoints\chatterbox_GER_finetune_5VARIANTS_2048_4096_CMD the t3_cfg.safetensors and so on. I copied them into my other project where I used chatterbox_stream to access this model. But it wont work out of the box, I also had to edit some files in the model. But may its a good Idea to just ask gemeni, it always helps me to create or edit code. Just drop the files and say what you need :D

may this helps. I am not a pro, just have access to all the stuff, from AI Pro Tools to 5090 and a tiny cluster of 4x3090 at work. Good luck!

havok2-htwo avatar Jun 26 '25 21:06 havok2-htwo

@havok2-htwo thank you very much for your code i create a new code
preprocess_data.py , finetune_t3_preprocessed and tts.py , t3_config.py replace

i success run the code for preprocess_data.py , i run with --metadata_file because i use metadata.csv

Image

now when i run the other code finetune_t3_preprocessed

  File "C:\PythonApps\ChatterboxNew\chatterbox\src\finetune_t3_preprocessed.py", line 312, in forward
    loss_text, loss_speech, speech_logits = self.t3.loss(
TypeError: T3.loss() got an unexpected keyword argument 'labels_text'

i see also in t3.py there is not value labels_text https://github.com/resemble-ai/chatterbox/blob/eb90621fa748f341a5b768aed0c0c12fc561894b/src/chatterbox/models/t3/t3.py#L167

can please tell me how fix thank you very much

lpscr avatar Jun 27 '25 18:06 lpscr

Hey this is mine,

python
# Copyright (c) 2025 Resemble AI
# MIT License
import logging
from typing import Union, Optional, List

from tqdm import tqdm
import torch
import torch.nn.functional as F
from torch import nn, Tensor
from transformers import LlamaModel, LlamaConfig
from transformers.generation.logits_process import TopPLogitsWarper, RepetitionPenaltyLogitsProcessor

from .modules.learned_pos_emb import LearnedPositionEmbeddings

from .modules.cond_enc import T3CondEnc, T3Cond
from .modules.t3_config import T3Config
from .llama_configs import LLAMA_CONFIGS
from .inference.t3_hf_backend import T3HuggingfaceBackend
from .inference.alignment_stream_analyzer import AlignmentStreamAnalyzer


logger = logging.getLogger(__name__)


class AttrDict(dict):
    def __init__(self, *args, **kwargs):
        super(AttrDict, self).__init__(*args, **kwargs)
        self.__dict__ = self


def _ensure_BOT_EOT(text_tokens: Tensor, hp):
    B = text_tokens.size(0)
    assert (text_tokens == hp.start_text_token).int().sum() >= B, "missing start_text_token"
    assert (text_tokens == hp.stop_text_token).int().sum() >= B, "missing stop_text_token"


class T3(nn.Module):
    """
    Token-To-Token (T3) TTS model using huggingface transformer models as backbones,
        * tokenization, including start / stop tokens are always added externally to this class
        * conditioning data like CLAP, emotion, etc are all in a separate file for more modularity
        * careful! this class assumes relative positional encoding -- with absolute PE, we would at
            least want to reset the position to 0 when speech tokens begin, and optionally use a
            different PE embedding space for speech.
    """

    def __init__(self, hp=T3Config()):
        super().__init__()
        self.hp = hp
        self.cfg = LlamaConfig(**LLAMA_CONFIGS[hp.llama_config_name])
        self.tfmr = LlamaModel(self.cfg)
        self.dim = self.cfg.hidden_size
        self.deepspeed_patch_applied = False

        # conditioning / embedding
        self.cond_enc = T3CondEnc(hp)
        self.text_emb = nn.Embedding(hp.text_tokens_dict_size, self.dim)
        self.speech_emb = nn.Embedding(hp.speech_tokens_dict_size, self.dim)

        # custom position embedding
        if hp.input_pos_emb == "learned":
            max_text_seq_len = hp.max_text_tokens + 2
            self.text_pos_emb = LearnedPositionEmbeddings(max_text_seq_len, self.dim)

            max_mel_seq_len = hp.max_speech_tokens + 2 + 2
            self.speech_pos_emb = LearnedPositionEmbeddings(max_mel_seq_len, self.dim)

        # logit projection
        self.text_head = nn.Linear(self.cfg.hidden_size, hp.text_tokens_dict_size, bias=False)
        self.speech_head = nn.Linear(self.cfg.hidden_size, hp.speech_tokens_dict_size, bias=False)
        self.compiled = False

    @property
    def device(self):
        return self.speech_head.weight.device

    def prepare_conditioning(self, t3_cond: T3Cond):
        """
        Token cond data needs to be embedded, so that needs to be here instead of in `T3CondEnc`.
        """
        if t3_cond.cond_prompt_speech_tokens is not None and t3_cond.cond_prompt_speech_emb is None:
            t3_cond.cond_prompt_speech_emb = self.speech_emb(t3_cond.cond_prompt_speech_tokens) + \
                self.speech_pos_emb(t3_cond.cond_prompt_speech_tokens)
        return self.cond_enc(t3_cond)  # (B, len_cond, dim)

    def prepare_input_embeds(
        self,
        *,
        t3_cond: T3Cond,
        text_tokens: torch.LongTensor,
        speech_tokens: torch.LongTensor,
        cfg_weight: float = 0.0,
    ):
        # prepare input embeddings (skip backbone tranformer embeddings)
        cond_emb = self.prepare_conditioning(t3_cond)  # (B, len_cond, dim)
        text_emb = self.text_emb(text_tokens)  # (B, len_text, dim)
        if cfg_weight > 0.0:
            text_emb[1].zero_()  # CFG uncond

        speech_emb = self.speech_emb(speech_tokens)  # (B, len_speech, dim)
        if self.hp.input_pos_emb == "learned":
            text_emb = text_emb + self.text_pos_emb(text_tokens)
            speech_emb = speech_emb + self.speech_pos_emb(speech_tokens)
        len_cond = cond_emb.size(1)

        if cond_emb.size(0) != text_emb.size(0):
             cond_emb = cond_emb.expand(text_emb.size(0), -1, -1)

        # concat
        embeds = torch.stack([
            torch.cat((ce, te, se))
            for ce, te, se in zip(cond_emb, text_emb, speech_emb)
        ])  # (B, length, dim)
        return embeds, len_cond

    def forward(
        self,
        *,
        t3_cond: T3Cond,
        text_tokens: torch.LongTensor,
        text_token_lens: torch.LongTensor,
        speech_tokens: torch.LongTensor,
        speech_token_lens: torch.LongTensor,
        training=False,
    ):
        _ensure_BOT_EOT(text_tokens, self.hp)

        # prepare custom input embeds
        embeds, len_cond = self.prepare_input_embeds(
            t3_cond=t3_cond,
            text_tokens=text_tokens,
            speech_tokens=speech_tokens,
        )

        # backbone tranformer forward
        tfmr_out = self.tfmr.forward(
            input_ids=None,
            # position_ids=position_ids, # TODO? ROPE should be fine?
            inputs_embeds=embeds,
            output_hidden_states=True,
            return_dict=True,
            use_cache=(not training),
        )
        hidden_states = tfmr_out.hidden_states[-1]  # final tfmr layer output, (B, seq, dim)

        # post-processing: splice out text and speech parts of hidden states
        len_text = text_tokens.size(1)
        len_speech = speech_tokens.size(1)
        B, _, dim = hidden_states.shape
        device, dtype = hidden_states.device, hidden_states.dtype
        text_latents = torch.zeros(B, len_text, dim, dtype=dtype, device=device)
        speech_latents = torch.zeros(B, len_speech, dim, dtype=dtype, device=device)
        ttl, stl = text_token_lens, speech_token_lens
        for i in range(B):
            text_end = len_cond + ttl[i].item()
            speech_start = len_cond + text_tokens.size(1)
            speech_end = speech_start + stl[i].item()
            text_latents[i, :ttl[i]] = hidden_states[i, len_cond:text_end]
            speech_latents[i, :stl[i]] = hidden_states[i, speech_start:speech_end]

        # logit projection
        text_logits = self.text_head(text_latents)
        speech_logits = self.speech_head(speech_latents)

        return AttrDict(
            text_logits=text_logits,
            text_latents=text_latents,
            speech_logits=speech_logits,
            speech_latents=speech_latents,
            hidden_states=hidden_states,
        )

    def loss_old(
        self,
        *,
        t3_cond: T3Cond,
        text_tokens: torch.LongTensor,
        text_token_lens: torch.LongTensor,
        speech_tokens: torch.LongTensor,
        speech_token_lens: torch.LongTensor,
    ):
        "training method"
        len_text = text_tokens.size(1)
        len_speech = speech_tokens.size(1)
        assert len_text == text_token_lens.max()
        assert len_speech == speech_token_lens.max()

        out = self.forward(
            t3_cond=t3_cond,
            text_tokens=text_tokens,
            text_token_lens=text_token_lens,
            speech_tokens=speech_tokens,
            speech_token_lens=speech_token_lens,
            training=True,
        )  # (B, seq, vocab_size)

        # Calc CCE losses
        IGNORE_ID = -100
        device = out.text_logits.device
        mask_text = torch.arange(len_text, device=device)[None] >= text_token_lens[:, None]  # (B, len_text)
        mask_speech = torch.arange(len_speech, device=device)[None] >= speech_token_lens[:, None]  # (B, len_speech)
        masked_text = text_tokens.masked_fill(mask_text, IGNORE_ID)
        masked_speech = speech_tokens.masked_fill(mask_speech, IGNORE_ID)

        loss_text   = F.cross_entropy(out.text_logits.transpose(1, 2),    masked_text,    ignore_index=IGNORE_ID)
        loss_speech = F.cross_entropy(out.speech_logits.transpose(1, 2), masked_speech, ignore_index=IGNORE_ID)
        #loss_text = F.cross_entropy(out.text_logits, masked_text, ignore_index=IGNORE_ID)
        #loss_speech = F.cross_entropy(out.speech_logits, masked_speech, ignore_index=IGNORE_ID)

        return loss_text, loss_speech

    def loss(
        self,
        *,
        t3_cond: T3Cond,
        text_tokens: torch.LongTensor,        # (B, S_text_padded), includes BOS & EOS
        text_token_lens: torch.LongTensor,    # (B,), actual lengths including BOS & EOS
        speech_tokens: torch.LongTensor,      # (B, S_speech_padded), includes BOS & EOS
        speech_token_lens: torch.LongTensor,  # (B,), actual lengths including BOS & EOS
        labels_text: torch.LongTensor,        # (B, S_text_padded-1), already masked with –100
        labels_speech: torch.LongTensor       # (B, S_speech_padded-1), already masked with –100
    ):
        """
        Compute text and speech cross-entropy using pre-masked labels from the collator.
        Assumes:
        - labels_text[t] corresponds to predicting text_tokens[:, 1:] with –100 where ignored
        - labels_speech[t] corresponds to predicting speech_tokens[:, 1:] with –100 where ignored
        """

        # 1) Run model to get logits
        out = self.forward(
            t3_cond=t3_cond,
            text_tokens=text_tokens,
            text_token_lens=text_token_lens,
            speech_tokens=speech_tokens,
            speech_token_lens=speech_token_lens,
            training=True,
        )
        # out.text_logits: (B, S_text_padded, V_text)
        # out.speech_logits: (B, S_speech_padded, V_speech)
        device = out.text_logits.device
        IGNORE_ID = -100

        # --- Text Loss (use labels_text directly) ---
        # Align logits: predict t₁..EOS from inputs [BOS, t₁..]
        logits_for_text = out.text_logits[:, :-1, :].contiguous()  # (B, S_text_padded-1, V_text)
        # labels_text already has shape (B, S_text_padded-1) with –100 where masked
        if logits_for_text.size(1) == 0:
            loss_text = torch.tensor(0.0, device=device, requires_grad=self.training)
        else:
            loss_text = F.cross_entropy(
                logits_for_text.transpose(1, 2),  # (B, V_text, S_text_padded-1)
                labels_text,                      # (B, S_text_padded-1), ignore_index=–100
                ignore_index=IGNORE_ID
            )

        # --- Speech Loss (use labels_speech directly) ---
        logits_for_speech = out.speech_logits[:, :-1, :].contiguous()  # (B, S_speech_padded-1, V_speech)
        # labels_speech already has shape (B, S_speech_padded-1) with –100 where masked
        if logits_for_speech.size(1) == 0:
            loss_speech = torch.tensor(0.0, device=device, requires_grad=self.training)
        else:
            loss_speech = F.cross_entropy(
                logits_for_speech.transpose(1, 2),  # (B, V_speech, S_speech_padded-1)
                labels_speech,                      # (B, S_speech_padded-1), ignore_index=–100
                ignore_index=IGNORE_ID
            )

        return loss_text, loss_speech, out.speech_logits

    @torch.inference_mode()
    def inference(
        self,
        *,
        t3_cond: T3Cond,
        text_tokens: Tensor,
        initial_speech_tokens: Optional[Tensor]=None,

        # misc conditioning
        prepend_prompt_speech_tokens: Optional[Tensor]=None,

        # HF generate args
        num_return_sequences=1,
        max_new_tokens=None,
        stop_on_eos=True,
        do_sample=True,
        temperature=0.8,
        top_p=0.8,
        length_penalty=1.0,
        repetition_penalty=2.0,
        cfg_weight=0,
    ):
        """
        Args:
            text_tokens: a 1D (unbatched) or 2D (batched) tensor.
        """
        # Validate / sanitize inputs
        assert prepend_prompt_speech_tokens is None, "not implemented"
        _ensure_BOT_EOT(text_tokens, self.hp)
        text_tokens = torch.atleast_2d(text_tokens).to(dtype=torch.long, device=self.device)

        # Default initial speech to a single start-of-speech token
        if initial_speech_tokens is None:
            initial_speech_tokens = self.hp.start_speech_token * torch.ones_like(text_tokens[:, :1])

        # Prepare custom input embeds
        embeds, len_cond = self.prepare_input_embeds(
            t3_cond=t3_cond,
            text_tokens=text_tokens,
            speech_tokens=initial_speech_tokens,
            cfg_weight=cfg_weight,
        )

        # In order to use the standard HF generate method, we need to extend some methods to inject our custom logic
        # Note the llama-specific logic. Other tfmr types can be added later.

        self.compiled = False

        # TODO? synchronize the expensive compile function
        # with self.compile_lock:
        if not self.compiled:
            alignment_stream_analyzer = AlignmentStreamAnalyzer(
                self.tfmr,
                None,
                text_tokens_slice=(len_cond, len_cond + text_tokens.size(-1)),
                alignment_layer_idx=9, # TODO: hparam or something?
                eos_idx=self.hp.stop_speech_token,
            )
            patched_model = T3HuggingfaceBackend(
                config=self.cfg,
                llama=self.tfmr,
                speech_enc=self.speech_emb,
                speech_head=self.speech_head,
                alignment_stream_analyzer=alignment_stream_analyzer,
            )
            self.patched_model = patched_model
            self.compiled = True

        # # Run normal generate method, which calls our custom extended methods
        # return self.patched_model.generate(
        #     inputs=initial_speech_tokens,
        #     decoder_cond=embeds,
        #     bos_token_id=self.hp.start_speech_token,
        #     eos_token_id=(self.hp.stop_speech_token if stop_on_eos else -1),
        #     pad_token_id=self.hp.stop_speech_token,
        #     max_new_tokens=max_new_tokens or self.hp.max_speech_tokens,
        #     num_return_sequences=num_return_sequences,
        #     temperature=temperature,
        #     top_p=top_p,
        #     length_penalty=length_penalty,
        #     repetition_penalty=repetition_penalty,
        #     do_sample=do_sample,
        #     # cache_implementation=None if not self.compiled else "static",
        # )

        device = embeds.device

        bos_token = torch.tensor([[self.hp.start_speech_token]], dtype=torch.long, device=device)
        bos_embed = self.speech_emb(bos_token)  # shape: (B, 1, embed_dim)
        bos_embed = bos_embed + self.speech_pos_emb.get_fixed_embedding(0)

        # batch_size=2 for CFG
        bos_embed = torch.cat([bos_embed, bos_embed])

        # Combine condition and BOS token for the initial input if cfg_weight > 0
        if cfg_weight > 0:
            inputs_embeds = torch.cat([embeds, bos_embed], dim=1)
        else:
            inputs_embeds = embeds

        # Track generated token ids; start with the BOS token.
        generated_ids = bos_token.clone()
        predicted = []  # To store the predicted tokens

        # Instantiate the logits processors.
        top_p_warper = TopPLogitsWarper(top_p=top_p)
        repetition_penalty_processor = RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)

        # ---- Initial Forward Pass (no kv_cache yet) ----
        output = self.patched_model(
            inputs_embeds=inputs_embeds,
            past_key_values=None,
            use_cache=True,
            output_attentions=True,
            output_hidden_states=True,
            return_dict=True,
        )
        # Initialize kv_cache with the full context.
        past = output.past_key_values

        # ---- Generation Loop using kv_cache ----
        for i in tqdm(range(max_new_tokens), desc="Sampling", dynamic_ncols=True):
            logits = output.logits[:, -1, :]

            # CFG
            if cfg_weight > 0.0:
                logits_cond = logits[0:1]
                logits_uncond = logits[1:2]
                logits = logits_cond + cfg_weight * (logits_cond - logits_uncond)

            logits = logits.squeeze(1)

            # Apply temperature scaling.
            if temperature != 1.0:
                logits = logits / temperature

            # Apply repetition penalty and top‑p filtering.
            logits = repetition_penalty_processor(generated_ids, logits)
            logits = top_p_warper(None, logits)

            # Convert logits to probabilities and sample the next token.
            probs = torch.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)  # shape: (B, 1)

            predicted.append(next_token)
            generated_ids = torch.cat([generated_ids, next_token], dim=1)

            # Check for EOS token.
            if next_token.view(-1) == self.hp.stop_speech_token:
                break

            # Get embedding for the new token.
            next_token_embed = self.speech_emb(next_token)
            next_token_embed = next_token_embed + self.speech_pos_emb.get_fixed_embedding(i + 1)

            #  For CFG
            if cfg_weight > 0.0:
                next_token_embed = torch.cat([next_token_embed, next_token_embed])

            # Forward pass with only the new token and the cached past.
            output = self.patched_model(
                inputs_embeds=next_token_embed,
                past_key_values=past,
                output_attentions=True,
                output_hidden_states=True,
                return_dict=True,
            )
            # Update the kv_cache.
            past = output.past_key_values

        # Concatenate all predicted tokens along the sequence dimension.
        predicted_tokens = torch.cat(predicted, dim=1)  # shape: (B, num_tokens)
        return predicted_tokens

Hope this helps.

My training is now running since 72h. It now sounds very german, but still have problems with the correct order of words. I will continue the training, and then will share my weights.

Hello lpscr! You have analyzed the problem perfectly. The TypeError occurs because your new finetune_t3_preprocessed.py script is calling an updated loss function in the T3 model, but your local t3.py file still contains the old version of that function. The new finetuning script passes the arguments labels_text and labels_speech to make the process more efficient, but the original loss function doesn't recognize these arguments. The solution is exactly what havok2-htwo provided in their latest comment. You need to update your t3.py file as well. How to Fix the Error: Find the correct file: Navigate to the path you already identified: chatterbox/src/chatterbox/models/t3/t3.py. Copy the new code: Go to havok2-htwo's last comment (the one that starts with # Copyright (c) 2025 Resemble AI). Replace the content: Copy the entire code block from that comment and use it to completely replace the contents of your local t3.py file. Why This Fixes the Error: If you look at the new code for t3.py, you will see that havok2-htwo has refactored the loss function. It now has the exact signature that your finetuning script expects: Generated python

In the new t3.py from havok2-htwo

python
def loss(
    self,
    *,
    t3_cond: T3Cond,
    text_tokens: torch.LongTensor,
    text_token_lens: torch.LongTensor,
    speech_tokens: torch.LongTensor,
    speech_token_lens: torch.LongTensor,
    labels_text: torch.LongTensor,        # <--- Here is the expected argument!
    labels_speech: torch.LongTensor       # <--- And here is the second one!
):
    # ... function logic

After you update your t3.py file with this new code, the T3.loss() function will correctly accept the labels_text and labels_speech arguments, and the TypeError will be resolved. Your training script should run without this error after making that change. Good luck. Dear Gemeni

havok2-htwo avatar Jun 29 '25 09:06 havok2-htwo

@havok2-htwo Thank you very much for checking this very helpfull

lpscr avatar Jul 01 '25 09:07 lpscr

@havok2-htwo Thank you very much for checking this very helpfull

Your welcome. My training runs good even with 36°C room temperature :D thanks god for custom water cooling. I tried yesterday to merge it with kartoffels model and the result was pretty good. I let the training run another day and then I will release the german fine tuned model also a merged model with Kartoffel‘s one.

havok2-htwo avatar Jul 01 '25 10:07 havok2-htwo

@rotatorotator here I uploaded the trained model, weights and also a merged version with Kartoffelbox and mine: https://huggingface.co/havok2/Kartoffelbox-v0.1_0.65h2

have fun with it =)

havok2-htwo avatar Jul 02 '25 23:07 havok2-htwo

@rotatorotator here I uploaded the trained model, weights and also a merged version with Kartoffelbox and mine: https://huggingface.co/havok2/Kartoffelbox-v0.1_0.65h2

have fun with it =)

Great @havok2-htwo , How can I do like you on Arabic? Is it valuable about integration if Karoffelbox with Chatterbox? I dont know why you mention this integration.

cod3r0k avatar Jul 03 '25 06:07 cod3r0k

@rotatorotator here I uploaded the trained model, weights and also a merged version with Kartoffelbox and mine: https://huggingface.co/havok2/Kartoffelbox-v0.1_0.65h2 have fun with it =)

Great @havok2-htwo , How can I do like you on Arabic? Is it valuable about integration if Karoffelbox with Chatterbox? I dont know why you mention this integration.

Hey I described the workflow above also shared the adjusted code. All you need is a lot of arabic wav files with their txt pair which contains the transcription. Then you train the stuff.

My model was able to produce good german sou ding voice, but struggle with the right order of words. So my model was good in producing fake-german but was not so good in german grammar. Kartoffels model instead was good in german grammar - so I mixed it with mine. I got then a better model. I only hat 12.000 german samples for training. Kartoffel used ~2.500.000 samples. So I think if you want to train arabic you had to generate a lot of data with elevenlabs for example. Otherwise may you have access to a big arabic voice database, then you can use whisper ai model to transcript them all to txt files.

havok2-htwo avatar Jul 03 '25 07:07 havok2-htwo

@havok2-htwo Hi there,

I’m sorry to bother you I’ve been having some trouble. I’m trying to train my model, but I keep running into an out-of-memory error. I have an RTX 4090 (24 GB VRAM). Could you please let me know which settings I should use to fit within the 24 GB limit?

Thank you very much for all your work and your time!

lpscr avatar Jul 04 '25 10:07 lpscr

❗ Issue: Help Needed to Use finetune_t3_preprocessed.py Step by Step

Hi, I'm trying to fine-tune a model using your codebase and would appreciate step-by-step guidance. Here's what I'm doing and where I’m stuck:


🔧 My Setup and Steps So Far

  1. Python Version:

    • Which version of Python do you recommend for this project?
      I'm currently using Python 3.11.
  2. Dataset Format:

    • My data is formatted similarly to LJSpeech:
      • A wavs/ directory containing audio files
      • A metadata.csv file structured as:
        id|transcription
        
  3. Script Used:

    • I’m trying to run the training using finetune_t3_preprocessed.py.

⚠ Current Problem

When I execute the script, I get the following error:

/python3.11/site-packages/transformers/utils/import_utils.py", line 1780, in _get_module
raise RuntimeError(
RuntimeError: Failed to import transformers.models.llama.modeling_llama because of the following error (look up to see its traceback):
operator torchvision::nms does not exist

❓ Questions

  1. What version of Python, PyTorch, and Transformers should I be using?
  2. Is there a specific preprocessing step I might be missing?
  3. How can I resolve the torchvision::nms does not exist error?
  4. Could you provide a step-by-step guide to fine-tune using finetune_t3_preprocessed.py on LJSpeech-style data?

Thanks in advance for your support @havok2-htwo !

cod3r0k avatar Jul 04 '25 13:07 cod3r0k

❗ Issue: Help Needed to Use finetune_t3_preprocessed.py Step by Step

Hi, I'm trying to fine-tune a model using your codebase and would appreciate step-by-step guidance. Here's what I'm doing and where I’m stuck:

🔧 My Setup and Steps So Far

  1. Python Version:

    • Which version of Python do you recommend for this project? I'm currently using Python 3.11.
  2. Dataset Format:

    • My data is formatted similarly to LJSpeech:

      • A wavs/ directory containing audio files
      • A metadata.csv file structured as:
        id|transcription
        
  3. Script Used:

    • I’m trying to run the training using finetune_t3_preprocessed.py.

⚠ Current Problem

When I execute the script, I get the following error:

/python3.11/site-packages/transformers/utils/import_utils.py", line 1780, in _get_module
raise RuntimeError(
RuntimeError: Failed to import transformers.models.llama.modeling_llama because of the following error (look up to see its traceback):
operator torchvision::nms does not exist

❓ Questions

  1. What version of Python, PyTorch, and Transformers should I be using?
  2. Is there a specific preprocessing step I might be missing?
  3. How can I resolve the torchvision::nms does not exist error?
  4. Could you provide a step-by-step guide to fine-tune using finetune_t3_preprocessed.py on LJSpeech-style data?

Thanks in advance for your support @havok2-htwo !

hey to train the model with the parameters as I have written there, you need at least 27 or 28 GB of graphics memory. You would then have to reduce the maximum token length when creating the training, i.e. set the 2048 to 1024, for example. So when preparing the material as well as during the training itself. but then the training samples will be so far. I understand shorter being used. In other words, it should actually be around 18 seconds at the moment, then it would only be nine, if I'm interpreting it correctly. So it's the same during preparation and training. I'm already on a chunk during training. Size one and smaller is unfortunately not possible.

havok2-htwo avatar Jul 04 '25 14:07 havok2-htwo

I prepare my data and run preprocessed_data.py and it take a long time, i think it get stuck! @havok2-htwo I have 120 cpu core and V100

2025-07-05 00:23:37,643 - INFO - __main__ - Found 5325 original audio-text pairs. Starting preprocessing...                                                                                                                        - INFO
2025-07-05 00:23:37,643 - INFO - __main__ - Using 100 worker processes. Each original file generates up to 4 samples.                                                                                                       00,512 
2025-07-05 00:23:37,645 - INFO - __main__ - Total expected samples to attempt (incl. augmentations): 21300
Processing original files (incl. augmentations):   0%|                                                                                                                            | 0/5325 [00:00<?, ?it/s]2025-07-05 00:24:00,512 - INFO - __main__ - Saving test audios for LJmine-1.wav (OriginalIndex 0) in processed_tts_data_AUGMENTED_WITH_PARAMS_2048_4096/test_audio_samples_debug                                                                       Jul-25
        

Another question about G2P, do I need to change anything for my Arabic language?

cod3r0k avatar Jul 05 '25 07:07 cod3r0k

Hey, I am currently out of home, but as far I can remember I have P3.10 and Torch 2.8 Nightly Builds - was the only one I found that worked with my 5090 so far.

@cod3r0k Can you post one txt file example? How long are your wav‘s? Try only 1 Worker - sounds strange but the original file uses only 1 under windows
 I installed something then it worked with more than 1. Just to try. My 12k samples prepared in around 30minutes on my 9800x3D with Workercount 14. I have 64GB Ram. May 100 cores ran out of memory? May try without augmentation parameters?

I give more help next week. Greetings

havok2-htwo avatar Jul 05 '25 07:07 havok2-htwo