chatterbox
chatterbox copied to clipboard
đ Proposal: Optimize Data Preprocessing for >4x Faster TTS Training
Problem: The current finetune_t3.py script's on-the-fly data processing (audio loading, resampling, tokenization, VE) in Dataset.getitem creates a major CPU bottleneck. For a ~5000 sample dataset (1.4GB raw audio, generated using ElevenLabs, which conveniently provided corresponding .txt files for the speech content), this limited iteration time to ~2.2-2.9s/it even with optimized dataloader_num_workers (down from an initial ~5s/it).
Proposed Solution: Offline Preprocessing - I've implemented a two-stage approach: preprocess_data.py (New Script): A standalone, parallelized (multi-CPU core) script that processes the entire raw dataset once. It handles all expensive CPU tasks (audio ops, tokenization, VE) using the .wav and associated .txt files, and saves the resulting tensors for each sample into individual .pt files. finetune_t3_preprocessed.py (Modified Training Script): The Dataset now simply loads these precomputed .pt files, drastically reducing CPU work during training. (Code for both scripts will be embedded below for review.) Key Benefits & Observed Results: Training Iteration Time: Reduced to ~1.2 seconds/iteration. This is a >4x speedup over the initial baseline and ~2x faster than optimized on-the-fly processing for this dataset. Reduced CPU Bottleneck: Training CPU load is minimal; dataloader_num_workers can be lowered (e.g., 2-8). Maximized GPU Utilization: VRAM becomes the primary constraint for batch size. Efficient Workflow: Preprocessing is a one-time cost, benefiting all subsequent training runs. Data Size: Preprocessed data for the 5000 samples was <100MB (from 1.4GB raw audio). This optimization significantly speeds up fine-tuning, especially when working with datasets structured with audio and corresponding text files. I'm happy to provide the scripts and discuss further.
Code see the posts below.
how would we run inference (using the gradio) I mean in terms of how do we point to the new model we trained?
I asked gemeni:
Okay, here's a shorter way to explain how to use the fine-tuned model for inference, for example, in a Gradio app:
"To use your fine-tuned model for inference (e.g., with Gradio):
Locate your fine-tuned model directory. This is the --output_dir from your training (e.g., ../checkpoints/chatterbox_10k_DE_.../). This directory should contain: t3_cfg.safetensors (your T3 weights) ve.safetensors s3gen.safetensors tokenizer.json In your Gradio script (or any inference script), load the model using ChatterboxTTS.from_local(): Generated python from chatterbox.tts import ChatterboxTTS import torch
MODEL_DIR = "path/to/your/finetuned_model_dir" # <-- CHANGE THIS DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
tts_model = ChatterboxTTS.from_local(ckpt_dir=MODEL_DIR, device=DEVICE)
Now use tts_model.generate(...) in your Gradio functions
content_copy download Use code with caution. Python No need to edit tts.py from the library. The from_local method is designed to load models saved this way. Just make sure the directory you point to has all the necessary files. If you want to use a specific intermediate checkpoint (like checkpoint-8000), you'd first need to manually prepare a directory that contains all these components (extract T3 weights from the Trainer's save and copy VE, S3Gen, Tokenizer files into it), and then point MODEL_DIR to that prepared directory."
I donât use gradio in my project, but Iâll hope it helps. I managed to get a much better german experience - model by using 10.000 german samples created with chatgpt and elevenlabs. I just ran the finetuning around 12h on my RTX5090.
@havok2-htwo thanks for providing the script! are you able to share your trained weights for german chatterbox? I would like to try it out! Did you observe loss or anything while training?
@havok2-htwo thanks for providing the script! are you able to share your trained weights for german chatterbox? I would like to try it out! Did you observe loss or anything while training?
Hi! I've updated the scripts and will upload them here in the next few days (currently at work and not at home). I've optimized several parts and ran multiple training sessions with my 5090 over the past few days. The German output is getting pretty good â not perfect, but much more understandable now.
I ran into a few issues during training. Some of the training data was too short, which caused strange results. My current training data script now only creates samples if the input is long enough. I also cleaned most of the samples by hand. Have generated 12000 Samples with Elevenlabs here.
I also added a lot of examples with tricky pronunciations like âZiel,â âGeneration,â and similar words to help reduce mispronunciations.
Right now, I'm still having trouble with the start and end of the generated audio: my fine-tuned model often skips the first word and doesn't stop speaking when the text ends. I asked Gemini about it, and it suggested that this might be due to missing silence at the beginning and end of the audio files. So Iâm now adding 300ms of silence to both ends of each file.
This is actually my first time fine-tuning a model â aside from playing with deepfakes back in the day. đ
By the way â any suggestions on where best to upload the trained weights? Would GitHub be suitable, or should I use something like Hugging Face or Google Drive? I will try to retrain my model with the added silence, and then will share it. Give me some days.
@rotatorotator I have updated the code above completly
@havok2-htwo thanks for providing the script! are you able to share your trained weights for german chatterbox? I would like to try it out! Did you observe loss or anything while training?
Hi! I've updated the scripts and will upload them here in the next few days (currently at work and not at home). I've optimized several parts and ran multiple training sessions with my 5090 over the past few days. The German output is getting pretty good â not perfect, but much more understandable now.
I ran into a few issues during training. Some of the training data was too short, which caused strange results. My current training data script now only creates samples if the input is long enough. I also cleaned most of the samples by hand. Have generated 12000 Samples with Elevenlabs here.
I also added a lot of examples with tricky pronunciations like âZiel,â âGeneration,â and similar words to help reduce mispronunciations.
Right now, I'm still having trouble with the start and end of the generated audio: my fine-tuned model often skips the first word and doesn't stop speaking when the text ends. I asked Gemini about it, and it suggested that this might be due to missing silence at the beginning and end of the audio files. So Iâm now adding 300ms of silence to both ends of each file.
This is actually my first time fine-tuning a model â aside from playing with deepfakes back in the day. đ
By the way â any suggestions on where best to upload the trained weights? Would GitHub be suitable, or should I use something like Hugging Face or Google Drive? I will try to retrain my model with the added silence, and then will share it. Give me some days.
Thank you I will try these out. Your previous code seemed to work well for English. I was wondering if you know how we might approach training a non-latin based language such as Arabic is there a cut and dry way to do this without too much wrangling? p.s Huggingface is good for models.
@rotatorotator I have updated the code above completly
Thanks, but I did not mean the code but the weights you have trained e.g. the T3 file. As mentioned Huggingface is a good option for sharing these. FYI there is another one training german chatterbox: https://huggingface.co/SebastianBodza/Kartoffelbox-v0.1 If you need gpu recources for training I would be happy to contribute to that.
Great @havok2-htwo @rotatorotator can I use it for Arabic? can you guide me about data and steps?
Great @havok2-htwo @rotatorotator can I use it for Arabic? can you guide me about data and steps?
Sure, I believe it should work if you have enough speech samples. I'm currently redesigning a lot of my code based on Geminiâs suggestions. I used the training code as a base from this repository, but I found a few issues in the tts.py file and also while reading through the cond_enc.py. Gemini provided me with the correct values for this model and the proper tokens for the start and stop points. I hope this will help me fix the issues and run a proper fine-tuning session to fully leverage the potential of this base model.
I also realized that I trained for 128 epochs, but that was way too much. I actually achieved the best results at epoch 5, after which the training suffered from overfitting. So I will address that as well. I've generated a sample where you can hear the result, but there are some issues at the end that still need improvement.
(https://jmp.sh/s/71S206JyeoDuZa0YK2ve)
Does it mean, that the finetuning time (12h on your 5090) was the 128 epochs and with training for about 5 epochs we could train in just a few minutes? Probably with a bigger dataset we would need to train more epochs.
Does it mean, that the finetuning time (12h on your 5090) was the 128 epochs and with training for about 5 epochs we could train in just a few minutes? Probably with a bigger dataset we would need to train more epochs.
Yes, the bottleneck was definitely the CPU â and the fact that each audio file had to be converted into a .pt file before the GPU could even start working. I just preprocessed all the WAV files (basically extracted features) and then fed everything to the GPU. Now itâs running steadily at 600W under full load.
But I noticed that during training, the sequence length was limited to just 300 tokens, even though the original model was trained with 2048 tokens. When I try to match that original token length, my 32GB of VRAM isnât sufficient â especially if I enable evaluation. Even with a batch size of 1, Iâm already hitting 28GB of VRAM.
Still, I think itâs important not to go below the original sequence length â that might actually be the reason why the model keeps rambling at the end of sentencesâŠ
But with Batch: 1 its slower than with 5.
how big is the batch influence on vram usage? I am thinking about just renting a H100 with 80 GB Vram or something like that to train if we have a dataset which is good enough
Great @havok2-htwo @rotatorotator can I use it for Arabic? can you guide me about data and steps?
Sure, I believe it should work if you have enough speech samples. I'm currently redesigning a lot of my code based on Geminiâs suggestions. I used the training code as a base from this repository, but I found a few issues in the tts.py file and also while reading through the cond_enc.py. Gemini provided me with the correct values for this model and the proper tokens for the start and stop points. I hope this will help me fix the issues and run a proper fine-tuning session to fully leverage the potential of this base model.
I also realized that I trained for 128 epochs, but that was way too much. I actually achieved the best results at epoch 5, after which the training suffered from overfitting. So I will address that as well. I've generated a sample where you can hear the result, but there are some issues at the end that still need improvement.
(https://jmp.sh/s/71S206JyeoDuZa0YK2ve)
Thanks. Could you share step by step of your finetuning code?
Her comes my new code: preprocess_data.py: - this creates pre calculated stuff for the gpu wit some variances
# preprocess_data.py
import argparse
import logging
from pathlib import Path
import os
from typing import Any, Dict, List, Tuple, Optional, Union
import functools
import gc # FĂŒr Garbage Collection
import torch
import torch.nn.functional as F
import librosa
import numpy as np
import soundfile as sf
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor, as_completed
# Versuche, audiotsm zu importieren
try:
from audiotsm import wsola
from audiotsm.io.array import ArrayReader, ArrayWriter
AUDIOTSM_AVAILABLE = True
logger_init = logging.getLogger(__name__) # TemporĂ€rer Logger fĂŒr Import-Status
logger_init.info("audiotsm gefunden und wird fĂŒr Pitch Shifting verwendet (kombiniert mit Resampling).")
except ImportError:
AUDIOTSM_AVAILABLE = False
logger_init = logging.getLogger(__name__)
logger_init.warning("audiotsm nicht gefunden. Fallback auf librosa.effects.pitch_shift. FĂŒr audiotsm-basiertes Pitching, 'pip install audiotsm' installieren.")
# Chatterbox specific imports
try:
from chatterbox.tts import ChatterboxTTS, punc_norm
from chatterbox.models.t3.modules.t3_config import T3Config
from chatterbox.models.s3tokenizer import S3_SR, S3Tokenizer
except ImportError:
print("Stelle sicher, dass das Chatterbox-Paket korrekt installiert ist und im PYTHONPATH liegt.")
raise
logger = logging.getLogger(__name__) # Haupt-Logger
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(name)s - %(message)s")
# --- Augmentierungsfunktionen ---
def change_speed(audio_array: np.ndarray, sr: int, speed_factor: float) -> Optional[np.ndarray]:
if speed_factor == 1.0: return audio_array.copy()
if not (0.5 <= speed_factor <= 2.0):
logger.debug(f"Speed factor {speed_factor} out of range [0.5, 2.0]. Skipping.")
return None
try:
audio_array_copy = np.ascontiguousarray(audio_array)
# Wenn audiotsm verfĂŒgbar ist, könnte man es auch hier verwenden fĂŒr bessere QualitĂ€t
if AUDIOTSM_AVAILABLE and audio_array_copy.ndim == 1: # audiotsm ArrayReader erwartet 1D fĂŒr Mono
try:
# audiotsm erwartet Samples als int16 oder float32, aber ArrayReader/Writer kĂŒmmern sich darum
# wenn channels=1. Librosa liefert float32.
reader = ArrayReader(np.array([audio_array_copy.T], dtype=np.float32)) # Muss [channels, samples] sein
writer = ArrayWriter(channels=1) # Wird 1D Array zurĂŒckgeben
tsm = wsola(channels=1, speed=speed_factor)
tsm.run(reader, writer)
stretched_audio = writer.data.T.flatten() # ZurĂŒck zu 1D
return stretched_audio if len(stretched_audio) > 0 else None
except Exception as e_tsm_speed:
logger.warning(f"audiotsm speed change failed for factor {speed_factor}: {e_tsm_speed}. Fallback auf librosa.")
stretched_audio = librosa.effects.time_stretch(y=audio_array_copy, rate=speed_factor)
else:
stretched_audio = librosa.effects.time_stretch(y=audio_array_copy, rate=speed_factor)
return stretched_audio if len(stretched_audio) > 0 else None
except Exception as e:
logger.warning(f"change_speed (librosa) failed for factor {speed_factor}: {e}")
return None
def change_pitch(audio_array: np.ndarray, sr: int, n_steps: float) -> Optional[np.ndarray]:
if n_steps == 0.0: return audio_array.copy()
if not (-12.0 <= n_steps <= 12.0):
logger.debug(f"Pitch steps {n_steps} out of range [-12, 12]. Skipping.")
return None
audio_array_copy = np.ascontiguousarray(audio_array)
if AUDIOTSM_AVAILABLE and audio_array_copy.ndim == 1: # audiotsm ArrayReader erwartet 1D fĂŒr Mono
try:
semitone_ratio = 2.0**(1.0/12.0)
pitch_shift_factor = semitone_ratio**n_steps
# 1. Time-stretch mit audiotsm (wsola)
# Der speed_factor fĂŒr audiotsm ist hier gleich dem pitch_shift_factor
# Wenn pitch_shift_factor > 1 (höherer Pitch), wird das Audio gestaucht (klingt schneller).
# Wenn pitch_shift_factor < 1 (tieferer Pitch), wird das Audio gedehnt (klingt langsamer).
reader = ArrayReader(np.array([audio_array_copy.T], dtype=np.float32))
writer = ArrayWriter(channels=1)
tsm = wsola(channels=1, speed=pitch_shift_factor) # Speed = pitch_factor
tsm.run(reader, writer)
time_stretched_audio = writer.data.T.flatten()
if len(time_stretched_audio) == 0:
logger.warning("audiotsm pitch_shift: time_stretched_audio is empty.")
return None
# 2. Resample, um die ursprĂŒngliche Dauer wiederherzustellen, aber mit der neuen Tonhöhe.
# Die neue "effektive" Samplerate des gestretchten Audios, wenn es die originale Tonhöhe hÀtte,
# wÀre sr * pitch_shift_factor. Um es auf die originale Dauer mit der *neuen* Tonhöhe
# zu bringen, resampeln wir es von dieser "effektiven" SR zurĂŒck zur originalen SR.
# Oder einfacher: Das gestretchte Audio hat nun N' Samples.
# Es soll wieder N Samples haben.
# Die Samples wurden effektiv mit `pitch_shift_factor` "gespielt".
# Wir resampeln es nun so, dass es auf die OriginallÀnge kommt.
# Die OriginallÀnge in Samples war `len(audio_array_copy)`.
# Die neue LĂ€nge in Samples ist `len(time_stretched_audio)`.
#
# Korrekte Logik fĂŒr Resampling nach TSM fĂŒr Pitch Shift:
# Das time_stretched_audio hat jetzt eine LĂ€nge, die `len(original_audio) / pitch_shift_factor` entspricht.
# Wir wollen es auf die ursprĂŒngliche LĂ€nge `len(original_audio)` resampeln.
# Librosa's resample nimmt target_sr.
# Das gestreckte Audio ist immer noch bei `sr`. Wir wollen es so resampeln,
# als ob es mit `sr / pitch_shift_factor` abgespielt wĂŒrde, um die ursprĂŒngliche Dauer zu erhalten.
# Also resampeln wir es von `sr` zu `sr / pitch_shift_factor`.
# Nein, das ist falsch herum.
#
# Wenn wir Audio mit Faktor S gestreckt haben, ist es jetzt S mal so lang.
# Samplerate ist noch SR.
# Wir wollen die Tonhöhe um Faktor P Àndern.
# Stretch-Faktor fĂŒr TSM = P. LĂ€nge des TSM-Audios ist OriginalLĂ€nge / P.
# Dieses TSM-Audio (LĂ€nge L/P, SR) muss jetzt resampelt werden.
# Wenn P > 1 (Pitch hoch), ist TSM-Audio kĂŒrzer. Wir mĂŒssen es "langsamer" resampeln (target_sr < sr).
# Target SR = SR / P.
# Das gestreckte Audio (immer noch bei `sr`) hat jetzt die Samples, die dem neuen Pitch entsprechen, aber bei falscher Dauer.
# Um die Dauer zu korrigieren und den Pitch beizubehalten, resampeln wir.
# Wenn der Pitch erhöht wurde (pitch_shift_factor > 1), ist time_stretched_audio kĂŒrzer.
# Wir mĂŒssen es "dehnen" durch Resampling auf eine niedrigere SR, um die OriginallĂ€nge zu erreichen.
# Die neue SR, um die OriginallÀnge zu erreichen, wÀre sr / pitch_shift_factor.
# Dann hÀtten wir das Originalaudio, nur langsamer/schneller abgespielt. Das ist nicht das Ziel.
#
# Richtig:
# 1. Audio time-stretchen mit Faktor `1/pitch_shift_factor`. D.h. `speed = 1/pitch_shift_factor`
# Audio wird lÀnger, wenn Pitch erhöht werden soll.
# Audio wird kĂŒrzer, wenn Pitch gesenkt werden soll.
# 2. Resample dieses gestretchten Audios von `sr` zu `sr * pitch_shift_factor`.
tsm_speed = 1.0 / pitch_shift_factor
reader_corr = ArrayReader(np.array([audio_array_copy.T], dtype=np.float32))
writer_corr = ArrayWriter(channels=1)
tsm_corr = wsola(channels=1, speed=tsm_speed)
tsm_corr.run(reader_corr, writer_corr)
time_stretched_audio_corr = writer_corr.data.T.flatten()
if len(time_stretched_audio_corr) == 0:
logger.warning("audiotsm pitch_shift (korrigiert): time_stretched_audio_corr is empty.")
return None
# Die Ziel-Samplerate fĂŒr librosa.resample ist sr.
# Das time_stretched_audio_corr wurde effektiv mit sr_orig / tsm_speed = sr_orig * pitch_shift_factor "erzeugt".
# Wir resampeln es jetzt von dieser "virtuellen" Samplerate zur originalen Samplerate.
# Nein, librosa.resample: y, orig_sr, target_sr.
# Das time_stretched_audio_corr hat die Samplerate `sr`.
# Es muss auf eine neue Samplerate resampelt werden, so dass bei Wiedergabe mit `sr` der Pitch-Shift entsteht.
# Dies ist konzeptionell, was librosa.effects.pitch_shift intern macht.
#
# Der einfachste Weg, das "richtig" zu machen, ist, `librosa.effects.pitch_shift` als Referenz zu nehmen:
# Es verwendet STFT, verschiebt Bins und macht ISTFT.
#
# Um es mit TSM + Resample zu machen:
# - Strecke das Audio um Faktor `1/P` (P=pitch_factor). Audio ist jetzt `L/P` lang. (Falsch, Audio ist L* (1/ (1/P)) = L*P lang).
# Nein, wenn `speed = S`, neue LĂ€nge = `alte LĂ€nge / S`.
# Also, `speed = 1/P`. Neue LĂ€nge = `alte LĂ€nge * P`.
# - Resample dieses `L*P` lange Audio (das immer noch SR `sr` hat) auf die ursprĂŒngliche LĂ€nge `L`.
# Das bedeutet, die Anzahl der Samples muss von `(L*P)/sr * sr` auf `L/sr * sr` geÀndert werden.
# Das ist ein Resampling von `sr` zu `sr / P`.
# Das ist effektiv Time-Stretching um Faktor `P` und dann Resampling zurĂŒck.
# Dieser Ansatz ist leider nicht so trivial, wie er scheint, um die QualitÀt zu halten.
# Librosa's `pitch_shift` ist trotz der Artefakte oft die direktere Methode fĂŒr reines Pitch-Shifting.
# Wenn `audiotsm` fĂŒr Pitch-Shifting verwendet wird, ist es oft in komplexeren Systemen
# oder fĂŒr Effekte, wo DauerĂ€nderung auch okay ist.
#
# FĂŒr diesen Anwendungsfall ist ein direkter Fallback auf librosa.effects.pitch_shift
# wahrscheinlich pragmatischer, wenn pyrubberband nicht verfĂŒgbar ist.
# Die obige audiotsm-Logik war ein Versuch, es abzubilden, aber die KomplexitÀt/QualitÀts-Tradeoffs
# sind signifikant.
logger.debug(f"AUDIOTSM Pitch-Shift-Implementierung hier vereinfacht auf Fallback zu Librosa, da die korrekte Kombination von TSM und Resample fĂŒr reines Pitch-Shifting komplex ist und die QualitĂ€t variieren kann.")
pitched_audio_librosa = librosa.effects.pitch_shift(y=audio_array_copy, sr=sr, n_steps=n_steps, bins_per_octave=24) # Erhöhe bins_per_octave
return pitched_audio_librosa if len(pitched_audio_librosa) > 0 else None
except Exception as e_audiotsm:
logger.warning(f"audiotsm-basierter Pitch Shift Versuch fehlgeschlagen: {e_audiotsm}. Fallback auf librosa.")
try:
pitched_audio_librosa = librosa.effects.pitch_shift(y=audio_array_copy, sr=sr, n_steps=n_steps, bins_per_octave=24)
return pitched_audio_librosa if len(pitched_audio_librosa) > 0 else None
except Exception as e_librosa:
logger.error(f"librosa.effects.pitch_shift (Fallback) failed for steps {n_steps}: {e_librosa}")
return None
else: # audiotsm nicht verfĂŒgbar oder Audio nicht 1D, nur librosa verwenden
try:
# Erhöhe bins_per_octave fĂŒr potenziell bessere QualitĂ€t mit librosa
pitched_audio = librosa.effects.pitch_shift(y=audio_array_copy, sr=sr, n_steps=n_steps, bins_per_octave=24)
return pitched_audio if len(pitched_audio) > 0 else None
except Exception as e:
logger.error(f"librosa.effects.pitch_shift failed for steps {n_steps}: {e}")
return None
def change_volume(audio_array: np.ndarray, gain_db: float) -> Optional[np.ndarray]:
if gain_db == 0.0: return audio_array.copy()
if not (-20.0 <= gain_db <= 20.0):
logger.debug(f"Volume gain_db {gain_db} out of range [-20, 20]. Skipping.")
return None
try:
gain_factor = np.power(10, gain_db / 20.0)
vol_changed_audio = audio_array * gain_factor
return vol_changed_audio
except Exception as e:
logger.warning(f"change_volume failed for gain_db {gain_db}: {e}")
return None
# --- Ende Augmentierungsfunktionen ---
def preprocess_sample_wrapper(
task_input: Tuple[Dict[str, Any], int, int],
text_tokenizer: Any,
speech_tokenizer: S3Tokenizer,
voice_encoder: Any,
t3_config_from_model: T3Config,
max_text_len_arg: int,
max_speech_len_arg: int,
audio_prompt_duration_s: float,
output_dir: Path,
silence_padding_ms: int,
do_augment: bool,
arg_speeds: List[float],
arg_pitches: List[float],
arg_volumes_db: List[float]
) -> List[Tuple[int, bool, str]]:
original_pair, original_file_index, global_sample_offset = task_input
audio_fpath = original_pair["audio"]
text_content = original_pair["text"]
augmentation_results = []
processed_sample_counter_for_this_file = 0
param_combinations: List[Tuple[float, float, float]] = []
param_combinations.append((1.0, 0.0, 0.0))
if do_augment:
for speed_val in arg_speeds: param_combinations.append((speed_val, 0.0, 0.0))
for pitch_val in arg_pitches: param_combinations.append((1.0, pitch_val, 0.0))
for volume_val in arg_volumes_db: param_combinations.append((1.0, 0.0, volume_val))
param_combinations = sorted(list(set(param_combinations)))
logger.debug(f"OrigIdx {original_file_index} ({audio_fpath.name}): {len(param_combinations)} individual augmentations planned (incl. original).")
wav_array_orig_sr_float32: Optional[np.ndarray] = None
sr_orig_raw: Optional[int] = None
try:
wav_array_orig_sr_float32, sr_orig_raw_loaded = librosa.load(str(audio_fpath), sr=None, mono=True, dtype=np.float32)
if wav_array_orig_sr_float32 is None or len(wav_array_orig_sr_float32) == 0:
raise ValueError("Geladene Audiodatei ist leer oder ungĂŒltig.")
sr_orig_raw = sr_orig_raw_loaded
except Exception as e_load:
logger.error(f"Error loading original audio {audio_fpath}: {e_load}", exc_info=True)
num_expected_variants_for_this_file = 1
if do_augment:
num_expected_variants_for_this_file += len(arg_speeds) + len(arg_pitches) + len(arg_volumes_db)
for i in range(num_expected_variants_for_this_file):
augmentation_results.append((global_sample_offset + i, False, f"{audio_fpath.name}_original_load_exception"))
return augmentation_results
test_audio_output_dir = None
if original_file_index == 0:
test_audio_output_dir = output_dir / "test_audio_samples_debug"
test_audio_output_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"Saving test audios for {audio_fpath.name} (OriginalIndex 0) in {test_audio_output_dir}")
for speed_factor, pitch_steps, gain_db in param_combinations:
current_global_sample_index = global_sample_offset + processed_sample_counter_for_this_file
is_original_params = (speed_factor == 1.0 and pitch_steps == 0.0 and gain_db == 0.0)
aug_name_parts = []
if not is_original_params:
if speed_factor != 1.0: aug_name_parts.append(f"s{str(speed_factor).replace('.', 'p')}")
if pitch_steps != 0.0: aug_name_parts.append(f"p{str(pitch_steps).replace('.', 'p').replace('-', 'm')}")
if gain_db != 0.0: aug_name_parts.append(f"v{str(gain_db).replace('.', 'p').replace('-', 'm')}")
aug_name = "_".join(aug_name_parts)
else:
aug_name = "original"
if not aug_name: aug_name = f"s{str(speed_factor).replace('.', 'p')}_p{str(pitch_steps).replace('.', 'p').replace('-', 'm')}_v{str(gain_db).replace('.', 'p').replace('-', 'm')}"
try:
if sr_orig_raw is None or wav_array_orig_sr_float32 is None :
raise RuntimeError("Original-Samplerate oder Audiodaten nicht korrekt geladen.")
current_audio_to_process = wav_array_orig_sr_float32.copy()
current_sr = sr_orig_raw
if speed_factor != 1.0:
augmented_audio = change_speed(current_audio_to_process, current_sr, speed_factor)
if augmented_audio is None: raise ValueError(f"Speed change (factor {speed_factor}) failed.")
current_audio_to_process = augmented_audio
if pitch_steps != 0.0:
augmented_audio = change_pitch(current_audio_to_process, current_sr, pitch_steps)
if augmented_audio is None: raise ValueError(f"Pitch change ({pitch_steps} steps) failed.")
current_audio_to_process = augmented_audio
if gain_db != 0.0:
augmented_audio = change_volume(current_audio_to_process, gain_db)
if augmented_audio is None: raise ValueError(f"Volume change ({gain_db}dB) failed.")
current_audio_to_process = augmented_audio
if current_sr != S3_SR:
final_audio_for_preprocess = librosa.resample(y=current_audio_to_process, orig_sr=current_sr, target_sr=S3_SR)
else:
final_audio_for_preprocess = current_audio_to_process
if len(final_audio_for_preprocess) == 0:
raise ValueError("Audio nach Augmentierung/Resampling leer.")
if test_audio_output_dir and original_file_index == 0:
test_audio_filename = f"{audio_fpath.stem}_AUG_{aug_name}_origidx{original_file_index}_gidx{current_global_sample_index}_SR{S3_SR}.wav"
try:
sf.write(str(test_audio_output_dir / test_audio_filename), final_audio_for_preprocess, S3_SR, subtype='FLOAT')
except Exception as e_sf:
logger.warning(f"Could not save test audio {test_audio_filename}: {e_sf}")
processed_data = preprocess_sample(
audio_array_input=final_audio_for_preprocess,
original_sr_input_after_augmentation=S3_SR,
audio_fpath_for_logging=audio_fpath, text=text_content,
text_tokenizer=text_tokenizer, speech_tokenizer=speech_tokenizer,
voice_encoder=voice_encoder, t3_config_from_model=t3_config_from_model,
max_text_len_arg=max_text_len_arg, max_speech_len_arg=max_speech_len_arg,
audio_prompt_duration_s=audio_prompt_duration_s, silence_padding_ms=silence_padding_ms,
)
if processed_data:
base_filename = audio_fpath.stem
safe_base_filename = "".join(c if c.isalnum() or c in ['-', '_'] else '_' for c in base_filename)
output_filename_pt = f"{safe_base_filename}_aug_{aug_name}_{current_global_sample_index:07d}.pt"
torch.save(processed_data, output_dir / output_filename_pt)
augmentation_results.append((current_global_sample_index, True, output_filename_pt))
else:
err_msg = f"{audio_fpath.name}_aug_{aug_name}_preprocess_failed"
augmentation_results.append((current_global_sample_index, False, err_msg))
except ValueError as ve:
logger.warning(f"Augmentation variant '{aug_name}' for {audio_fpath.name} (GlobalIdx {current_global_sample_index}) failed: {ve}. Skipping.")
augmentation_results.append((current_global_sample_index, False, f"{audio_fpath.name}_aug_{aug_name}_value_error"))
except RuntimeError as rterr:
logger.error(f"Runtime error for {audio_fpath.name} (Aug: {aug_name}, GlobalIdx {current_global_sample_index}): {rterr}", exc_info=False)
augmentation_results.append((current_global_sample_index, False, f"{audio_fpath.name}_aug_{aug_name}_runtime_error"))
except Exception as e_outer:
logger.error(f"Unexpected error for {audio_fpath.name} (Aug: {aug_name}, GlobalIdx {current_global_sample_index}): {e_outer}", exc_info=True)
augmentation_results.append((current_global_sample_index, False, f"{audio_fpath.name}_aug_{aug_name}_exception"))
finally:
if 'current_audio_to_process' in locals() and current_audio_to_process is not None: del current_audio_to_process
if 'final_audio_for_preprocess' in locals() and final_audio_for_preprocess is not None: del final_audio_for_preprocess
if 'augmented_audio' in locals() and augmented_audio is not None: del augmented_audio
processed_sample_counter_for_this_file += 1
if processed_sample_counter_for_this_file >= 150:
logger.warning(f"Max augmentation limit (150) for {audio_fpath.name} reached.")
remaining_planned_slots = len(param_combinations) - processed_sample_counter_for_this_file
for i in range(remaining_planned_slots):
augmentation_results.append((global_sample_offset + processed_sample_counter_for_this_file + i, False, f"{audio_fpath.name}_max_aug_limit_reached"))
break
expected_results_count = len(param_combinations)
while len(augmentation_results) < expected_results_count:
missing_idx = len(augmentation_results)
augmentation_results.append((global_sample_offset + missing_idx, False, f"{audio_fpath.name}_processing_loop_incomplete_at_idx_{missing_idx}"))
# logger.debug(f"Padded missing result for {audio_fpath.name} at local variant index {missing_idx}")
if wav_array_orig_sr_float32 is not None: del wav_array_orig_sr_float32
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
return augmentation_results
def preprocess_sample(
audio_array_input: np.ndarray,
original_sr_input_after_augmentation: int,
audio_fpath_for_logging: Path,
text: str,
text_tokenizer: Any,
speech_tokenizer: S3Tokenizer,
voice_encoder: Any,
t3_config_from_model: T3Config,
max_text_len_arg: int,
max_speech_len_arg: int,
audio_prompt_duration_s: float,
silence_padding_ms: int,
) -> Optional[Dict[str, torch.Tensor]]:
try:
if len(audio_array_input) == 0: return None
if original_sr_input_after_augmentation != S3_SR:
wav_16k_final = librosa.resample(y=audio_array_input, orig_sr=original_sr_input_after_augmentation, target_sr=S3_SR)
else:
wav_16k_final = audio_array_input.copy()
if wav_16k_final.ndim > 1: wav_16k_final = librosa.to_mono(wav_16k_final)
if wav_16k_final.dtype != np.float32: wav_16k_final = wav_16k_final.astype(np.float32)
if len(wav_16k_final) == 0: return None
if silence_padding_ms > 0:
padding_samples = int((silence_padding_ms / 1000.0) * S3_SR)
silence_array = np.zeros(padding_samples, dtype=np.float32)
wav_16k_padded_silence = np.concatenate((silence_array, wav_16k_final, silence_array))
else:
wav_16k_padded_silence = wav_16k_final
if len(wav_16k_padded_silence) == 0: return None
MAX_AUDIO_SAMPLES_FOR_S3TOKENIZER_INPUT = int(18.0 * S3_SR)
wav_16k_for_tokenizer_input = wav_16k_padded_silence[:MAX_AUDIO_SAMPLES_FOR_S3TOKENIZER_INPUT]
MIN_VE_AUDIO_LEN_SAMPLES = S3_SR // 10
if len(wav_16k_padded_silence) < MIN_VE_AUDIO_LEN_SAMPLES:
padding_needed = MIN_VE_AUDIO_LEN_SAMPLES - len(wav_16k_padded_silence)
wav_for_ve = np.pad(wav_16k_padded_silence, (0, padding_needed), 'constant')
else:
wav_for_ve = wav_16k_padded_silence
speaker_emb_np = voice_encoder.embeds_from_wavs([wav_for_ve], sample_rate=S3_SR)
speaker_emb = torch.from_numpy(speaker_emb_np[0]).cpu()
start_text_token_id = t3_config_from_model.start_text_token
stop_text_token_id = t3_config_from_model.stop_text_token
start_speech_token_id = t3_config_from_model.start_speech_token
stop_speech_token_id = t3_config_from_model.stop_speech_token
speech_prompt_len_from_config = t3_config_from_model.speech_cond_prompt_len
MIN_TOKENIZER_AUDIO_LEN_SAMPLES = S3_SR // 20
if len(wav_16k_for_tokenizer_input) < MIN_TOKENIZER_AUDIO_LEN_SAMPLES: return None
raw_speech_tokens_batch, speech_token_lengths_batch = speech_tokenizer.forward([wav_16k_for_tokenizer_input])
if raw_speech_tokens_batch is None or speech_token_lengths_batch is None or speech_token_lengths_batch.squeeze(0).item() == 0: return None
raw_speech_tokens = raw_speech_tokens_batch.cpu().squeeze(0)[:speech_token_lengths_batch.cpu().squeeze(0).item()]
speech_tokens_unpadded_full_sequence = F.pad(raw_speech_tokens, (1, 0), value=start_speech_token_id)
speech_tokens_unpadded_full_sequence = F.pad(speech_tokens_unpadded_full_sequence, (0, 1), value=stop_speech_token_id)
current_total_speech_len_incl_start_stop = len(speech_tokens_unpadded_full_sequence)
MIN_PREDICTABLE_CONTENT_TOKENS = 1
if current_total_speech_len_incl_start_stop < (1 + MIN_PREDICTABLE_CONTENT_TOKENS + 1): return None
speech_token_actual_len = torch.tensor(current_total_speech_len_incl_start_stop, dtype=torch.long).cpu()
padding_token_id_speech = stop_speech_token_id
if current_total_speech_len_incl_start_stop > max_speech_len_arg:
padded_speech_tokens = speech_tokens_unpadded_full_sequence[:max_speech_len_arg-1]
padded_speech_tokens = torch.cat([padded_speech_tokens, torch.tensor([stop_speech_token_id],dtype=torch.long).cpu()])
speech_token_actual_len = torch.tensor(max_speech_len_arg, dtype=torch.long).cpu()
else:
padded_speech_tokens = F.pad(speech_tokens_unpadded_full_sequence, (0, max_speech_len_arg - current_total_speech_len_incl_start_stop), value=padding_token_id_speech)
if padded_speech_tokens.size(0) != max_speech_len_arg: return None
enc_cond_audio_len_samples_prompt = int(audio_prompt_duration_s * S3_SR)
cond_audio_segment_for_prompt_np = wav_16k_padded_silence[:enc_cond_audio_len_samples_prompt]
target_prompt_tensor_len = speech_prompt_len_from_config
cond_prompt_speech_tokens_padded: torch.Tensor
if len(cond_audio_segment_for_prompt_np) < MIN_TOKENIZER_AUDIO_LEN_SAMPLES:
cond_prompt_unpadded_temp = torch.full((1,), start_speech_token_id, dtype=torch.long).cpu()
else:
try:
cond_prompt_tokens_batch, cond_prompt_lens_batch = speech_tokenizer.forward([cond_audio_segment_for_prompt_np], max_len=None)
if cond_prompt_tokens_batch is not None and cond_prompt_lens_batch is not None and cond_prompt_lens_batch.squeeze(0).item() > 0:
cond_prompt_unpadded_temp = cond_prompt_tokens_batch.cpu().squeeze(0)[:cond_prompt_lens_batch.cpu().squeeze(0).item()]
else:
cond_prompt_unpadded_temp = torch.full((1,), start_speech_token_id, dtype=torch.long).cpu()
except Exception:
cond_prompt_unpadded_temp = torch.full((1,), start_speech_token_id, dtype=torch.long).cpu()
current_prompt_tokens_len = cond_prompt_unpadded_temp.size(0)
if current_prompt_tokens_len > target_prompt_tensor_len:
cond_prompt_speech_tokens_padded = cond_prompt_unpadded_temp[:target_prompt_tensor_len]
elif current_prompt_tokens_len < target_prompt_tensor_len:
cond_prompt_speech_tokens_padded = F.pad(cond_prompt_unpadded_temp, (0, target_prompt_tensor_len - current_prompt_tokens_len), value=start_speech_token_id)
else:
cond_prompt_speech_tokens_padded = cond_prompt_unpadded_temp
if cond_prompt_speech_tokens_padded.size(0) != target_prompt_tensor_len: return None
normalized_text = punc_norm(text)
if not normalized_text.strip(): return None
raw_text_tokens_for_text = text_tokenizer.text_to_tokens(normalized_text).cpu().squeeze(0)
text_tokens_unpadded_full_text = F.pad(raw_text_tokens_for_text, (1, 0), value=start_text_token_id)
text_tokens_unpadded_full_text = F.pad(text_tokens_unpadded_full_text, (0, 1), value=stop_text_token_id)
current_text_len_unpadded = len(text_tokens_unpadded_full_text)
text_token_actual_len = torch.tensor(current_text_len_unpadded, dtype=torch.long).cpu()
padding_token_id_text = stop_text_token_id
if current_text_len_unpadded > max_text_len_arg:
padded_text_tokens = text_tokens_unpadded_full_text[:max_text_len_arg-1]
padded_text_tokens = torch.cat([padded_text_tokens, torch.tensor([stop_text_token_id], dtype=torch.long).cpu()])
text_token_actual_len = torch.tensor(max_text_len_arg, dtype=torch.long).cpu()
else:
padded_text_tokens = F.pad(text_tokens_unpadded_full_text, (0, max_text_len_arg - current_text_len_unpadded), value=padding_token_id_text)
if padded_text_tokens.size(0) != max_text_len_arg: return None
emotion_adv_scalar_tensor = torch.tensor(0.5, dtype=torch.float).cpu()
return {
"text_tokens": padded_text_tokens.long().cpu(),
"text_token_lens": text_token_actual_len.long().cpu(),
"speech_tokens": padded_speech_tokens.long().cpu(),
"speech_token_lens": speech_token_actual_len.long().cpu(),
"t3_cond_speaker_emb": speaker_emb.float().cpu(),
"t3_cond_prompt_speech_tokens": cond_prompt_speech_tokens_padded.long().cpu(),
"t3_cond_emotion_adv": emotion_adv_scalar_tensor.cpu()
}
except Exception as e:
logger.error(f"Critical error in preprocess_sample for {audio_fpath_for_logging} (Text: '{text[:30]}...'): {e}", exc_info=False)
return None
def main():
parser = argparse.ArgumentParser(description="Preprocess TTS data with augmentation.")
# ... (Argumente bleiben gleich) ...
parser.add_argument("--model_load_path", type=str, required=True)
parser.add_argument("--is_hub_model", action="store_true")
parser.add_argument("--dataset_dir", type=str)
parser.add_argument("--metadata_file", type=str)
parser.add_argument("--output_dir", type=str, required=True)
parser.add_argument("--max_text_len", type=int, default=2048)
parser.add_argument("--max_speech_len", type=int, default=4096)
parser.add_argument("--audio_prompt_duration_s", type=float, default=3.0)
parser.add_argument("--cache_dir", type=str, default=None)
parser.add_argument("--num_workers", type=int, default=None)
parser.add_argument("--audio_ext", type=str, default="wav")
parser.add_argument("--text_ext", type=str, default="txt")
parser.add_argument("--silence_padding_ms", type=int, default=300, help="Silence in ms to add at start/end of audio *after* augmentations and resampling, before tokenization.")
parser.add_argument("--augment_data", action="store_true", help="Enable audio data augmentation.")
parser.add_argument("--aug_speeds", type=float, nargs='*', default=[], help="Speed factors (e.g., 0.9 1.1). Neutral 1.0 ignored.")
parser.add_argument("--aug_pitches", type=float, nargs='*', default=[], help="Pitch shifts in semitones (e.g., -1 1). Neutral 0.0 ignored.")
parser.add_argument("--aug_volumes_db", type=float, nargs='*', default=[], help="Volume changes dB (e.g., -3 3). Neutral 0.0 ignored.")
args = parser.parse_args()
output_dir_path_obj = Path(args.output_dir)
output_dir_path_obj.mkdir(parents=True, exist_ok=True)
valid_speeds = sorted(list(set(s for s in args.aug_speeds if 0.5 <= s <= 2.0 and s != 1.0)))
valid_pitches = sorted(list(set(p for p in args.aug_pitches if -12.0 <= p <= 12.0 and p != 0.0)))
valid_volumes_db = sorted(list(set(v for v in args.aug_volumes_db if -20.0 <= v <= 20.0 and v != 0.0)))
augmentation_params_for_wrapper = {
"speeds": valid_speeds,
"pitches": valid_pitches,
"volumes_db": valid_volumes_db,
}
pitch_method_used = "librosa.effects.pitch_shift (bins_per_octave=24)"
if AUDIOTSM_AVAILABLE:
# Hinweis: Auch wenn AUDIOTSM verfĂŒgbar ist, habe ich die Implementierung in change_pitch
# so geĂ€ndert, dass sie derzeit auf librosa zurĂŒckfĂ€llt, da die audiotsm-Variante fĂŒr
# reines Pitch-Shifting komplexer war als ursprĂŒnglich gedacht und die QualitĂ€t variieren kann.
# Wenn du eine dedizierte audiotsm Pitch-Shift-Lösung einbauen willst, muss change_pitch angepasst werden.
# Aktuell ist die Meldung etwas irrefĂŒhrend, da change_pitch intern entscheidet.
logger.info(f"audiotsm ist verfĂŒgbar. Pitch Shifting wird intern librosa.effects.pitch_shift (bins_per_octave=24) verwenden, da die audiotsm-Pitch-Shift-Logik vereinfacht wurde.")
if args.augment_data:
logger.info(f"Data augmentation ENABLED. Pitch shifting verwendet: {pitch_method_used}.")
# ... (Rest der Log-Ausgaben bleibt gleich) ...
logger.info(f" Individual augmentations (plus original):")
if augmentation_params_for_wrapper["speeds"]: logger.info(f" Speeds: {augmentation_params_for_wrapper['speeds']}")
if augmentation_params_for_wrapper["pitches"]: logger.info(f" Pitches (semitones): {augmentation_params_for_wrapper['pitches']}")
if augmentation_params_for_wrapper["volumes_db"]: logger.info(f" Volumes (dB): {augmentation_params_for_wrapper['volumes_db']}")
if not any(augmentation_params_for_wrapper.values()):
logger.warning("WARNING: --augment_data set, but no valid non-neutral augmentation parameters provided. Only originals processed.")
else:
logger.info("Data augmentation DEACTIVATED. Only original samples processed.")
augmentation_params_for_wrapper = {"speeds": [], "pitches": [], "volumes_db": []}
# ... (Rest der main-Funktion bleibt unverÀndert bis zum Ende) ...
logger.info(f"Loading ChatterboxTTS model from: {args.model_load_path} ...")
model_source_path: str
if args.is_hub_model:
from huggingface_hub import snapshot_download
try: model_source_path = snapshot_download(repo_id=args.model_load_path, cache_dir=args.cache_dir, allow_patterns=["*.safetensors", "*.json", "*.pt", "config.json", "*.model"])
except Exception as e: logger.error(f"Hub Download Error: {e}"); return
else:
model_source_path = args.model_load_path
if not Path(model_source_path).is_dir(): logger.error(f"Local path not found: {model_source_path}"); return
chatterbox_model: ChatterboxTTS
try:
chatterbox_model = ChatterboxTTS.from_local(
ckpt_dir_for_model_weights=model_source_path,
base_model_dir_for_static_components=model_source_path,
device="cpu",
target_t3_config_for_model=None
)
chatterbox_model.t3.to("cpu")
chatterbox_model.s3gen.to("cpu")
chatterbox_model.ve.to("cpu")
if hasattr(chatterbox_model, 'tokenizer') and hasattr(chatterbox_model.tokenizer, 'to'):
chatterbox_model.tokenizer.to("cpu")
config_used = chatterbox_model.t3.hp
logger.info(f"Using T3-Config (Base): PromptLen={getattr(config_used, 'speech_cond_prompt_len', 'N/A')}")
logger.info(f"ChatterboxTTS model and components loaded to CPU.")
except Exception as e: logger.error(f"Model loading error: {e}", exc_info=True); return
text_tokenizer_instance = chatterbox_model.tokenizer
speech_tokenizer_instance = chatterbox_model.s3gen.tokenizer
voice_encoder_instance = chatterbox_model.ve
t3_config_from_loaded_model = chatterbox_model.t3.hp
file_pairs: List[Dict[str, Any]] = []
if args.metadata_file:
metadata_p = Path(args.metadata_file).resolve(); dataset_root = metadata_p.parent
logger.info(f"Reading metadata from: {metadata_p}")
try:
with open(metadata_p, 'r', encoding='utf-8') as f:
for i, line in enumerate(f):
parts = line.strip().split('|'); text_content_meta = ""
if len(parts) < 2: parts = line.strip().split('\t')
if len(parts) >= 1: audio_file_rel = parts[0]
if len(parts) >= 2: text_content_meta = "|".join(parts[1:])
else: logger.debug(f"Line {i+1} in meta has no text. Skip."); continue
audio_p_candidate1 = dataset_root / audio_file_rel
audio_p_candidate2 = Path(audio_file_rel)
audio_p_resolved = None
if audio_p_candidate1.exists() and audio_p_candidate1.is_file():
audio_p_resolved = audio_p_candidate1.resolve()
elif audio_p_candidate2.exists() and audio_p_candidate2.is_file() and audio_p_candidate2.is_absolute():
audio_p_resolved = audio_p_candidate2.resolve()
if audio_p_resolved: file_pairs.append({"audio": audio_p_resolved, "text": text_content_meta.strip()})
else: logger.debug(f"Audio not found for meta entry '{audio_file_rel}'. Tried {audio_p_candidate1}, {audio_p_candidate2}. Skip.")
except Exception as e_meta: logger.error(f"Meta loading error: {e_meta}", exc_info=True); return
elif args.dataset_dir:
dataset_p = Path(args.dataset_dir).resolve(); logger.info(f"Searching for pairs in: {dataset_p}")
if not dataset_p.is_dir(): logger.error(f"Dataset directory not found: {dataset_p}"); return
found_audio_files = list(dataset_p.rglob(f"*.{args.audio_ext}"))
for audio_path_obj in tqdm(found_audio_files, desc="Checking audio/text pairs"):
text_path = audio_path_obj.with_suffix(f".{args.text_ext}")
if text_path.exists():
try:
text_content = text_path.read_text(encoding='utf-8').strip()
if text_content: file_pairs.append({"audio": audio_path_obj.resolve(), "text": text_content})
else: logger.debug(f"Text file {text_path} is empty. Skip.")
except Exception as e_read_text:
logger.warning(f"Could not read text file {text_path}: {e_read_text}. Skip.")
else: logger.error("Neither --dataset_dir nor --metadata_file provided. Aborting."); return
if not file_pairs: logger.error("No audio/text pairs found. Aborting."); return
logger.info(f"Found {len(file_pairs)} original audio-text pairs. Starting preprocessing...")
num_workers_requested = args.num_workers if args.num_workers is not None and args.num_workers > 0 else os.cpu_count()
if num_workers_requested is None: num_workers_requested = 1
num_aug_variants_per_file = 1
if args.augment_data:
num_aug_variants_per_file += len(augmentation_params_for_wrapper["speeds"])
num_aug_variants_per_file += len(augmentation_params_for_wrapper["pitches"])
num_aug_variants_per_file += len(augmentation_params_for_wrapper["volumes_db"])
if num_aug_variants_per_file == 0 : num_aug_variants_per_file = 1
num_workers = min(num_workers_requested, len(file_pairs), 60 if os.name == 'nt' else (os.cpu_count() or 1) * 5)
if num_workers <= 0: num_workers = 1
logger.info(f"Using {num_workers} worker processes. Each original file generates up to {num_aug_variants_per_file} samples.")
tasks_for_executor = []
current_global_offset_for_filenames = 0
for i, pair in enumerate(file_pairs):
tasks_for_executor.append((pair, i, current_global_offset_for_filenames))
current_global_offset_for_filenames += num_aug_variants_per_file
logger.info(f"Total expected samples to attempt (incl. augmentations): {current_global_offset_for_filenames}")
process_fn_partial = functools.partial(
preprocess_sample_wrapper,
text_tokenizer=text_tokenizer_instance,
speech_tokenizer=speech_tokenizer_instance,
voice_encoder=voice_encoder_instance,
t3_config_from_model=t3_config_from_loaded_model,
max_text_len_arg=args.max_text_len,
max_speech_len_arg=args.max_speech_len,
audio_prompt_duration_s=args.audio_prompt_duration_s,
output_dir=output_dir_path_obj,
silence_padding_ms=args.silence_padding_ms,
do_augment=args.augment_data,
arg_speeds=augmentation_params_for_wrapper["speeds"],
arg_pitches=augmentation_params_for_wrapper["pitches"],
arg_volumes_db=augmentation_params_for_wrapper["volumes_db"]
)
total_successful_pt_files = 0
total_failed_or_skipped_variants = 0
total_original_files_submitted_to_workers = 0
executor = None
try:
executor = ProcessPoolExecutor(max_workers=num_workers)
futures = []
for task_input_tuple in tasks_for_executor:
futures.append(executor.submit(process_fn_partial, task_input_tuple))
total_original_files_submitted_to_workers +=1
for future in tqdm(as_completed(futures), total=len(futures), desc="Processing original files (incl. augmentations)"):
try:
augmentation_batch_results = future.result()
for _, success, _ in augmentation_batch_results:
if success:
total_successful_pt_files += 1
else:
total_failed_or_skipped_variants += 1
except Exception as e_future:
logger.error(f"A worker process failed critically: {e_future}", exc_info=True)
total_failed_or_skipped_variants += num_aug_variants_per_file
except Exception as e_pool:
logger.error(f"Error occurred in ProcessPoolExecutor: {e_pool}", exc_info=True)
finally:
if executor:
logger.info("Shutting down worker processes...")
executor.shutdown(wait=True)
del chatterbox_model, text_tokenizer_instance, speech_tokenizer_instance, voice_encoder_instance
del t3_config_from_loaded_model, file_pairs, tasks_for_executor, process_fn_partial
gc.collect()
logger.info("Main process cleanup done.")
logger.info(f"Preprocessing COMPLETE.")
logger.info(f" {total_original_files_submitted_to_workers} original files submitted to workers.")
logger.info(f" Successfully created {total_successful_pt_files} .pt files.")
if total_failed_or_skipped_variants > 0:
logger.warning(f" {total_failed_or_skipped_variants} variants failed or were skipped.")
logger.info(f" Saved .pt files in: '{args.output_dir}'.")
test_audio_dir_check = output_dir_path_obj / "test_audio_samples_debug"
if test_audio_dir_check.exists() and any(test_audio_dir_check.iterdir()):
logger.info(f" Test audio samples saved in: '{test_audio_dir_check}'.")
if __name__ == "__main__":
main()
finetune_t3_preprocessed.py: - the gpu magic starts and heats up your room
# finetune_t3_preprocessed.py
import argparse
import logging
import os
import json
from pathlib import Path
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Union, Any
import shutil
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
import librosa
import numpy as np
import soundfile as sf
from torch.utils.tensorboard import SummaryWriter
import gc # HinzugefĂŒgt fĂŒr Speicherbereinigung
# --- BEGIN MODIFICATION FOR PICKLE ERROR ---
try:
import numpy as pynumpy_for_pickle # GeÀnderter Alias
import torch.serialization
safe_globals_to_add = [
pynumpy_for_pickle.core.multiarray._reconstruct,
pynumpy_for_pickle.dtype,
pynumpy_for_pickle.ndarray,
]
try:
safe_globals_to_add.append(pynumpy_for_pickle.dtypes.UInt32DType)
except AttributeError:
logging.warning("numpy.dtypes.UInt32DType not found, not adding to safe globals.")
torch.serialization.add_safe_globals(safe_globals_to_add)
logging.info(
f"Attempted to add NumPy types to torch.serialization.get_safe_globals() for checkpoint loading: "
f"{[g.__name__ if hasattr(g, '__name__') else str(g) for g in safe_globals_to_add]}"
)
except ImportError:
logging.warning("Could not import numpy or torch.serialization to add safe globals.")
except AttributeError as e:
logging.warning(f"General AttributeError while trying to add safe globals: {e}.")
# --- END MODIFICATION FOR PICKLE ERROR ---
from transformers import (
HfArgumentParser,
EarlyStoppingCallback,
set_seed,
Trainer,
PretrainedConfig,
TrainerCallback, # HinzugefĂŒgt
TrainerState, # HinzugefĂŒgt
TrainerControl # HinzugefĂŒgt
)
from transformers import TrainingArguments as HfTrainingArguments
from datasets import load_dataset, DatasetDict, VerificationMode, Audio
import datasets
from chatterbox.tts import ChatterboxTTS, Conditionals, punc_norm, REPO_ID
from chatterbox.models.t3.t3 import T3, T3Cond
from chatterbox.models.t3.modules.t3_config import T3Config
from chatterbox.models.s3tokenizer import S3_SR, S3Tokenizer
logger = logging.getLogger(__name__)
# --- Detailed Memory Clearing Callback ---
logger_callback = logging.getLogger(__name__ + ".DetailedMemoryClearingCallback")
class DetailedMemoryClearingCallback(TrainerCallback):
def _log_memory_usage(self, stage: str, device: Union[str, torch.device]):
if torch.cuda.is_available():
torch.cuda.synchronize(device) # Wichtig fĂŒr akkurate Messungen
allocated = torch.cuda.memory_allocated(device) / (1024 ** 2)
reserved = torch.cuda.memory_reserved(device) / (1024 ** 2)
# max_allocated = torch.cuda.max_memory_allocated(device) / (1024 ** 2) # RĂŒcksetzen fĂŒr nĂ€chste Phase
# max_reserved = torch.cuda.max_memory_reserved(device) / (1024 ** 2) # RĂŒcksetzen fĂŒr nĂ€chste Phase
# torch.cuda.reset_peak_memory_stats(device) # Max-Werte fĂŒr die nĂ€chste Phase zurĂŒcksetzen
logger_callback.info(
f"CUDA Memory ({stage}) on device {device}: \n"
f" Allocated: {allocated:.2f} MB\n"
f" Reserved: {reserved:.2f} MB"
# f" Max Allocated (since last reset): {max_allocated:.2f} MB\n"
# f" Max Reserved (since last reset): {max_reserved:.2f} MB"
)
# Detailliertere Infos bei Bedarf:
# logger_callback.debug(torch.cuda.memory_summary(device=device, abbreviated=False))
def _clear_memory(self, stage: str, device: Union[str, torch.device]):
logger_callback.info(f"Attempting to clear memory ({stage}) on device {device}...")
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger_callback.info(f"Memory clearing attempt ({stage}) finished.")
def on_train_begin(self, args: HfTrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
if state.is_world_process_zero:
logger_callback.info("Training started. Initial memory state.")
self._log_memory_usage("on_train_begin", args.device)
def on_epoch_begin(self, args: HfTrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
if state.is_world_process_zero:
self._log_memory_usage(f"Epoch {state.epoch:.0f} begin, before cleanup", args.device)
self._clear_memory(f"Epoch {state.epoch:.0f} begin", args.device)
self._log_memory_usage(f"Epoch {state.epoch:.0f} begin, after cleanup", args.device)
def on_step_begin(self, args: HfTrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
# Nur loggen, wenn ein Logging-Step erreicht ist, um die Logs nicht zu ĂŒberfluten
if state.is_world_process_zero and state.global_step > 0 and \
args.logging_steps > 0 and state.global_step % args.logging_steps == 0:
self._log_memory_usage(f"Step {state.global_step} begin", args.device)
def on_evaluate(self, args: HfTrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
if state.is_world_process_zero:
self._log_memory_usage("Before on_evaluate cleanup", args.device)
self._clear_memory("on_evaluate", args.device)
self._log_memory_usage("After on_evaluate cleanup", args.device)
def on_log(self, args: HfTrainingArguments, state: TrainerState, control: TrainerControl, logs=None, **kwargs):
if state.is_world_process_zero and logs is not None:
is_eval_log = any(k.startswith("eval_") for k in logs)
current_stage = "on_log (post-eval)" if is_eval_log else f"on_log (step {state.global_step})"
self._log_memory_usage(f"Before {current_stage} cleanup", args.device)
self._clear_memory(current_stage, args.device)
self._log_memory_usage(f"After {current_stage} cleanup", args.device)
def on_save(self, args: HfTrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
if state.is_world_process_zero:
self._log_memory_usage("Before on_save cleanup", args.device)
self._clear_memory("on_save", args.device)
self._log_memory_usage("After on_save cleanup", args.device)
def on_train_end(self, args: HfTrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
if state.is_world_process_zero:
self._log_memory_usage("Before on_train_end cleanup", args.device)
self._clear_memory("on_train_end", args.device)
self._log_memory_usage("After on_train_end cleanup", args.device)
# --- Argument Klassen ---
@dataclass
class CustomTrainingArguments(HfTrainingArguments):
early_stopping_patience: Optional[int] = field(
default=None, metadata={"help": "Enable early stopping with specified patience. Default: None (disabled)."}
)
generate_samples_on_finish: bool = field(
default=True, metadata={"help": "Generate audio samples after training/evaluation."}
)
num_samples_to_generate: int = field(
default=3, metadata={"help": "Number of audio samples to generate PER CHECKPOINT."}
)
sample_texts: Optional[List[str]] = field(
default_factory=lambda: ["Hallo Welt, dies ist ein Test.", "Wie klingt meine Stimme nach dem Training?", "Die kĂŒnstliche Intelligenz macht Fortschritte."],
metadata={"help": "List of texts to use for generating audio samples."}
)
generate_from_all_checkpoints: bool = field(
default=False, metadata={"help": "If true, iterate through all checkpoint subfolders to generate samples, not just the final/best model."}
)
# HfTrainingArguments hat bereits per_device_eval_batch_size.
# Wenn wir es hier mit einem anderen Default definieren, ĂŒberschreiben wir es.
# Der Standard in HfTrainingArguments ist 8. FĂŒr TTS ist 1 oder 2 oft sicherer.
per_device_eval_batch_size: int = field(
default=1, metadata={"help": "Batch size for evaluation per device. Lower to save VRAM."}
)
@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(default=None, metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"})
local_model_dir: Optional[str] = field(default=None, metadata={"help": "Path to local directory containing ve.safetensors, t3_cfg.safetensors, etc. Overrides model_name_or_path for loading."})
cache_dir: Optional[str] = field(default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"})
freeze_voice_encoder: bool = field(default=True, metadata={"help": "Freeze the Voice Encoder."})
freeze_s3gen: bool = field(default=True, metadata={"help": "Freeze the S3Gen model (speech token to waveform)."})
@dataclass
class DataArguments:
preprocessed_data_dir: str = field(metadata={"help": "Path to the directory containing preprocessed .pt files."})
eval_split_size: float = field(default=0.05, metadata={"help": "Fraction of preprocessed data to use for evaluation if splitting."})
max_text_len: int = field(default=300, metadata={"help": "Max text len the model was trained with / expects."})
max_speech_len: int = field(default=750, metadata={"help": "Max speech len the model was trained with / expects."})
original_dataset_dir: Optional[str] = field(default=None, metadata={"help": "Path to the original raw audio dataset_dir (for sample generation conditioning)."})
original_metadata_file: Optional[str] = field(default=None, metadata={"help": "Path to original raw metadata_file (for sample generation conditioning)."})
original_dataset_name: Optional[str] = field(default=None, metadata={"help": "Original HF dataset name (for sample generation conditioning)."})
original_dataset_config_name: Optional[str] = field(default=None, metadata={"help": "Original HF dataset config name (for sample generation conditioning)."})
original_audio_column_name: str = field(default="audio", metadata={"help": "Audio column in original HF dataset (for sample generation conditioning)."})
original_text_column_name: str = field(default="text", metadata={"help": "Text column in original HF dataset."})
# --- Dataset und Collator ---
class PreprocessedSpeechDataset(Dataset):
def __init__(self, preprocessed_files: List[Path]):
self.preprocessed_files = preprocessed_files
if not self.preprocessed_files:
raise ValueError("No preprocessed files provided to PreprocessedSpeechDataset.")
logger.info(f"Initialized PreprocessedSpeechDataset with {len(self.preprocessed_files)} files.")
def __len__(self):
return len(self.preprocessed_files)
def __getitem__(self, idx) -> Optional[Dict[str, torch.Tensor]]:
file_path = self.preprocessed_files[idx]
try:
data_dict = torch.load(file_path, weights_only=None) # weights_only=None fĂŒr Ă€ltere PyTorch Versionen oder wenn nicht-Tensor Daten enthalten sind
required_keys = ["text_tokens", "text_token_lens", "speech_tokens", "speech_token_lens",
"t3_cond_speaker_emb", "t3_cond_prompt_speech_tokens", "t3_cond_emotion_adv"]
for key in required_keys:
if key not in data_dict:
logger.error(f"Missing key '{key}' in preprocessed file {file_path}. Skipping.")
return None
return data_dict
except Exception as e:
logger.error(f"Error loading or validating preprocessed file {file_path}: {e}. Skipping.")
return None
@dataclass
class SpeechDataCollator:
t3_config: T3Config
def __call__(self, features: List[Optional[Dict[str, Any]]]) -> Dict[str, Any]:
valid_features = [f for f in features if f is not None]
if not valid_features:
dummy_device = torch.device("cpu") # Oder das Device des Modells, wenn bekannt
speaker_emb_dim = self.t3_config.llama_hidden_size
prompt_len_collator = self.t3_config.speech_cond_prompt_len
max_text_len_collator = self.t3_config.max_text_tokens
max_speech_len_collator = self.t3_config.max_speech_tokens
logger.warning("COLLATOR_DEBUG: SpeechDataCollator received no valid features. Returning empty batch structure.")
return {
"text_tokens": torch.empty(0, max_text_len_collator, dtype=torch.long, device=dummy_device),
"text_token_lens": torch.empty(0, dtype=torch.long, device=dummy_device),
"speech_tokens": torch.empty(0, max_speech_len_collator, dtype=torch.long, device=dummy_device),
"speech_token_lens": torch.empty(0, dtype=torch.long, device=dummy_device),
"t3_cond_speaker_emb": torch.empty(0, speaker_emb_dim, dtype=torch.float, device=dummy_device),
"t3_cond_prompt_speech_tokens": torch.empty(0, prompt_len_collator, dtype=torch.long, device=dummy_device),
"t3_cond_emotion_adv": torch.empty(0, 1, 1, dtype=torch.float, device=dummy_device),
"labels_text": torch.empty(0, max_text_len_collator -1 if max_text_len_collator > 0 else 0, dtype=torch.long, device=dummy_device),
"labels_speech": torch.empty(0, max_speech_len_collator -1 if max_speech_len_collator > 0 else 0, dtype=torch.long, device=dummy_device),
}
features = valid_features
first_tensor_device = features[0]["text_tokens"].device # Behalte Device vom ersten Feature
padded_text_tokens = torch.stack([f["text_tokens"] for f in features]).to(first_tensor_device)
text_token_lens = torch.stack([f["text_token_lens"] for f in features]).to(first_tensor_device)
padded_speech_tokens = torch.stack([f["speech_tokens"] for f in features]).to(first_tensor_device)
speech_token_lens = torch.stack([f["speech_token_lens"] for f in features]).to(first_tensor_device)
t3_cond_speaker_emb = torch.stack([f["t3_cond_speaker_emb"] for f in features]).to(first_tensor_device)
t3_cond_prompt_speech_tokens = torch.stack([f["t3_cond_prompt_speech_tokens"] for f in features]).to(first_tensor_device)
emotion_adv_scalars = torch.stack([f["t3_cond_emotion_adv"] for f in features]).to(first_tensor_device)
t3_cond_emotion_adv = emotion_adv_scalars.view(len(features), 1, 1) # Ensure [B, 1, 1]
IGNORE_ID = -100
prompt_len = self.t3_config.speech_cond_prompt_len
shifted_text = padded_text_tokens[:, 1:].contiguous()
T_text = shifted_text.size(1)
text_lens_minus_one = (text_token_lens - 1).clamp(min=0)
arange_text = torch.arange(T_text, device=shifted_text.device)
mask_pad_text = arange_text[None, :] >= text_lens_minus_one[:, None]
labels_text = shifted_text.clone()
labels_text[mask_pad_text] = IGNORE_ID
shifted_speech = padded_speech_tokens[:, 1:].contiguous()
T_speech = shifted_speech.size(1)
speech_lens_minus_one = (speech_token_lens - 1).clamp(min=0)
arange_speech = torch.arange(T_speech, device=shifted_speech.device)
mask_pad_speech = arange_speech[None, :] >= speech_lens_minus_one[:, None]
mask_prompt = arange_speech[None, :] < prompt_len
mask_prompt = mask_prompt.expand(len(features), T_speech)
mask_speech_total = mask_pad_speech | mask_prompt
labels_speech = shifted_speech.clone()
labels_speech[mask_speech_total] = IGNORE_ID
num_valid_labels_speech = (labels_speech != IGNORE_ID).sum().item()
if padded_text_tokens.size(0) > 0 and num_valid_labels_speech == 0:
logger.warning(f"COLLATOR_WARNUNG: Batch (Size {padded_text_tokens.size(0)}) enthÀlt KEINE validen Sprach-Labels! "
f"Speech Lens (vor Padding): {speech_token_lens.tolist()}, Prompt Len (t3_config): {prompt_len}. ")
return {"text_tokens": padded_text_tokens, "text_token_lens": text_token_lens,
"speech_tokens": padded_speech_tokens, "speech_token_lens": speech_token_lens,
"t3_cond_speaker_emb": t3_cond_speaker_emb,
"t3_cond_prompt_speech_tokens": t3_cond_prompt_speech_tokens,
"t3_cond_emotion_adv": t3_cond_emotion_adv,
"labels_text": labels_text, "labels_speech": labels_speech}
# --- Modell-Wrapper fĂŒr Hugging Face Trainer ---
class T3ForFineTuning(torch.nn.Module):
def __init__(self, t3_model: T3, chatterbox_t3_config: T3Config):
super().__init__()
self.t3 = t3_model
self.chatterbox_t3_config = chatterbox_t3_config # Sollte die Ziel-Config sein
class HFCompatibleConfig(PretrainedConfig): # Minimal config for HF Trainer compatibility
model_type = "chatterbox_t3_finetune"
def __init__(self, **kwargs):
self.llama_config_name = kwargs.pop("llama_config_name", None)
self.text_tokens_dict_size = kwargs.pop("text_tokens_dict_size", None)
self.speech_tokens_dict_size = kwargs.pop("speech_tokens_dict_size", None)
self.max_text_tokens = kwargs.pop("max_text_tokens", 256)
self.max_speech_tokens = kwargs.pop("max_speech_tokens", 800)
self.speech_cond_prompt_len = kwargs.pop("speech_cond_prompt_len", 100)
self.start_text_token = kwargs.pop("start_text_token", 0)
self.stop_text_token = kwargs.pop("stop_text_token", 1)
self.start_speech_token = kwargs.pop("start_speech_token", 0)
self.stop_speech_token = kwargs.pop("stop_speech_token", 1)
# Weitere Felder aus T3Config, die fĂŒr die HF-Integration relevant sein könnten
self.llama_hidden_size = kwargs.pop("llama_hidden_size", 4096) # Beispiel
super().__init__(**kwargs)
self.config = HFCompatibleConfig(**chatterbox_t3_config.__dict__)
def forward(self, text_tokens, text_token_lens, speech_tokens, speech_token_lens,
t3_cond_speaker_emb, t3_cond_prompt_speech_tokens, t3_cond_emotion_adv,
labels_text = None, labels_speech=None):
target_device = text_tokens.device # Alle Inputs sollten auf dem gleichen Device sein
current_t3_cond = T3Cond(
speaker_emb=t3_cond_speaker_emb.to(target_device), # Sicherstellen, dass Konditionierungen auf dem richtigen GerÀt sind
cond_prompt_speech_tokens=t3_cond_prompt_speech_tokens.to(target_device),
cond_prompt_speech_emb=None, # Wird im Training typischerweise nicht ĂŒbergeben, sondern intern erzeugt, falls nötig
emotion_adv=t3_cond_emotion_adv.to(target_device) # Shape [B, 1, 1]
).to(device=target_device) # Sicherstellen, dass das T3Cond Objekt auf dem richtigen GerÀt ist
loss_text, loss_speech, speech_logits = self.t3.loss(
t3_cond=current_t3_cond, text_tokens=text_tokens, text_token_lens=text_token_lens,
speech_tokens=speech_tokens, speech_token_lens=speech_token_lens,
labels_text=labels_text, labels_speech=labels_speech
)
if loss_text is None or loss_speech is None: # Sollte nicht passieren, wenn Labels da sind
zero_loss_device = speech_logits.device if speech_logits is not None else target_device
zero_loss = torch.tensor(0.0, device=zero_loss_device, requires_grad=self.training)
if loss_text is None: loss_text = zero_loss.clone()
if loss_speech is None: loss_speech = zero_loss.clone()
total_loss = loss_text + loss_speech
# Trainer erwartet ein Tupel, dessen erstes Element der Loss ist.
# Weitere Elemente (z.B. Logits) können fĂŒr Metrikberechnung etc. zurĂŒckgegeben werden.
return (total_loss, speech_logits) if speech_logits is not None else (total_loss,)
# --- Funktionen fĂŒr Sample-Generierung ---
def find_checkpoint_folders_for_generation(base_dir: str) -> list[Path]:
base_path = Path(base_dir).resolve()
if not base_path.is_dir():
logger.error(f"Basis-Checkpoint-Verzeichnis fĂŒr Sample-Generierung nicht gefunden: {base_path}")
return []
checkpoint_folders = [d for d in base_path.iterdir() if d.is_dir() and d.name.startswith("checkpoint-")]
checkpoint_folders.append(base_path) # FĂŒr das "beste" Modell im Basisordner
def sort_key(p: Path):
if p == base_path: return -1 # Ensure base_path (best model) is processed, perhaps first or last based on need
try: return int(p.name.split('-')[-1])
except (ValueError, IndexError): return float('inf') # Put malformed names at the end
return sorted(list(set(checkpoint_folders)), key=sort_key) # Use set to remove duplicates if base_path was also a checkpoint
def set_model_components_to_eval_mode(model_instance: ChatterboxTTS):
if model_instance is None: return
logger.info(" Setze Modellkomponenten in den Eval-Modus...")
if hasattr(model_instance, 't3') and model_instance.t3 is not None: model_instance.t3.eval()
if hasattr(model_instance, 's3gen') and model_instance.s3gen is not None: model_instance.s3gen.eval()
if hasattr(model_instance, 've') and model_instance.ve is not None: model_instance.ve.eval()
def get_display_path(path_obj: Path, base_path_for_display: Path) -> str:
try:
common = Path(os.path.commonpath([str(path_obj.resolve()), str(base_path_for_display.resolve())]))
if common == base_path_for_display.resolve():
return str(path_obj.relative_to(base_path_for_display))
return path_obj.name
except ValueError:
return path_obj.name
def generate_and_save_samples_from_checkpoints(
output_dir_str: str,
data_args_original: DataArguments,
training_args_custom: CustomTrainingArguments,
base_model_path_for_static_components_str: str
):
if not training_args_custom.generate_samples_on_finish:
logger.info("Ăberspringe Sample-Generierung gemÀà Konfiguration.")
return
main_training_output_dir = Path(output_dir_str).resolve()
paths_to_generate_from: List[Path]
if training_args_custom.generate_from_all_checkpoints:
logger.info("Starte erweiterte Sample-Generierung fĂŒr verschiedene Checkpoints...")
paths_to_generate_from = find_checkpoint_folders_for_generation(str(main_training_output_dir))
else:
paths_to_generate_from = [main_training_output_dir]
logger.info(f"Generiere Samples nur vom finalen/besten Modell in: {main_training_output_dir}")
if not paths_to_generate_from:
logger.warning(f"Keine Checkpoint-Pfade in {main_training_output_dir} gefunden zum Testen.")
return
overall_sample_output_dir = main_training_output_dir / "generated_samples_all_checkpoints_audio"
overall_sample_output_dir.mkdir(parents=True, exist_ok=True)
script_dir_parent_for_display = Path(__file__).parent.resolve().parent
logger.info(f"Gefundene Checkpoint-Pfade zum Testen (relativ zu {script_dir_parent_for_display}):")
for p in paths_to_generate_from:
logger.info(f" - {get_display_path(p, script_dir_parent_for_display)}")
generation_device = training_args_custom.device
selected_conditioning_audio = None
if data_args_original.original_dataset_dir:
orig_dir = Path(data_args_original.original_dataset_dir)
if not orig_dir.is_absolute():
script_execution_dir = Path.cwd()
orig_dir = (script_execution_dir / orig_dir).resolve()
if orig_dir.is_dir():
try:
audio_files = list(orig_dir.rglob("*.wav")) + list(orig_dir.rglob("*.mp3")) + list(orig_dir.rglob("*.flac"))
if audio_files:
selected_conditioning_audio = str(np.random.choice(audio_files).resolve())
logger.info(f"ZufÀllige Konditionierungs-Audiodatei ausgewÀhlt: {selected_conditioning_audio}")
else: logger.warning(f"Keine .wav/.mp3/.flac Dateien in {orig_dir} fĂŒr Konditionierung gefunden.")
except Exception as e: logger.warning(f"Fehler beim Suchen nach Konditionierungs-Audio in {orig_dir}: {e}")
if not selected_conditioning_audio and training_args_custom.sample_texts:
logger.warning("Kein Konditionierungs-Audio gefunden. Versuche mit Default-Konditionierung des Modells (falls vorhanden).")
texts_to_generate = training_args_custom.sample_texts
if not texts_to_generate: texts_to_generate = ["Hallo, dies ist eine Standard-Sprachausgabe."]
num_to_gen_per_ckpt = min(training_args_custom.num_samples_to_generate, len(texts_to_generate))
for ckpt_path_for_t3 in paths_to_generate_from:
ckpt_name_for_folder = ckpt_path_for_t3.name
if ckpt_path_for_t3 == main_training_output_dir and not ckpt_path_for_t3.name.startswith("checkpoint-"):
ckpt_name_for_folder = "final_best_model"
logger.info(f"\n--- Teste Checkpoint: {ckpt_name_for_folder} (T3-Gewichte von: {get_display_path(ckpt_path_for_t3, script_dir_parent_for_display)}) ---")
current_model_instance = None
try:
current_model_instance = ChatterboxTTS.from_local(
ckpt_dir_for_model_weights=str(ckpt_path_for_t3),
base_model_dir_for_static_components=base_model_path_for_static_components_str,
device=generation_device
)
set_model_components_to_eval_mode(current_model_instance)
logger.info(f" Modell erfolgreich fĂŒr Checkpoint '{ckpt_name_for_folder}' geladen.")
if selected_conditioning_audio:
try:
logger.info(f" Bereite Konditionierung mit Audio-Prompt vor: {Path(selected_conditioning_audio).name}")
current_model_instance.prepare_conditionals(audio_prompt_path=selected_conditioning_audio)
except TypeError as te:
logger.error(f" TypeError bei Konditionierung fĂŒr {ckpt_name_for_folder}: {te}. PrĂŒfen Sie Signatur von prepare_conditionals.", exc_info=True)
except Exception as e_cond:
logger.error(f" Anderer Fehler bei Konditionierung fĂŒr {ckpt_name_for_folder}: {e_cond}. Versuche mit Default.", exc_info=True)
elif hasattr(current_model_instance, 'conds') and current_model_instance.conds is not None and current_model_instance.conds.is_prepared:
logger.info(" Verwende Default/geladene Konditionierung aus dem Modell.")
else:
logger.warning(f" Kein Audio-Prompt & keine Default-Konditionierung fĂŒr {ckpt_name_for_folder}. Es wird versucht, ohne explizite Konditionierung zu generieren oder die prepare_conditionals-Methode des Modells wird Default-Verhalten haben.")
ckpt_specific_sample_dir = overall_sample_output_dir / ckpt_name_for_folder
ckpt_specific_sample_dir.mkdir(parents=True, exist_ok=True)
for i in range(num_to_gen_per_ckpt):
text_to_synth = texts_to_generate[i]
logger.info(f" Generiere Audio fĂŒr Text: \"{text_to_synth}\"")
try:
wav_out_tensor, sr_out = current_model_instance.generate(text_to_synth)
wav_out_numpy = wav_out_tensor.squeeze().cpu().numpy()
filename_safe_text = "".join(c if c.isalnum() else "_" for c in text_to_synth[:30]).strip("_")
output_path = ckpt_specific_sample_dir / f"sample_{i+1}_{filename_safe_text}.wav"
sf.write(str(output_path), wav_out_numpy, sr_out)
logger.info(f" Audio gespeichert unter: {output_path.relative_to(Path.cwd())}")
except ValueError as ve:
logger.error(f" ValueError bei Generierung fĂŒr '{text_to_synth}' mit {ckpt_name_for_folder}: {ve}.", exc_info=True)
except Exception as e_gen:
logger.error(f" Fehler bei Generierung fĂŒr '{text_to_synth}' mit {ckpt_name_for_folder}: {e_gen}", exc_info=True)
except FileNotFoundError as e_load_fnf:
logger.error(f"FileNotFoundError beim Laden von Checkpoint {ckpt_name_for_folder}: {e_load_fnf}.")
except Exception as e_load:
logger.error(f"Allgemeiner Fehler beim Laden/Verarbeiten von {ckpt_name_for_folder}: {e_load}", exc_info=True)
finally:
if current_model_instance is not None: del current_model_instance
if isinstance(generation_device, torch.device) and generation_device.type == "cuda":
torch.cuda.empty_cache()
elif isinstance(generation_device, str) and "cuda" in generation_device:
torch.cuda.empty_cache()
logger.info(f" Verarbeitung fĂŒr Checkpoint {ckpt_name_for_folder} abgeschlossen.")
logger.info(f"\nAlle Sample-Generierungen abgeschlossen. Samples in: {overall_sample_output_dir.relative_to(Path.cwd())}")
# --- Hauptfunktion main() ---
def main():
if os.name == 'nt':
try:
import torch.multiprocessing as mp
mp.set_sharing_strategy('file_system')
logger.info("Set multiprocessing sharing strategy to 'file_system' for Windows.")
except Exception as e:
logger.warning(f"Could not set multiprocessing sharing strategy to 'file_system': {e}.")
parser = HfArgumentParser((ModelArguments, DataArguments, CustomTrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
handlers=[logging.StreamHandler()] # Sicherstellen, dass Logs auf stdout/stderr gehen
)
# Setze auch den Root-Logger Level, falls andere Libraries darunter loggen
logging.getLogger().setLevel(logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN)
datasets.utils.logging.set_verbosity_info()
logger.info("Training/evaluation parameters %s", training_args)
logger.info("Model parameters %s", model_args)
logger.info("Data parameters %s", data_args)
set_seed(training_args.seed)
# Wichtiger Hinweis zur Evaluation und Early Stopping:
if training_args.early_stopping_patience is not None and training_args.early_stopping_patience > 0:
if training_args.evaluation_strategy == "no" or \
(training_args.evaluation_strategy == "steps" and training_args.eval_steps >= training_args.max_steps and training_args.num_train_epochs > 0) or \
(training_args.evaluation_strategy == "steps" and training_args.eval_steps > 100000) : # Heuristik fĂŒr sehr hohe eval_steps
logger.warning(
f"EarlyStoppingCallback is enabled (patience={training_args.early_stopping_patience}), "
f"but evaluation_strategy is '{training_args.evaluation_strategy}' "
f"with eval_steps={training_args.eval_steps}. "
"Early stopping requires regular evaluations to function correctly. "
"Consider setting evaluation_strategy to 'steps' and eval_steps to a reasonable "
"value (e.g., equal to save_steps or logging_steps)."
)
base_model_path_for_static_and_initial_t3_str: Optional[str] = None
original_model_dir_for_copy_logic: Optional[Path] = None
script_dir = Path(__file__).parent.resolve()
if model_args.local_model_dir:
lm_path = Path(model_args.local_model_dir)
if not lm_path.is_absolute(): lm_path = (script_dir / lm_path).resolve()
logger.info(f"Verwende local_model_dir als Basis: {lm_path}")
if not lm_path.is_dir():
logger.error(f"local_model_dir {lm_path} nicht gefunden! Abbruch.")
return
base_model_path_for_static_and_initial_t3_str = str(lm_path)
original_model_dir_for_copy_logic = lm_path
else:
repo_to_download = model_args.model_name_or_path or REPO_ID
logger.info(f"Verwende Modell von Hugging Face Hub ({repo_to_download}) als Basis.")
download_base_parent_dir = script_dir.parent / "model_hub_cache"
repo_name_for_dir = repo_to_download.replace("/", "_").replace("\\", "_")
snapshot_dir_path = download_base_parent_dir / repo_name_for_dir
snapshot_dir_path.mkdir(parents=True, exist_ok=True)
logger.info(f"Stelle sicher, dass das Originalmodell in {snapshot_dir_path} vorhanden ist.")
from huggingface_hub import snapshot_download
try:
actual_snapshot_dir = snapshot_download(
repo_id=repo_to_download, cache_dir=model_args.cache_dir,
local_dir=str(snapshot_dir_path), local_dir_use_symlinks=False, # local_dir muss str sein
allow_patterns=["*.safetensors", "*.json", "*.pt", "config.json"],
)
base_model_path_for_static_and_initial_t3_str = actual_snapshot_dir
original_model_dir_for_copy_logic = Path(actual_snapshot_dir)
logger.info(f"Originalmodell-Komponenten sind in {actual_snapshot_dir} verfĂŒgbar.")
except Exception as e:
logger.error(f"Fehler beim Herunterladen des Originalmodells von {repo_to_download}: {e}", exc_info=True)
return
if not base_model_path_for_static_and_initial_t3_str:
logger.error("Konnte keinen Pfad fĂŒr Basismodellkomponenten bestimmen. Abbruch.")
return
target_t3_config = T3Config()
logger.info(f"T3Config Standardwerte: max_text={target_t3_config.max_text_tokens}, max_speech={target_t3_config.max_speech_tokens}")
target_t3_config.max_text_tokens = data_args.max_text_len
target_t3_config.max_speech_tokens = data_args.max_speech_len
logger.info(f"Ziel-T3Config fĂŒr Training und Initialisierung: "
f"max_text_tokens={target_t3_config.max_text_tokens}, "
f"max_speech_tokens={target_t3_config.max_speech_tokens}, "
f"start_text_token={target_t3_config.start_text_token}, "
f"speech_cond_prompt_len={target_t3_config.speech_cond_prompt_len}")
logger.info(f"Lade initiales T3-Modell und passe Gewichte an Ziel-Config an (Quelle: {base_model_path_for_static_and_initial_t3_str})")
try:
temp_full_model_for_t3_init = ChatterboxTTS.from_local(
ckpt_dir_for_model_weights=base_model_path_for_static_and_initial_t3_str,
base_model_dir_for_static_components=base_model_path_for_static_and_initial_t3_str,
device="cpu",
target_t3_config_for_model=target_t3_config
)
t3_model_to_finetune = temp_full_model_for_t3_init.t3
t3_config_for_trainer = t3_model_to_finetune.hp
if t3_config_for_trainer.max_text_tokens != data_args.max_text_len or \
t3_config_for_trainer.max_speech_tokens != data_args.max_speech_len:
logger.error(
f"FATAL: Mismatch in T3 Konfiguration nach dem Laden! "
f"Modell max_text: {t3_config_for_trainer.max_text_tokens} vs DataArgs: {data_args.max_text_len}. "
f"Modell max_speech: {t3_config_for_trainer.max_speech_tokens} vs DataArgs: {data_args.max_speech_len}."
)
return
del temp_full_model_for_t3_init
logger.info("Initiales T3-Modell und Konfiguration erfolgreich geladen und an Ziel-LĂ€ngen angepasst.")
except TypeError as te:
logger.error(f"TypeError beim initialen Laden von ChatterboxTTS.from_local: {te}.", exc_info=True)
return
except Exception as e_load_init:
logger.error(f"Allgemeiner Fehler beim initialen Laden des T3-Modells: {e_load_init}", exc_info=True)
return
for param in t3_model_to_finetune.parameters(): param.requires_grad = True
if model_args.freeze_voice_encoder: # Diese sollten von ChatterboxTTS selbst gehandhabt werden, wenn T3 trainiert wird
logger.warning("freeze_voice_encoder/freeze_s3gen sind in diesem Skript nicht direkt implementiert, "
"da nur das T3-Modul trainiert wird. Die ChatterboxTTS-Klasse sollte das Einfrieren "
"der anderen Komponenten bei Bedarf intern handhaben.")
preprocessed_data_path = Path(data_args.preprocessed_data_dir)
if not preprocessed_data_path.is_absolute():
preprocessed_data_path = (script_dir / preprocessed_data_path).resolve()
logger.info(f"Lade vorverarbeitete Daten von: {preprocessed_data_path}")
if not preprocessed_data_path.is_dir():
raise FileNotFoundError(f"Vorverarbeitetes Datenverzeichnis nicht gefunden: {preprocessed_data_path}")
all_pt_files = sorted(list(preprocessed_data_path.glob("*.pt")))
if not all_pt_files:
raise FileNotFoundError(f"Keine .pt Dateien in {preprocessed_data_path} gefunden")
rng = np.random.default_rng(training_args.seed)
rng.shuffle(all_pt_files)
train_pt_files: List[Path]
eval_pt_files: Optional[List[Path]] = None
if training_args.do_eval and data_args.eval_split_size > 0 and len(all_pt_files) > 1 :
num_eval_samples = int(len(all_pt_files) * data_args.eval_split_size)
if num_eval_samples == 0 and len(all_pt_files) > 1 : num_eval_samples = 1
if num_eval_samples >= len(all_pt_files): num_eval_samples = max(0, len(all_pt_files) -1)
if num_eval_samples > 0 and (len(all_pt_files) - num_eval_samples) > 0 :
train_pt_files = all_pt_files[:-num_eval_samples]
eval_pt_files = all_pt_files[-num_eval_samples:]
logger.info(f"Dateien aufgeteilt: {len(train_pt_files)} Training, {len(eval_pt_files)} Evaluation.")
else:
train_pt_files = all_pt_files
logger.warning("Split nicht möglich oder resultiert in 0 Trainings-/Evaluationssamples. Verwende alle fĂŒr Training.")
else:
train_pt_files = all_pt_files
logger.info(f"Verwende alle {len(train_pt_files)} Dateien fĂŒr Training (kein Eval-Split).")
train_dataset = PreprocessedSpeechDataset(train_pt_files) if train_pt_files else None
eval_dataset = PreprocessedSpeechDataset(eval_pt_files) if eval_pt_files and training_args.do_eval else None
if not train_dataset:
logger.error("Keine Trainingsdaten. Abbruch.")
return
data_collator = SpeechDataCollator(t3_config_for_trainer)
hf_trainable_model = T3ForFineTuning(t3_model_to_finetune, t3_config_for_trainer)
callbacks_list = []
if training_args.early_stopping_patience is not None and training_args.early_stopping_patience > 0:
callbacks_list.append(EarlyStoppingCallback(early_stopping_patience=training_args.early_stopping_patience))
# DetailedMemoryClearingCallback hinzufĂŒgen
if training_args.do_train or training_args.do_eval: # Callback ist nĂŒtzlich fĂŒr beides
callbacks_list.append(DetailedMemoryClearingCallback())
trainer_instance = Trainer(
model=hf_trainable_model, args=training_args,
train_dataset=train_dataset, eval_dataset=eval_dataset,
data_collator=data_collator, callbacks=callbacks_list if callbacks_list else None,
)
if training_args.do_train:
logger.info("*** Training des T3-Modells ***")
resume_from_ckt = training_args.resume_from_checkpoint
if isinstance(resume_from_ckt, str):
resume_path = Path(resume_from_ckt)
if not resume_path.is_absolute(): resume_path = (script_dir / resume_path).resolve()
if not resume_path.exists():
logger.warning(f"resume_from_checkpoint Pfad {resume_path} existiert nicht. Starte neu.")
resume_from_ckt = None
else: resume_from_ckt = str(resume_path)
train_result = trainer_instance.train(resume_from_checkpoint=resume_from_ckt)
# Manuelles Speichern des reinen T3-Modells, wenn Trainer.save_model() den Wrapper speichert
trainer_instance.save_model() # Speichert den HF-Wrapper (T3ForFineTuning)
logger.info(f"HF Trainer Modell (Wrapper) gespeichert in: {training_args.output_dir}")
logger.info("Speichere reine T3-Gewichte (t3_cfg.safetensors)...")
model_to_save_t3 = trainer_instance.model.t3 if hasattr(trainer_instance.model, 't3') else \
(trainer_instance.model.module.t3 if hasattr(trainer_instance.model, 'module') and hasattr(trainer_instance.model.module, 't3') else None)
if model_to_save_t3:
finetuned_t3_state_dict = model_to_save_t3.state_dict()
output_t3_safetensor_path = Path(training_args.output_dir).resolve() / "t3_cfg.safetensors"
from safetensors.torch import save_file as save_safetensors_file
save_safetensors_file(finetuned_t3_state_dict, str(output_t3_safetensor_path))
logger.info(f"Reine T3-Gewichte gespeichert: {output_t3_safetensor_path}")
# Speichere auch die T3Config des trainierten Modells
output_t3_config_path = Path(training_args.output_dir).resolve() / "t3_model_config.json"
with open(output_t3_config_path, 'w') as f:
json.dump(model_to_save_t3.hp.__dict__, f, indent=4) # Speichere T3Config (hp)
logger.info(f"T3-Modellkonfiguration (aus t3.hp) gespeichert: {output_t3_config_path}")
else: logger.error("Konnte T3-Submodul nicht extrahieren zum separaten Speichern!")
if original_model_dir_for_copy_logic and original_model_dir_for_copy_logic.is_dir():
logger.info(f"Kopiere statische Komponenten von {original_model_dir_for_copy_logic} nach {training_args.output_dir}")
final_output_dir = Path(training_args.output_dir).resolve()
files_to_copy = ["ve.safetensors", "s3gen.safetensors", "tokenizer.json", "conds.pt"]
for f_name in files_to_copy:
src_path = original_model_dir_for_copy_logic / f_name
dest_path = final_output_dir / f_name
if src_path.exists():
try:
shutil.copy2(src_path, dest_path)
logger.info(f" '{f_name}' kopiert nach {dest_path}")
except Exception as e_copy: logger.warning(f" Fehler beim Kopieren von {src_path} nach {dest_path}: {e_copy}")
elif f_name != "conds.pt": # conds.pt ist optional
logger.warning(f" Kritische Quelldatei {src_path} nicht gefunden zum Kopieren.")
else:
logger.info(f" Optionale Quelldatei {src_path} (conds.pt) nicht gefunden, wird nicht kopiert.")
logger.info(f"Modellkomponenten strukturiert in {final_output_dir}")
else: logger.warning(f"Kein gĂŒltiger Pfad fĂŒr original_model_dir_for_copy_logic ({original_model_dir_for_copy_logic}). Statische Komponenten nicht kopiert.")
metrics = train_result.metrics
trainer_instance.log_metrics("train", metrics)
trainer_instance.save_metrics("train", metrics)
trainer_instance.save_state()
if training_args.do_eval and eval_dataset:
logger.info("*** Evaluiere T3-Modell ***")
metrics = trainer_instance.evaluate()
trainer_instance.log_metrics("eval", metrics)
trainer_instance.save_metrics("eval", metrics)
if training_args.generate_samples_on_finish:
can_generate_samples = False
final_output_dir_path = Path(training_args.output_dir).resolve()
# PrĂŒfen, ob Modell trainiert wurde oder ein Checkpoint existiert, von dem geladen werden kann.
# Und ob die notwendigen Dateien fĂŒr die Sample-Generierung vorhanden sind.
model_files_exist = (final_output_dir_path / "t3_cfg.safetensors").exists() or \
(final_output_dir_path / "model.safetensors").exists() # HF Trainer speichert model.safetensors
if (training_args.do_train or training_args.resume_from_checkpoint or training_args.do_eval) and \
final_output_dir_path.is_dir() and model_files_exist:
can_generate_samples = True
logger.info("Bedingungen fĂŒr Sample-Generierung nach Training/Evaluation/Resume erfĂŒllt.")
else:
logger.warning(f"Kein valides Modell/Output-Verzeichnis ({final_output_dir_path}, Files exist: {model_files_exist}) fĂŒr Sample-Generierung gefunden oder kein Training/Eval/Resume durchgefĂŒhrt.")
if can_generate_samples:
logger.info("Starte Sample-Generierung...")
# FĂŒr die Sample-Generierung verwenden wir den Output-Ordner des Trainings,
# da dort alle Komponenten (inkl. der kopierten statischen) liegen sollten.
path_for_statics_for_generation = str(final_output_dir_path)
# Wenn nur evaluiert wurde ohne Training und ohne Resume von einem *Trainings*-Checkpoint,
# könnten statische Komponenten fehlen, falls nicht explizit dorthin kopiert.
# Aber die Logik oben sollte die statischen Teile in den final_output_dir_path kopieren.
if not (final_output_dir_path / "ve.safetensors").exists():
logger.warning(f"ve.safetensors nicht in {final_output_dir_path} gefunden. Sample-Generierung könnte fehlschlagen oder auf Basismodell zurĂŒckgreifen.")
# Fallback, falls die Kopierlogik nicht alles erfasst hat oder do_train=False war
if base_model_path_for_static_and_initial_t3_str and \
Path(base_model_path_for_static_and_initial_t3_str, "ve.safetensors").exists():
path_for_statics_for_generation = base_model_path_for_static_and_initial_t3_str
logger.info(f" Verwende statische Komponenten von: {path_for_statics_for_generation} fĂŒr Sample Generierung")
try:
generate_and_save_samples_from_checkpoints(
output_dir_str=str(final_output_dir_path), # Wo die Checkpoints des aktuellen Laufs liegen
data_args_original=data_args,
training_args_custom=training_args,
base_model_path_for_static_components_str=path_for_statics_for_generation # Wo VE, S3Gen etc. zu finden sind
)
except ValueError as ve_path:
logger.error(f"Fehler bei der Pfad-Logik in generate_and_save_samples_from_checkpoints: {ve_path}", exc_info=True)
overall_sample_dir_name = "generated_samples_all_checkpoints_audio" # Aus der Funktion bekannt
logger.info(f" Absoluter Pfad fĂŒr generierte Samples wĂ€re: {final_output_dir_path / overall_sample_dir_name}")
except Exception as e_gen_samples:
logger.error(f"Genereller Fehler wÀhrend der Sample-Generierung: {e_gen_samples}", exc_info=True)
else:
logger.info("Ăberspringe Sample-Generierung: Bedingungen nicht erfĂŒllt.")
logger.info("Finetuning-Skript beendet.")
if __name__ == "__main__":
main()
How to start preprocess:
python preprocess_data.py --model_load_path "ResembleAI/chatterbox" --is_hub_model --dataset_dir ./audio_data --output_dir ./processed_tts_data_AUGMENTED_WITH_PARAMS_2048_4096 --max_text_len 2048 --max_speech_len 4096 --silence_padding_ms 500 --num_workers 14 --augment_data --aug_speeds 0.9 --aug_volumes_db -2 2
How to start training RTX 5090 (ca. 29GB/32GB VRAM):
python finetune_t3_preprocessed.py --output_dir "../checkpoints/chatterbox_GER_finetune_5VARIANTS_2048_4096_CMD" --local_model_dir "../model_hub_cache/ResembleAI_chatterbox" --preprocessed_data_dir "X:\KI\anaconda3\envs\chatterbox_train_env\MyGenesisTTSProject\src\processed_tts_data_5VARIANTS_2048_4096" --original_dataset_dir "./audio_data" --eval_split_size 0.05 --num_train_epochs 20 --per_device_train_batch_size 1 --per_device_eval_batch_size 1 --gradient_accumulation_steps 32 --learning_rate 5e-6 --warmup_steps 200 --logging_strategy "steps" --logging_steps 50 --evaluation_strategy "steps" --eval_steps 200 --save_strategy "steps" --save_steps 200 --save_total_limit 15 --fp16 True --report_to "tensorboard" --dataloader_num_workers 4 --dataloader_pin_memory True --do_train --do_eval --max_text_len 2048 --max_speech_len 4096 --save_safetensors True --load_best_model_at_end True --generate_samples_on_finish True --early_stopping_patience 5
How I save in between to test: (see max steps value its just a little bit more than the last checkpoint)
python finetune_t3_preprocessed.py --output_dir "../checkpoints/chatterbox_GER_finetune_5VARIANTS_2048_4096_CMD" --local_model_dir "../model_hub_cache/ResembleAI_chatterbox" --preprocessed_data_dir "X:\KI\anaconda3\envs\chatterbox_train_env\MyGenesisTTSProject\src\processed_tts_data_5VARIANTS_2048_4096" --original_dataset_dir "./audio_data" --eval_split_size 0.05 --num_train_epochs 20 --per_device_train_batch_size 1 --per_device_eval_batch_size 1 --gradient_accumulation_steps 32 --learning_rate 5e-6 --warmup_steps 200 --logging_strategy "steps" --logging_steps 50 --evaluation_strategy "steps" --eval_steps 200 --save_strategy "steps" --save_steps 200 --save_total_limit 15 --fp16 True --report_to "tensorboard" --dataloader_num_workers 4 --dataloader_pin_memory True --do_train --do_eval --max_text_len 2048 --max_speech_len 4096 --save_safetensors True --load_best_model_at_end True --generate_samples_on_finish True --early_stopping_patience 5 --resume_from_checkpoint "../checkpoints/chatterbox_GER_finetune_5VARIANTS_2048_4096_CMD/checkpoint-1000" --max_steps 1010
How to continue: (from last checkpoin)
python finetune_t3_preprocessed.py --output_dir "../checkpoints/chatterbox_GER_finetune_5VARIANTS_2048_4096_CMD" --local_model_dir "../model_hub_cache/ResembleAI_chatterbox" --preprocessed_data_dir "X:\KI\anaconda3\envs\chatterbox_train_env\MyGenesisTTSProject\src\processed_tts_data_5VARIANTS_2048_4096" --original_dataset_dir "./audio_data" --eval_split_size 0.05 --num_train_epochs 20 --per_device_train_batch_size 1 --per_device_eval_batch_size 1 --gradient_accumulation_steps 32 --learning_rate 5e-6 --warmup_steps 200 --logging_strategy "steps" --logging_steps 50 --evaluation_strategy "steps" --eval_steps 200 --save_strategy "steps" --save_steps 200 --save_total_limit 15 --fp16 True --report_to "tensorboard" --dataloader_num_workers 4 --dataloader_pin_memory True --do_train --do_eval --max_text_len 2048 --max_speech_len 4096 --save_safetensors True --load_best_model_at_end True --generate_samples_on_finish True --early_stopping_patience 5 --resume_from_checkpoint "../checkpoints/chatterbox_GER_finetune_5VARIANTS_2048_4096_CMD/checkpoint-1000"
I also edited the tts:
python
# tts.py
from dataclasses import dataclass, field
from pathlib import Path
import logging
from typing import Optional, Tuple, Dict, Any, Union
import inspect # FĂŒr Debugging
import librosa
import numpy as np
import torch
import perth
import torch.nn.functional as F
from huggingface_hub import snapshot_download
from safetensors.torch import load_file
from .models.t3 import T3
from .models.s3tokenizer import S3_SR, drop_invalid_tokens
from .models.s3gen import S3GEN_SR, S3Gen
from .models.tokenizers import EnTokenizer
from .models.voice_encoder import VoiceEncoder
from .models.t3.modules.cond_enc import T3Cond
from .models.t3.modules.t3_config import T3Config # Importiere T3Config
logger = logging.getLogger(__name__)
REPO_ID = "ResembleAI/chatterbox" # Standard-Repo, kann ĂŒberschrieben werden
def punc_norm(text: str) -> str:
if not text:
logger.warning("punc_norm erhielt leeren Text. Standardtext wird zurĂŒckgegeben.")
return "You need to add some text for me to talk."
if text[0].islower():
text = text[0].upper() + text[1:]
text = " ".join(text.split())
punc_to_replace = [
("...", ", "), ("âŠ", ", "), (":", ","), (" - ", ", "), (";", ", "),
("â", "-"), ("â", "-"), (" ,", ","), ("â", "\""), ("â", "\""),
("â", "'"), ("â", "'"),
]
for old, new in punc_to_replace:
text = text.replace(old, new)
text = text.rstrip(" ")
if not any(text.endswith(p) for p in {".", "!", "?", "-", ","}):
text += "."
return text
@dataclass
class Conditionals:
t3: T3Cond
gen: Dict[str, Any]
def to(self, device: Union[torch.device, str]) -> 'Conditionals':
target_device = device if isinstance(device, torch.device) else torch.device(device)
# Nutze die .to() Methode von T3Cond
self.t3 = self.t3.to(device=target_device)
for k, v in self.gen.items():
if torch.is_tensor(v):
self.gen[k] = v.to(target_device)
return self
def save(self, fpath: Path):
# T3Cond hat eine eigene .save(), aber die speichert __dict__.
# Wir speichern hier auch __dict__ fĂŒr Konsistenz mit T3Cond.load(**kwargs)
arg_dict = dict(t3=self.t3.__dict__, gen=self.gen)
torch.save(arg_dict, fpath)
logger.info(f"Conditionals gespeichert unter: {fpath}")
@classmethod
def load(cls, fpath: Path, map_location: str = "cpu") -> 'Conditionals':
map_location_device = torch.device(map_location)
loaded_data = torch.load(fpath, map_location=map_location_device, weights_only=None)
logger.info(f"Lade Conditionals von: {fpath} auf {map_location_device}")
t3_cond_attributes_dict = loaded_data['t3']
if not isinstance(t3_cond_attributes_dict, dict):
raise TypeError(f"Unerwarteter Typ fĂŒr 't3' Daten in Conditionals: {type(t3_cond_attributes_dict)}. Erwartet dict.")
processed_t3_cond_attrs = {}
for key, value in t3_cond_attributes_dict.items():
if isinstance(value, torch.Tensor):
processed_t3_cond_attrs[key] = value.to(map_location_device)
else:
processed_t3_cond_attrs[key] = value
try:
t3_instance = T3Cond(**processed_t3_cond_attrs)
except TypeError as e:
logger.error(f"Fehler beim Instanziieren von T3Cond mit **kwargs: {e}")
logger.error(f"VerfĂŒgbare SchlĂŒssel im verarbeiteten T3Cond Daten-Dictionary: {list(processed_t3_cond_attrs.keys())}")
try:
sig = inspect.signature(T3Cond)
logger.error(f"Erwartete Parameter fĂŒr T3Cond: {list(sig.parameters.keys())}")
except Exception: pass
raise
gen_data_dict = loaded_data['gen']
processed_gen_data_dict = {
k: v.to(map_location_device) if isinstance(v, torch.Tensor) else v
for k, v in gen_data_dict.items()
}
return cls(t3_instance, processed_gen_data_dict)
class ChatterboxTTS:
ENC_COND_LEN = 6 * S3_SR
DEC_COND_LEN = 10 * S3GEN_SR
def __init__(
self,
t3: T3,
s3gen: S3Gen,
ve: VoiceEncoder,
tokenizer: EnTokenizer,
device: torch.device,
conds: Optional[Conditionals] = None,
s3_tokenizer_sample_rate: int = S3_SR
):
self.sr = S3GEN_SR
self.t3 = t3 # Sollte bereits die korrekte T3Config (hp Attribut) haben
self.s3gen = s3gen
self.ve = ve
self.tokenizer = tokenizer
self.device = device
self.conds = conds
self.watermarker = perth.PerthImplicitWatermarker()
self.s3_tokenizer_sample_rate = s3_tokenizer_sample_rate
logger.info(f"ChatterboxTTS initialisiert auf GerÀt: {self.device}. T3 Config Max Text: {getattr(self.t3.hp, 'max_text_tokens', 'N/A')}, Max Speech: {getattr(self.t3.hp, 'max_speech_tokens', 'N/A')}")
def to(self, device: Union[torch.device, str]) -> 'ChatterboxTTS':
target_device = device if isinstance(device, torch.device) else torch.device(device)
if self.device == target_device: # Bereits auf dem ZielgerÀt
return self
logger.info(f"Verschiebe ChatterboxTTS von {self.device} auf GerÀt: {target_device}")
self.device = target_device
self.t3.to(target_device)
self.s3gen.to(target_device)
self.ve.to(target_device)
if self.conds:
self.conds.to(target_device)
return self
@classmethod
def from_local(cls,
ckpt_dir_for_model_weights: str,
base_model_dir_for_static_components: str,
device: str, # Kommt als String vom Aufrufer
target_t3_config_for_model: Optional[T3Config] = None # GeÀndert: target_t3_config_for_MODEL
) -> 'ChatterboxTTS':
logger.info(f"ChatterboxTTS.from_local:")
logger.info(f" Lade T3-spezifische Gewichte von: {ckpt_dir_for_model_weights}")
logger.info(f" Lade statische Komponenten von: {base_model_dir_for_static_components}")
if target_t3_config_for_model:
logger.info(f" Verwende ĂŒbergebene target_t3_config fĂŒr T3-Initialisierung: "
f"max_text={target_t3_config_for_model.max_text_tokens}, "
f"max_speech={target_t3_config_for_model.max_speech_tokens}")
else:
logger.info(" Keine target_t3_config ĂŒbergeben, T3 wird mit Default-Config initialisiert.")
target_device_obj = torch.device(device)
if target_device_obj.type == "cuda" and not torch.cuda.is_available():
logger.warning("CUDA angefordert, aber nicht verfĂŒgbar. Fallback auf CPU.")
target_device_obj = torch.device("cpu")
elif target_device_obj.type == "mps" and not torch.backends.mps.is_available(): # Korrigiert zu target_device_obj.type
logger.warning("MPS angefordert, aber nicht verfĂŒgbar. Fallback auf CPU.")
target_device_obj = torch.device("cpu")
path_for_t3_specific_weights = Path(ckpt_dir_for_model_weights).resolve()
path_for_static_files = Path(base_model_dir_for_static_components).resolve()
# Validierung der Pfade
if not path_for_static_files.is_dir():
raise FileNotFoundError(f"Basis-Modellpfad fĂŒr statische Komponenten nicht gefunden: {path_for_static_files}")
if not path_for_t3_specific_weights.is_dir() and not path_for_t3_specific_weights.is_file() : # Kann auch eine Datei sein
# PrĂŒfen ob es eine Datei ist, falls ckpt_dir_for_model_weights direkt auf t3_cfg.safetensors zeigt
if not (path_for_t3_specific_weights.is_file() and \
any(str(path_for_t3_specific_weights).endswith(x) for x in [".safetensors", ".bin"])):
raise FileNotFoundError(f"Pfad fĂŒr T3-Modellgewichte weder Verzeichnis noch gĂŒltige Datei: {path_for_t3_specific_weights}")
load_device_cpu_str = "cpu" # FĂŒr safetensors device-Argument
# Statische Komponenten laden (VE, S3Gen, Tokenizer) - immer von base_model_dir...
ve = VoiceEncoder()
ve.load_state_dict(load_file(path_for_static_files / "ve.safetensors", device=load_device_cpu_str))
ve.eval()
s3gen = S3Gen()
s3gen.load_state_dict(load_file(path_for_static_files / "s3gen.safetensors", device=load_device_cpu_str), strict=False)
s3gen.eval()
tokenizer = EnTokenizer(str(path_for_static_files / "tokenizer.json"))
# Conditionals laden - immer von base_model_dir...
conds = None
conds_path_static = path_for_static_files / "conds.pt"
if conds_path_static.exists():
try:
conds = Conditionals.load(conds_path_static, map_location=load_device_cpu_str)
except Exception as e_conds:
logger.warning(f"Konnte conds.pt nicht laden von {conds_path_static}: {e_conds}", exc_info=True)
else:
logger.info(f"Keine conds.pt in {path_for_static_files} (Basis fĂŒr statische Komponenten) gefunden.")
# --- T3 Initialisierung und Gewichte laden ---
# Initialisiere T3 mit der Ziel-Konfiguration, falls gegeben, sonst mit Default aus T3Config.py.
# Diese Instanz (final_t3_instance) wird diejenige sein, die zurĂŒckgegeben wird.
if target_t3_config_for_model:
final_t3_instance = T3(hp=target_t3_config_for_model)
else:
final_t3_instance = T3() # Verwendet Default-Config aus T3Config.py
# Finde T3-Gewichtsdatei
t3_weights_file_to_load: Optional[Path] = None
if path_for_t3_specific_weights.is_file() and any(str(path_for_t3_specific_weights).endswith(x) for x in [".safetensors", ".bin"]):
t3_weights_file_to_load = path_for_t3_specific_weights
logger.info(f" T3-Gewichtsdatei direkt als Pfad angegeben: {t3_weights_file_to_load.name}")
elif path_for_t3_specific_weights.is_dir():
potential_ckpt_names = ["t3_cfg.safetensors", "model.safetensors", "pytorch_model.bin"]
for name in potential_ckpt_names:
if (path_for_t3_specific_weights / name).exists():
t3_weights_file_to_load = path_for_t3_specific_weights / name
logger.info(f" T3-Gewichtsdatei in '{path_for_t3_specific_weights.name}' gefunden: {t3_weights_file_to_load.name}")
break
# Fallback, wenn ckpt_dir == base_dir und dort eine t3_cfg.safetensors liegt (typisch fĂŒr Basismodell)
if not t3_weights_file_to_load and path_for_t3_specific_weights == path_for_static_files:
if (path_for_static_files / "t3_cfg.safetensors").exists():
t3_weights_file_to_load = path_for_static_files / "t3_cfg.safetensors"
logger.info(f" Keine spezifische T3-Datei in '{path_for_t3_specific_weights.name}', verwende 't3_cfg.safetensors' aus Basis: {path_for_static_files.name}")
if not t3_weights_file_to_load:
raise FileNotFoundError(f"Keine passende T3-Gewichtsdatei in {path_for_t3_specific_weights} oder als Fallback in {path_for_static_files} gefunden.")
logger.info(f" Lade T3-Gewichte von: {t3_weights_file_to_load} auf CPU.")
if str(t3_weights_file_to_load).endswith(".safetensors"):
t3_state_raw = load_file(t3_weights_file_to_load, device=load_device_cpu_str)
elif str(t3_weights_file_to_load).endswith(".bin"):
t3_state_raw = torch.load(t3_weights_file_to_load, map_location=load_device_cpu_str)
else:
raise ValueError(f"Unbekanntes Dateiformat fĂŒr T3-Gewichte: {t3_weights_file_to_load}")
# PrĂ€fix-Bereinigung fĂŒr HF Trainer Modelle
cleaned_t3_state = {}
if t3_weights_file_to_load.name in ["model.safetensors", "pytorch_model.bin"]: # Namen, die typisch fĂŒr HF Trainer sind
if any(key.startswith("t3.") for key in t3_state_raw.keys()):
logger.info(" Entferne 't3.' PrĂ€fix aus T3 State Dict SchlĂŒsseln.")
for k, v_tensor in t3_state_raw.items():
if k.startswith("t3."): cleaned_t3_state[k[3:]] = v_tensor
if not cleaned_t3_state: # Fallback, falls keine SchlĂŒssel passten
logger.warning(" PrĂ€fix-Entfernung ergab leeres State Dict fĂŒr T3, verwende rohes State Dict.")
cleaned_t3_state = t3_state_raw
else: # Keine "t3." PrÀfixe, aber es ist eine HF Trainer Datei
logger.info(f" Verwende State Dict aus '{t3_weights_file_to_load.name}' direkt fĂŒr T3 (kein 't3.' PrĂ€fix gefunden).")
cleaned_t3_state = t3_state_raw
else: # z.B. t3_cfg.safetensors, sollte bereits korrekte SchlĂŒssel haben
logger.info(f" Verwende T3 State Dict aus '{t3_weights_file_to_load.name}' direkt.")
cleaned_t3_state = t3_state_raw
if not cleaned_t3_state:
raise ValueError(f"T3 State Dict ist leer nach Lade- und Bereinigungsversuchen fĂŒr {t3_weights_file_to_load}.")
# --- ANPASSUNG FĂR POSITIONAL EMBEDDINGS ---
# `final_t3_instance` wurde mit der Ziel-Config (oder Default) initialisiert.
# Ihre Embedding-Layer haben also die Ziel-Shapes.
pos_emb_map = {
"text_pos_emb.emb.weight": final_t3_instance.text_pos_emb.emb if hasattr(final_t3_instance, "text_pos_emb") else None,
"speech_pos_emb.emb.weight": final_t3_instance.speech_pos_emb.emb if hasattr(final_t3_instance, "speech_pos_emb") else None
}
for emb_key_in_checkpoint, target_embedding_layer in pos_emb_map.items():
if target_embedding_layer is None:
logger.debug(f"Ziel-Embedding-Layer fĂŒr '{emb_key_in_checkpoint}' nicht in final_t3_instance gefunden. Ăberspringe Anpassung.")
continue
if emb_key_in_checkpoint in cleaned_t3_state:
ckpt_emb_weights = cleaned_t3_state[emb_key_in_checkpoint]
target_emb_weights_tensor = target_embedding_layer.weight # nn.Embedding.weight
ckpt_shape = ckpt_emb_weights.shape
target_shape = target_emb_weights_tensor.shape
if ckpt_shape[0] != target_shape[0]: # Unterschied in der SequenzlÀngen-Dimension
logger.warning(
f"Passe GröĂe von Positional Embedding '{emb_key_in_checkpoint}' an: "
f"Checkpoint-Form {ckpt_shape} -> Modell-Form {target_shape}."
)
# Erstelle eine Kopie der initialisierten Gewichte des Zielmodells
# Diese werden dann teilweise mit den Checkpoint-Gewichten ĂŒberschrieben.
new_weights_for_model = target_emb_weights_tensor.data.clone()
len_to_copy = min(ckpt_shape[0], target_shape[0])
# Die Embedding-Dimension (dim=1) sollte ĂŒbereinstimmen.
if ckpt_shape[1] != target_shape[1]:
logger.error(f"FATAL: Embedding-Dimension Mismatch fĂŒr '{emb_key_in_checkpoint}'. "
f"Checkpoint: {ckpt_shape[1]}, Modell: {target_shape[1]}. Laden wird wahrscheinlich fehlschlagen.")
# Hier könnte man auch abbrechen oder versuchen, nur zu kopieren, wenn dim gleich ist.
# FĂŒrs Erste machen wir weiter und hoffen, dass load_state_dict den Fehler fĂ€ngt, falls Dims nicht passen.
dim_to_copy = min(ckpt_shape[1], target_shape[1]) # Nur zur Sicherheit
new_weights_for_model[:len_to_copy, :dim_to_copy] = ckpt_emb_weights[:len_to_copy, :dim_to_copy]
cleaned_t3_state[emb_key_in_checkpoint] = new_weights_for_model
logger.info(f" '{emb_key_in_checkpoint}' angepasst auf LĂ€nge {len_to_copy}.")
# Wenn Shapes gleich sind, keine Aktion.
else:
logger.warning(f"Positions-Embedding '{emb_key_in_checkpoint}' nicht im Checkpoint-State-Dict gefunden. "
"Wird im Modell initialisiert bleiben.")
# --- ENDE DER ANPASSUNG ---
missing_keys, unexpected_keys = final_t3_instance.load_state_dict(cleaned_t3_state, strict=False)
if missing_keys: logger.warning(f" Fehlende SchlĂŒssel beim Laden des T3 State Dict: {missing_keys}")
if unexpected_keys: logger.warning(f" Unerwartete SchlĂŒssel beim Laden des T3 State Dict: {unexpected_keys}")
final_t3_instance.eval() # T3 ist jetzt auf CPU mit (hoffentlich) korrekten Gewichten
# Erstelle ChatterboxTTS Instanz mit allen Komponenten auf CPU
model_instance = cls(final_t3_instance, s3gen, ve, tokenizer, torch.device(load_device_cpu_str), conds=conds)
# Verschiebe das GESAMTE Modell am Ende auf das ZielgerÀt
return model_instance.to(target_device_obj)
@classmethod
def from_pretrained(cls, device: str, target_t3_config_for_model: Optional[T3Config] = None) -> 'ChatterboxTTS': # target_config hier auch
logger.info(f"Lade vortrainiertes Modell '{REPO_ID}' vom Hugging Face Hub.")
try:
snapshot_dir = snapshot_download(
repo_id=REPO_ID,
allow_patterns=["*.safetensors", "*.json", "*.pt", "config.json"] # config.json fĂŒr T3Config
)
logger.info(f"Modell-Snapshot heruntergeladen nach: {snapshot_dir}")
except Exception as e:
logger.error(f"Fehler beim Herunterladen des Modells vom Hub {REPO_ID}: {e}", exc_info=True)
raise
# Wenn target_t3_config_for_model nicht explizit ĂŒbergeben wurde,
# versuchen wir, die Config aus dem heruntergeladenen Snapshot zu laden,
# falls der User es nicht schon getan hat.
# Wenn target_t3_config_for_model hier None ist, wird T3 mit Default initialisiert.
# Das ist okay, wenn der User danach in finetune die LĂ€ngen anpasst.
# Die Anpassung der PosEmbs in from_local wird dann mit der Default-T3-Config arbeiten.
# Ideal: Der Aufrufer von from_pretrained gibt die Ziel-Config an.
if target_t3_config_for_model is None:
logger.info("Keine explizite target_t3_config fĂŒr from_pretrained ĂŒbergeben. "
"T3 wird mit Default-Config initialisiert, es sei denn, der Basis-Snapshot enthÀlt eine ladbare Config "
"(was from_local nicht automatisch tut, es sei denn, target_t3_config wird ĂŒbergeben).")
return cls.from_local(
ckpt_dir_for_model_weights=snapshot_dir, # T3-Gewichte kommen auch von hier
base_model_dir_for_static_components=snapshot_dir, # Statische Teile auch
device=device,
target_t3_config_for_model=target_t3_config_for_model # Weitergeben
)
def prepare_conditionals(self, wav_fpath: str, exaggeration: float = 0.5) -> Optional[Conditionals]:
logger.info(f"Bereite Konditionierung mit Audio-Prompt vor: {wav_fpath}, Exaggeration: {exaggeration}")
try:
s3gen_ref_wav, sr_orig = librosa.load(wav_fpath, sr=None)
if sr_orig != self.sr: # self.sr ist S3GEN_SR
s3gen_ref_wav = librosa.resample(s3gen_ref_wav, orig_sr=sr_orig, target_sr=self.sr)
s3gen_ref_wav_tensor = torch.from_numpy(s3gen_ref_wav)
if s3gen_ref_wav_tensor.shape[0] < self.DEC_COND_LEN:
logger.warning(f"Audio-Prompt {wav_fpath} kĂŒrzer ({s3gen_ref_wav_tensor.shape[0]}) als DEC_COND_LEN ({self.DEC_COND_LEN}). Padding.")
s3gen_ref_wav_cut_tensor = F.pad(s3gen_ref_wav_tensor, (0, self.DEC_COND_LEN - s3gen_ref_wav_tensor.shape[0]))
else:
s3gen_ref_wav_cut_tensor = s3gen_ref_wav_tensor[:self.DEC_COND_LEN]
s3gen_ref_wav_cut_np = s3gen_ref_wav_cut_tensor.numpy()
s3gen_ref_dict = self.s3gen.embed_ref(s3gen_ref_wav_cut_np, self.sr, device=self.device)
if self.s3_tokenizer_sample_rate != self.sr:
ref_16k_wav_np = librosa.resample(s3gen_ref_wav, orig_sr=self.sr, target_sr=self.s3_tokenizer_sample_rate)
else:
ref_16k_wav_np = s3gen_ref_wav
t3_cond_prompt_tokens = None
# Verwende die Config der aktuellen T3-Instanz (self.t3.hp)
t3_config_current = self.t3.hp # Dies sollte die korrekt initialisierte T3Config sein
plen = getattr(t3_config_current, 'speech_cond_prompt_len', 0)
if plen > 0:
s3_tokzr = self.s3gen.tokenizer
ref_16k_wav_tensor_prompt = torch.from_numpy(ref_16k_wav_np)
if ref_16k_wav_tensor_prompt.shape[0] < self.ENC_COND_LEN:
prompt_audio_segment_tensor = F.pad(ref_16k_wav_tensor_prompt, (0, self.ENC_COND_LEN - ref_16k_wav_tensor_prompt.shape[0]))
else:
prompt_audio_segment_tensor = ref_16k_wav_tensor_prompt[:self.ENC_COND_LEN]
prompt_audio_segment_np = prompt_audio_segment_tensor.numpy()
tokens_from_16k_batch, _ = s3_tokzr.forward([prompt_audio_segment_np], max_len=plen)
if tokens_from_16k_batch is not None and tokens_from_16k_batch.numel() > 0:
t3_cond_prompt_tokens = torch.atleast_2d(tokens_from_16k_batch).to(self.device)
else:
t3_cond_prompt_tokens = torch.zeros((1, plen), dtype=torch.long, device=self.device)
else:
logger.info("T3 speech_cond_prompt_len ist 0. Keine Sprachprompt-Tokens fĂŒr T3 Konditionierung.")
ve_embed_np = self.ve.embeds_from_wavs([ref_16k_wav_np], sample_rate=self.s3_tokenizer_sample_rate)
ve_embed = torch.from_numpy(ve_embed_np).mean(axis=0, keepdim=True).to(self.device)
emotion_tensor = exaggeration * torch.ones(1, 1, 1, device=self.device)
t3_cond_instance = T3Cond(
speaker_emb=ve_embed,
cond_prompt_speech_tokens=t3_cond_prompt_tokens,
emotion_adv=emotion_tensor
# clap_emb und cond_prompt_speech_emb bleiben None (Default aus T3Cond)
).to(device=self.device) # Sicherstellen, dass die T3Cond Instanz auf dem GerÀt ist
processed_s3gen_ref_dict = {
k: v.to(self.device) if isinstance(v, torch.Tensor) else v
for k, v in s3gen_ref_dict.items()
}
new_conds = Conditionals(t3_cond_instance, processed_s3gen_ref_dict)
self.conds = new_conds # Wichtig: self.conds aktualisieren
logger.info(f"Konditionierung erfolgreich vorbereitet und in Instanz gespeichert.")
return new_conds
except Exception as e:
logger.error(f"Fehler in prepare_conditionals fĂŒr {wav_fpath}: {e}", exc_info=True)
self.conds = None
return None
def generate(
self,
text: str,
audio_prompt_path: Optional[str] = None,
exaggeration: Optional[float] = None,
cfg_weight: float = 0.5,
temperature: float = 0.8,
) -> Tuple[torch.Tensor, int]:
if exaggeration is None: exaggeration = 0.5
current_conds_to_use = self.conds
if audio_prompt_path:
prepared_conds = self.prepare_conditionals(audio_prompt_path, exaggeration=exaggeration)
if prepared_conds is None:
raise RuntimeError(f"Konnte Konditionierung nicht mit {audio_prompt_path} vorbereiten.")
current_conds_to_use = prepared_conds
elif not current_conds_to_use:
raise ValueError("Keine Konditionierung vorhanden. Bitte audio_prompt_path angeben oder Default laden.")
elif current_conds_to_use.t3.emotion_adv is None or exaggeration != current_conds_to_use.t3.emotion_adv.item():
logger.info(f"Aktualisiere Exaggeration in bestehenden Konditionierungen auf: {exaggeration}")
_cond_t3_old: T3Cond = current_conds_to_use.t3
new_emotion_adv = exaggeration * torch.ones(1, 1, 1, device=self.device)
current_conds_to_use.t3 = T3Cond( # Erstelle neue Instanz, da T3Cond ein dataclass ist
speaker_emb=_cond_t3_old.speaker_emb,
clap_emb=_cond_t3_old.clap_emb,
cond_prompt_speech_tokens=_cond_t3_old.cond_prompt_speech_tokens,
cond_prompt_speech_emb=_cond_t3_old.cond_prompt_speech_emb,
emotion_adv=new_emotion_adv # Aktualisiert
).to(device=self.device)
current_conds_to_use = current_conds_to_use.to(self.device) # Sicherstellen
text_normalized = punc_norm(text)
text_tokens_single = self.tokenizer.text_to_tokens(text_normalized).to(self.device)
text_tokens_batch = text_tokens_single.unsqueeze(0)
if cfg_weight > 0.0:
text_tokens_batch = torch.cat([text_tokens_batch, text_tokens_batch], dim=0)
# Verwende die Config der aktuellen T3-Instanz
t3_config_current = self.t3.hp
sot = getattr(t3_config_current, 'start_text_token', 0)
eot = getattr(t3_config_current, 'stop_text_token', 1)
if not (hasattr(t3_config_current, 'start_text_token')): # Einfacherer Check
logger.warning("Standard SOT/EOT fĂŒr Text verwendet, da nicht in T3Config gefunden.")
text_tokens_padded = F.pad(text_tokens_batch, (1, 0), value=sot)
text_tokens_padded = F.pad(text_tokens_padded, (0, 1), value=eot)
with torch.inference_mode():
t3_cond_for_inference: T3Cond = current_conds_to_use.t3
current_batch_size = text_tokens_padded.shape[0]
# Batch-GröĂen der Konditionierungs-Tensoren anpassen
for attr_name in ["speaker_emb", "clap_emb", "cond_prompt_speech_tokens", "cond_prompt_speech_emb", "emotion_adv"]:
attr_val = getattr(t3_cond_for_inference, attr_name)
if isinstance(attr_val, torch.Tensor) and attr_val.shape[0] == 1 and current_batch_size > 1:
# Wenn Batch-Dimension 1 ist aber current_batch_size > 1 (z.B. fĂŒr CFG), dann repeate.
num_repeats = [current_batch_size] + [1] * (attr_val.dim() - 1)
setattr(t3_cond_for_inference, attr_name, attr_val.repeat(*num_repeats))
elif isinstance(attr_val, torch.Tensor) and attr_val.shape[0] != current_batch_size:
# Dieser Fall sollte idealerweise nicht auftreten, wenn der Prompt schon fĂŒr CFG vorbereitet wurde.
logger.warning(f"Unerwarteter Batch-Size Mismatch fĂŒr {attr_name} in T3Cond. "
f"Tensor-Shape: {attr_val.shape}, Erwartet Batch: {current_batch_size}")
plen_from_hp = getattr(t3_config_current, 'speech_cond_prompt_len', 0)
if t3_cond_for_inference.cond_prompt_speech_tokens is None and plen_from_hp > 0:
dummy_prompt = torch.zeros((current_batch_size, plen_from_hp), dtype=torch.long, device=self.device)
t3_cond_for_inference.cond_prompt_speech_tokens = dummy_prompt
max_speech_tokens = getattr(t3_config_current, 'max_speech_tokens', 1000)
max_new_tokens = max_speech_tokens - (plen_from_hp if t3_cond_for_inference.cond_prompt_speech_tokens is not None else 0)
max_new_tokens = max(50, max_new_tokens)
speech_tokens_batched = self.t3.inference(
t3_cond=t3_cond_for_inference,
text_tokens=text_tokens_padded,
max_new_tokens=max_new_tokens,
temperature=temperature,
cfg_weight=cfg_weight,
)
speech_tokens_single = speech_tokens_batched[0] # Erster Teil fĂŒr CFG
speech_tokens_clean = drop_invalid_tokens(speech_tokens_single.squeeze()).to(self.device)
if speech_tokens_clean.numel() == 0:
wav_np = np.zeros(int(self.sr * 0.5), dtype=np.float32) # Kurze Stille
else:
s3gen_conds = current_conds_to_use.gen
# Sicherstellen, dass s3gen_conds auf dem richtigen GerÀt sind
for k_gen, v_gen in s3gen_conds.items():
if isinstance(v_gen, torch.Tensor): s3gen_conds[k_gen] = v_gen.to(self.device)
wav_batch, _ = self.s3gen.inference(
speech_tokens=speech_tokens_clean.unsqueeze(0),
ref_dict=s3gen_conds,
)
wav_np = wav_batch.squeeze(0).detach().cpu().numpy()
watermarked_wav_np = self.watermarker.apply_watermark(wav_np, sample_rate=self.sr)
return torch.from_numpy(watermarked_wav_np).unsqueeze(0), self.sr
and t3_config.py:
python
from ..llama_configs import LLAMA_CONFIGS
class T3Config:
start_text_token = 255
stop_text_token = 0
text_tokens_dict_size = 704
max_text_tokens = 2048
start_speech_token = 6561
stop_speech_token = 6562
speech_tokens_dict_size = 8194
max_speech_tokens = 4096
llama_config_name = "Llama_520M"
input_pos_emb = "learned"
speech_cond_prompt_len = 10 #was 150, allows shorter input samples
# For T3CondEnc
encoder_type = "voice_encoder"
speaker_embed_size = 256
use_perceiver_resampler = True
emotion_adv = True
@property
def n_channels(self):
return LLAMA_CONFIGS[self.llama_config_name]["hidden_size"]
My training is currencly running, I don't know if my code helps you here, but I added a little bit of memory management, so its more efficient. Maybe you had to reduce the sample and token lenght from 2048/4096 to lower values by editing the function call parameters to get it running on 24GB or 16GB.
how big is the batch influence on vram usage? I am thinking about just renting a H100 with 80 GB Vram or something like that to train if we have a dataset which is good enough
I have around 28GB VRAM with Batch Size of 1. I think with 80GB you are ready to go with higher values mate.
Great @havok2-htwo @rotatorotator can I use it for Arabic? can you guide me about data and steps?
Sure, I believe it should work if you have enough speech samples. I'm currently redesigning a lot of my code based on Geminiâs suggestions. I used the training code as a base from this repository, but I found a few issues in the tts.py file and also while reading through the cond_enc.py. Gemini provided me with the correct values for this model and the proper tokens for the start and stop points. I hope this will help me fix the issues and run a proper fine-tuning session to fully leverage the potential of this base model. I also realized that I trained for 128 epochs, but that was way too much. I actually achieved the best results at epoch 5, after which the training suffered from overfitting. So I will address that as well. I've generated a sample where you can hear the result, but there are some issues at the end that still need improvement. (https://jmp.sh/s/71S206JyeoDuZa0YK2ve)
Thanks. Could you share step by step of your finetuning code?
I'll hope so :D first I cloned this repo here. Then I edited and added some code. May you had to install some additional moduls by pip and so on, I think I have installed two or thee things. Then you need a ...anaconda3\envs\chatterbox_train_env\MyGenesisTTSProject\src\audio_data folder with wav and txt pairs. So the audio file and the transcription of the voice. They need the same name here. I have generated aprox 12k samples with elevenlaps API, filtered it by hand and deleted unwanted audio/txt pairs.
Then you had to run the preprocess_data.py like I have described above. May you have to adjust some parameters to match folders and so on. This tool creates a ton of .pt files. Also adds 300ms or whatever you put in (I used 500ms) silence in front and end of the audiofiles before converting them. Also it creates variations, louder more quiet, slower, faster, whatever you enter, but pitch up/down and faster sounds not good, so I leaved them out.
Then your cpu creates these files. This process is much more efficient. Otherwise in each epoch your cpu had to do this again and again and your gpu idle around a lot.
After that you can run the training also like described above by starting the finetune_t3_preprocessed.py file. In the end you have somewhere here: ...anaconda3\envs\chatterbox_train_env\MyGenesisTTSProject\checkpoints\chatterbox_GER_finetune_5VARIANTS_2048_4096_CMD the t3_cfg.safetensors and so on. I copied them into my other project where I used chatterbox_stream to access this model. But it wont work out of the box, I also had to edit some files in the model. But may its a good Idea to just ask gemeni, it always helps me to create or edit code. Just drop the files and say what you need :D
may this helps. I am not a pro, just have access to all the stuff, from AI Pro Tools to 5090 and a tiny cluster of 4x3090 at work. Good luck!
@havok2-htwo thank you very much for your code i create a new code
preprocess_data.py , finetune_t3_preprocessed
and tts.py , t3_config.py replace
i success run the code for preprocess_data.py , i run with --metadata_file because i use metadata.csv
now when i run the other code finetune_t3_preprocessed
File "C:\PythonApps\ChatterboxNew\chatterbox\src\finetune_t3_preprocessed.py", line 312, in forward
loss_text, loss_speech, speech_logits = self.t3.loss(
TypeError: T3.loss() got an unexpected keyword argument 'labels_text'
i see also in t3.py there is not value labels_text https://github.com/resemble-ai/chatterbox/blob/eb90621fa748f341a5b768aed0c0c12fc561894b/src/chatterbox/models/t3/t3.py#L167
can please tell me how fix thank you very much
Hey this is mine,
python
# Copyright (c) 2025 Resemble AI
# MIT License
import logging
from typing import Union, Optional, List
from tqdm import tqdm
import torch
import torch.nn.functional as F
from torch import nn, Tensor
from transformers import LlamaModel, LlamaConfig
from transformers.generation.logits_process import TopPLogitsWarper, RepetitionPenaltyLogitsProcessor
from .modules.learned_pos_emb import LearnedPositionEmbeddings
from .modules.cond_enc import T3CondEnc, T3Cond
from .modules.t3_config import T3Config
from .llama_configs import LLAMA_CONFIGS
from .inference.t3_hf_backend import T3HuggingfaceBackend
from .inference.alignment_stream_analyzer import AlignmentStreamAnalyzer
logger = logging.getLogger(__name__)
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self
def _ensure_BOT_EOT(text_tokens: Tensor, hp):
B = text_tokens.size(0)
assert (text_tokens == hp.start_text_token).int().sum() >= B, "missing start_text_token"
assert (text_tokens == hp.stop_text_token).int().sum() >= B, "missing stop_text_token"
class T3(nn.Module):
"""
Token-To-Token (T3) TTS model using huggingface transformer models as backbones,
* tokenization, including start / stop tokens are always added externally to this class
* conditioning data like CLAP, emotion, etc are all in a separate file for more modularity
* careful! this class assumes relative positional encoding -- with absolute PE, we would at
least want to reset the position to 0 when speech tokens begin, and optionally use a
different PE embedding space for speech.
"""
def __init__(self, hp=T3Config()):
super().__init__()
self.hp = hp
self.cfg = LlamaConfig(**LLAMA_CONFIGS[hp.llama_config_name])
self.tfmr = LlamaModel(self.cfg)
self.dim = self.cfg.hidden_size
self.deepspeed_patch_applied = False
# conditioning / embedding
self.cond_enc = T3CondEnc(hp)
self.text_emb = nn.Embedding(hp.text_tokens_dict_size, self.dim)
self.speech_emb = nn.Embedding(hp.speech_tokens_dict_size, self.dim)
# custom position embedding
if hp.input_pos_emb == "learned":
max_text_seq_len = hp.max_text_tokens + 2
self.text_pos_emb = LearnedPositionEmbeddings(max_text_seq_len, self.dim)
max_mel_seq_len = hp.max_speech_tokens + 2 + 2
self.speech_pos_emb = LearnedPositionEmbeddings(max_mel_seq_len, self.dim)
# logit projection
self.text_head = nn.Linear(self.cfg.hidden_size, hp.text_tokens_dict_size, bias=False)
self.speech_head = nn.Linear(self.cfg.hidden_size, hp.speech_tokens_dict_size, bias=False)
self.compiled = False
@property
def device(self):
return self.speech_head.weight.device
def prepare_conditioning(self, t3_cond: T3Cond):
"""
Token cond data needs to be embedded, so that needs to be here instead of in `T3CondEnc`.
"""
if t3_cond.cond_prompt_speech_tokens is not None and t3_cond.cond_prompt_speech_emb is None:
t3_cond.cond_prompt_speech_emb = self.speech_emb(t3_cond.cond_prompt_speech_tokens) + \
self.speech_pos_emb(t3_cond.cond_prompt_speech_tokens)
return self.cond_enc(t3_cond) # (B, len_cond, dim)
def prepare_input_embeds(
self,
*,
t3_cond: T3Cond,
text_tokens: torch.LongTensor,
speech_tokens: torch.LongTensor,
cfg_weight: float = 0.0,
):
# prepare input embeddings (skip backbone tranformer embeddings)
cond_emb = self.prepare_conditioning(t3_cond) # (B, len_cond, dim)
text_emb = self.text_emb(text_tokens) # (B, len_text, dim)
if cfg_weight > 0.0:
text_emb[1].zero_() # CFG uncond
speech_emb = self.speech_emb(speech_tokens) # (B, len_speech, dim)
if self.hp.input_pos_emb == "learned":
text_emb = text_emb + self.text_pos_emb(text_tokens)
speech_emb = speech_emb + self.speech_pos_emb(speech_tokens)
len_cond = cond_emb.size(1)
if cond_emb.size(0) != text_emb.size(0):
cond_emb = cond_emb.expand(text_emb.size(0), -1, -1)
# concat
embeds = torch.stack([
torch.cat((ce, te, se))
for ce, te, se in zip(cond_emb, text_emb, speech_emb)
]) # (B, length, dim)
return embeds, len_cond
def forward(
self,
*,
t3_cond: T3Cond,
text_tokens: torch.LongTensor,
text_token_lens: torch.LongTensor,
speech_tokens: torch.LongTensor,
speech_token_lens: torch.LongTensor,
training=False,
):
_ensure_BOT_EOT(text_tokens, self.hp)
# prepare custom input embeds
embeds, len_cond = self.prepare_input_embeds(
t3_cond=t3_cond,
text_tokens=text_tokens,
speech_tokens=speech_tokens,
)
# backbone tranformer forward
tfmr_out = self.tfmr.forward(
input_ids=None,
# position_ids=position_ids, # TODO? ROPE should be fine?
inputs_embeds=embeds,
output_hidden_states=True,
return_dict=True,
use_cache=(not training),
)
hidden_states = tfmr_out.hidden_states[-1] # final tfmr layer output, (B, seq, dim)
# post-processing: splice out text and speech parts of hidden states
len_text = text_tokens.size(1)
len_speech = speech_tokens.size(1)
B, _, dim = hidden_states.shape
device, dtype = hidden_states.device, hidden_states.dtype
text_latents = torch.zeros(B, len_text, dim, dtype=dtype, device=device)
speech_latents = torch.zeros(B, len_speech, dim, dtype=dtype, device=device)
ttl, stl = text_token_lens, speech_token_lens
for i in range(B):
text_end = len_cond + ttl[i].item()
speech_start = len_cond + text_tokens.size(1)
speech_end = speech_start + stl[i].item()
text_latents[i, :ttl[i]] = hidden_states[i, len_cond:text_end]
speech_latents[i, :stl[i]] = hidden_states[i, speech_start:speech_end]
# logit projection
text_logits = self.text_head(text_latents)
speech_logits = self.speech_head(speech_latents)
return AttrDict(
text_logits=text_logits,
text_latents=text_latents,
speech_logits=speech_logits,
speech_latents=speech_latents,
hidden_states=hidden_states,
)
def loss_old(
self,
*,
t3_cond: T3Cond,
text_tokens: torch.LongTensor,
text_token_lens: torch.LongTensor,
speech_tokens: torch.LongTensor,
speech_token_lens: torch.LongTensor,
):
"training method"
len_text = text_tokens.size(1)
len_speech = speech_tokens.size(1)
assert len_text == text_token_lens.max()
assert len_speech == speech_token_lens.max()
out = self.forward(
t3_cond=t3_cond,
text_tokens=text_tokens,
text_token_lens=text_token_lens,
speech_tokens=speech_tokens,
speech_token_lens=speech_token_lens,
training=True,
) # (B, seq, vocab_size)
# Calc CCE losses
IGNORE_ID = -100
device = out.text_logits.device
mask_text = torch.arange(len_text, device=device)[None] >= text_token_lens[:, None] # (B, len_text)
mask_speech = torch.arange(len_speech, device=device)[None] >= speech_token_lens[:, None] # (B, len_speech)
masked_text = text_tokens.masked_fill(mask_text, IGNORE_ID)
masked_speech = speech_tokens.masked_fill(mask_speech, IGNORE_ID)
loss_text = F.cross_entropy(out.text_logits.transpose(1, 2), masked_text, ignore_index=IGNORE_ID)
loss_speech = F.cross_entropy(out.speech_logits.transpose(1, 2), masked_speech, ignore_index=IGNORE_ID)
#loss_text = F.cross_entropy(out.text_logits, masked_text, ignore_index=IGNORE_ID)
#loss_speech = F.cross_entropy(out.speech_logits, masked_speech, ignore_index=IGNORE_ID)
return loss_text, loss_speech
def loss(
self,
*,
t3_cond: T3Cond,
text_tokens: torch.LongTensor, # (B, S_text_padded), includes BOS & EOS
text_token_lens: torch.LongTensor, # (B,), actual lengths including BOS & EOS
speech_tokens: torch.LongTensor, # (B, S_speech_padded), includes BOS & EOS
speech_token_lens: torch.LongTensor, # (B,), actual lengths including BOS & EOS
labels_text: torch.LongTensor, # (B, S_text_padded-1), already masked with â100
labels_speech: torch.LongTensor # (B, S_speech_padded-1), already masked with â100
):
"""
Compute text and speech cross-entropy using pre-masked labels from the collator.
Assumes:
- labels_text[t] corresponds to predicting text_tokens[:, 1:] with â100 where ignored
- labels_speech[t] corresponds to predicting speech_tokens[:, 1:] with â100 where ignored
"""
# 1) Run model to get logits
out = self.forward(
t3_cond=t3_cond,
text_tokens=text_tokens,
text_token_lens=text_token_lens,
speech_tokens=speech_tokens,
speech_token_lens=speech_token_lens,
training=True,
)
# out.text_logits: (B, S_text_padded, V_text)
# out.speech_logits: (B, S_speech_padded, V_speech)
device = out.text_logits.device
IGNORE_ID = -100
# --- Text Loss (use labels_text directly) ---
# Align logits: predict tâ..EOS from inputs [BOS, tâ..]
logits_for_text = out.text_logits[:, :-1, :].contiguous() # (B, S_text_padded-1, V_text)
# labels_text already has shape (B, S_text_padded-1) with â100 where masked
if logits_for_text.size(1) == 0:
loss_text = torch.tensor(0.0, device=device, requires_grad=self.training)
else:
loss_text = F.cross_entropy(
logits_for_text.transpose(1, 2), # (B, V_text, S_text_padded-1)
labels_text, # (B, S_text_padded-1), ignore_index=â100
ignore_index=IGNORE_ID
)
# --- Speech Loss (use labels_speech directly) ---
logits_for_speech = out.speech_logits[:, :-1, :].contiguous() # (B, S_speech_padded-1, V_speech)
# labels_speech already has shape (B, S_speech_padded-1) with â100 where masked
if logits_for_speech.size(1) == 0:
loss_speech = torch.tensor(0.0, device=device, requires_grad=self.training)
else:
loss_speech = F.cross_entropy(
logits_for_speech.transpose(1, 2), # (B, V_speech, S_speech_padded-1)
labels_speech, # (B, S_speech_padded-1), ignore_index=â100
ignore_index=IGNORE_ID
)
return loss_text, loss_speech, out.speech_logits
@torch.inference_mode()
def inference(
self,
*,
t3_cond: T3Cond,
text_tokens: Tensor,
initial_speech_tokens: Optional[Tensor]=None,
# misc conditioning
prepend_prompt_speech_tokens: Optional[Tensor]=None,
# HF generate args
num_return_sequences=1,
max_new_tokens=None,
stop_on_eos=True,
do_sample=True,
temperature=0.8,
top_p=0.8,
length_penalty=1.0,
repetition_penalty=2.0,
cfg_weight=0,
):
"""
Args:
text_tokens: a 1D (unbatched) or 2D (batched) tensor.
"""
# Validate / sanitize inputs
assert prepend_prompt_speech_tokens is None, "not implemented"
_ensure_BOT_EOT(text_tokens, self.hp)
text_tokens = torch.atleast_2d(text_tokens).to(dtype=torch.long, device=self.device)
# Default initial speech to a single start-of-speech token
if initial_speech_tokens is None:
initial_speech_tokens = self.hp.start_speech_token * torch.ones_like(text_tokens[:, :1])
# Prepare custom input embeds
embeds, len_cond = self.prepare_input_embeds(
t3_cond=t3_cond,
text_tokens=text_tokens,
speech_tokens=initial_speech_tokens,
cfg_weight=cfg_weight,
)
# In order to use the standard HF generate method, we need to extend some methods to inject our custom logic
# Note the llama-specific logic. Other tfmr types can be added later.
self.compiled = False
# TODO? synchronize the expensive compile function
# with self.compile_lock:
if not self.compiled:
alignment_stream_analyzer = AlignmentStreamAnalyzer(
self.tfmr,
None,
text_tokens_slice=(len_cond, len_cond + text_tokens.size(-1)),
alignment_layer_idx=9, # TODO: hparam or something?
eos_idx=self.hp.stop_speech_token,
)
patched_model = T3HuggingfaceBackend(
config=self.cfg,
llama=self.tfmr,
speech_enc=self.speech_emb,
speech_head=self.speech_head,
alignment_stream_analyzer=alignment_stream_analyzer,
)
self.patched_model = patched_model
self.compiled = True
# # Run normal generate method, which calls our custom extended methods
# return self.patched_model.generate(
# inputs=initial_speech_tokens,
# decoder_cond=embeds,
# bos_token_id=self.hp.start_speech_token,
# eos_token_id=(self.hp.stop_speech_token if stop_on_eos else -1),
# pad_token_id=self.hp.stop_speech_token,
# max_new_tokens=max_new_tokens or self.hp.max_speech_tokens,
# num_return_sequences=num_return_sequences,
# temperature=temperature,
# top_p=top_p,
# length_penalty=length_penalty,
# repetition_penalty=repetition_penalty,
# do_sample=do_sample,
# # cache_implementation=None if not self.compiled else "static",
# )
device = embeds.device
bos_token = torch.tensor([[self.hp.start_speech_token]], dtype=torch.long, device=device)
bos_embed = self.speech_emb(bos_token) # shape: (B, 1, embed_dim)
bos_embed = bos_embed + self.speech_pos_emb.get_fixed_embedding(0)
# batch_size=2 for CFG
bos_embed = torch.cat([bos_embed, bos_embed])
# Combine condition and BOS token for the initial input if cfg_weight > 0
if cfg_weight > 0:
inputs_embeds = torch.cat([embeds, bos_embed], dim=1)
else:
inputs_embeds = embeds
# Track generated token ids; start with the BOS token.
generated_ids = bos_token.clone()
predicted = [] # To store the predicted tokens
# Instantiate the logits processors.
top_p_warper = TopPLogitsWarper(top_p=top_p)
repetition_penalty_processor = RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)
# ---- Initial Forward Pass (no kv_cache yet) ----
output = self.patched_model(
inputs_embeds=inputs_embeds,
past_key_values=None,
use_cache=True,
output_attentions=True,
output_hidden_states=True,
return_dict=True,
)
# Initialize kv_cache with the full context.
past = output.past_key_values
# ---- Generation Loop using kv_cache ----
for i in tqdm(range(max_new_tokens), desc="Sampling", dynamic_ncols=True):
logits = output.logits[:, -1, :]
# CFG
if cfg_weight > 0.0:
logits_cond = logits[0:1]
logits_uncond = logits[1:2]
logits = logits_cond + cfg_weight * (logits_cond - logits_uncond)
logits = logits.squeeze(1)
# Apply temperature scaling.
if temperature != 1.0:
logits = logits / temperature
# Apply repetition penalty and topâp filtering.
logits = repetition_penalty_processor(generated_ids, logits)
logits = top_p_warper(None, logits)
# Convert logits to probabilities and sample the next token.
probs = torch.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1) # shape: (B, 1)
predicted.append(next_token)
generated_ids = torch.cat([generated_ids, next_token], dim=1)
# Check for EOS token.
if next_token.view(-1) == self.hp.stop_speech_token:
break
# Get embedding for the new token.
next_token_embed = self.speech_emb(next_token)
next_token_embed = next_token_embed + self.speech_pos_emb.get_fixed_embedding(i + 1)
# For CFG
if cfg_weight > 0.0:
next_token_embed = torch.cat([next_token_embed, next_token_embed])
# Forward pass with only the new token and the cached past.
output = self.patched_model(
inputs_embeds=next_token_embed,
past_key_values=past,
output_attentions=True,
output_hidden_states=True,
return_dict=True,
)
# Update the kv_cache.
past = output.past_key_values
# Concatenate all predicted tokens along the sequence dimension.
predicted_tokens = torch.cat(predicted, dim=1) # shape: (B, num_tokens)
return predicted_tokens
Hope this helps.
My training is now running since 72h. It now sounds very german, but still have problems with the correct order of words. I will continue the training, and then will share my weights.
Hello lpscr! You have analyzed the problem perfectly. The TypeError occurs because your new finetune_t3_preprocessed.py script is calling an updated loss function in the T3 model, but your local t3.py file still contains the old version of that function. The new finetuning script passes the arguments labels_text and labels_speech to make the process more efficient, but the original loss function doesn't recognize these arguments. The solution is exactly what havok2-htwo provided in their latest comment. You need to update your t3.py file as well. How to Fix the Error: Find the correct file: Navigate to the path you already identified: chatterbox/src/chatterbox/models/t3/t3.py. Copy the new code: Go to havok2-htwo's last comment (the one that starts with # Copyright (c) 2025 Resemble AI). Replace the content: Copy the entire code block from that comment and use it to completely replace the contents of your local t3.py file. Why This Fixes the Error: If you look at the new code for t3.py, you will see that havok2-htwo has refactored the loss function. It now has the exact signature that your finetuning script expects: Generated python
In the new t3.py from havok2-htwo
python
def loss(
self,
*,
t3_cond: T3Cond,
text_tokens: torch.LongTensor,
text_token_lens: torch.LongTensor,
speech_tokens: torch.LongTensor,
speech_token_lens: torch.LongTensor,
labels_text: torch.LongTensor, # <--- Here is the expected argument!
labels_speech: torch.LongTensor # <--- And here is the second one!
):
# ... function logic
After you update your t3.py file with this new code, the T3.loss() function will correctly accept the labels_text and labels_speech arguments, and the TypeError will be resolved. Your training script should run without this error after making that change. Good luck. Dear Gemeni
@havok2-htwo Thank you very much for checking this very helpfull
@havok2-htwo Thank you very much for checking this very helpfull
Your welcome. My training runs good even with 36°C room temperature :D thanks god for custom water cooling. I tried yesterday to merge it with kartoffels model and the result was pretty good. I let the training run another day and then I will release the german fine tuned model also a merged model with Kartoffelâs one.
@rotatorotator here I uploaded the trained model, weights and also a merged version with Kartoffelbox and mine: https://huggingface.co/havok2/Kartoffelbox-v0.1_0.65h2
have fun with it =)
@rotatorotator here I uploaded the trained model, weights and also a merged version with Kartoffelbox and mine: https://huggingface.co/havok2/Kartoffelbox-v0.1_0.65h2
have fun with it =)
Great @havok2-htwo , How can I do like you on Arabic? Is it valuable about integration if Karoffelbox with Chatterbox? I dont know why you mention this integration.
@rotatorotator here I uploaded the trained model, weights and also a merged version with Kartoffelbox and mine: https://huggingface.co/havok2/Kartoffelbox-v0.1_0.65h2 have fun with it =)
Great @havok2-htwo , How can I do like you on Arabic? Is it valuable about integration if Karoffelbox with Chatterbox? I dont know why you mention this integration.
Hey I described the workflow above also shared the adjusted code. All you need is a lot of arabic wav files with their txt pair which contains the transcription. Then you train the stuff.
My model was able to produce good german sou ding voice, but struggle with the right order of words. So my model was good in producing fake-german but was not so good in german grammar. Kartoffels model instead was good in german grammar - so I mixed it with mine. I got then a better model. I only hat 12.000 german samples for training. Kartoffel used ~2.500.000 samples. So I think if you want to train arabic you had to generate a lot of data with elevenlabs for example. Otherwise may you have access to a big arabic voice database, then you can use whisper ai model to transcript them all to txt files.
@havok2-htwo Hi there,
Iâm sorry to bother you Iâve been having some trouble. Iâm trying to train my model, but I keep running into an out-of-memory error. I have an RTX 4090 (24 GB VRAM). Could you please let me know which settings I should use to fit within the 24 GB limit?
Thank you very much for all your work and your time!
â Issue: Help Needed to Use finetune_t3_preprocessed.py Step by Step
Hi, I'm trying to fine-tune a model using your codebase and would appreciate step-by-step guidance. Here's what I'm doing and where Iâm stuck:
đ§ My Setup and Steps So Far
-
Python Version:
- Which version of Python do you recommend for this project?
I'm currently using Python 3.11.
- Which version of Python do you recommend for this project?
-
Dataset Format:
- My data is formatted similarly to LJSpeech:
- A
wavs/directory containing audio files - A
metadata.csvfile structured as:id|transcription
- A
- My data is formatted similarly to LJSpeech:
-
Script Used:
- Iâm trying to run the training using
finetune_t3_preprocessed.py.
- Iâm trying to run the training using
â ïž Current Problem
When I execute the script, I get the following error:
/python3.11/site-packages/transformers/utils/import_utils.py", line 1780, in _get_module
raise RuntimeError(
RuntimeError: Failed to import transformers.models.llama.modeling_llama because of the following error (look up to see its traceback):
operator torchvision::nms does not exist
â Questions
- What version of Python, PyTorch, and Transformers should I be using?
- Is there a specific preprocessing step I might be missing?
- How can I resolve the
torchvision::nms does not existerror? - Could you provide a step-by-step guide to fine-tune using
finetune_t3_preprocessed.pyon LJSpeech-style data?
Thanks in advance for your support @havok2-htwo !
â Issue: Help Needed to Use
finetune_t3_preprocessed.pyStep by StepHi, I'm trying to fine-tune a model using your codebase and would appreciate step-by-step guidance. Here's what I'm doing and where Iâm stuck:
đ§ My Setup and Steps So Far
Python Version:
- Which version of Python do you recommend for this project? I'm currently using Python 3.11.
Dataset Format:
My data is formatted similarly to LJSpeech:
- A
wavs/directory containing audio files- A
metadata.csvfile structured as:id|transcriptionScript Used:
- Iâm trying to run the training using
finetune_t3_preprocessed.py.â ïž Current Problem
When I execute the script, I get the following error:
/python3.11/site-packages/transformers/utils/import_utils.py", line 1780, in _get_module raise RuntimeError( RuntimeError: Failed to import transformers.models.llama.modeling_llama because of the following error (look up to see its traceback): operator torchvision::nms does not existâ Questions
- What version of Python, PyTorch, and Transformers should I be using?
- Is there a specific preprocessing step I might be missing?
- How can I resolve the
torchvision::nms does not existerror?- Could you provide a step-by-step guide to fine-tune using
finetune_t3_preprocessed.pyon LJSpeech-style data?Thanks in advance for your support @havok2-htwo !
hey to train the model with the parameters as I have written there, you need at least 27 or 28 GB of graphics memory. You would then have to reduce the maximum token length when creating the training, i.e. set the 2048 to 1024, for example. So when preparing the material as well as during the training itself. but then the training samples will be so far. I understand shorter being used. In other words, it should actually be around 18 seconds at the moment, then it would only be nine, if I'm interpreting it correctly. So it's the same during preparation and training. I'm already on a chunk during training. Size one and smaller is unfortunately not possible.
I prepare my data and run preprocessed_data.py and it take a long time, i think it get stuck! @havok2-htwo I have 120 cpu core and V100
2025-07-05 00:23:37,643 - INFO - __main__ - Found 5325 original audio-text pairs. Starting preprocessing... - INFO
2025-07-05 00:23:37,643 - INFO - __main__ - Using 100 worker processes. Each original file generates up to 4 samples. 00,512
2025-07-05 00:23:37,645 - INFO - __main__ - Total expected samples to attempt (incl. augmentations): 21300
Processing original files (incl. augmentations): 0%| | 0/5325 [00:00<?, ?it/s]2025-07-05 00:24:00,512 - INFO - __main__ - Saving test audios for LJmine-1.wav (OriginalIndex 0) in processed_tts_data_AUGMENTED_WITH_PARAMS_2048_4096/test_audio_samples_debug Jul-25
Another question about G2P, do I need to change anything for my Arabic language?
Hey, I am currently out of home, but as far I can remember I have P3.10 and Torch 2.8 Nightly Builds - was the only one I found that worked with my 5090 so far.
@cod3r0k Can you post one txt file example? How long are your wavâs? Try only 1 Worker - sounds strange but the original file uses only 1 under windows⊠I installed something then it worked with more than 1. Just to try. My 12k samples prepared in around 30minutes on my 9800x3D with Workercount 14. I have 64GB Ram. May 100 cores ran out of memory? May try without augmentation parameters?
I give more help next week. Greetings