Text-to-Speech
Vietnamese
vietnamese
female
male
voice-cloning
EraX-Smile-Female-F5-V1.0 / f5tts_wrapper.py
erax's picture
Rename model/f5tts_wrapper.py to f5tts_wrapper.py
3b13719 verified
import os
import torch
import torchaudio
import numpy as np
from pathlib import Path
from typing import Optional, Union, List, Tuple, Dict
from cached_path import cached_path
from hydra.utils import get_class
from omegaconf import OmegaConf
from importlib.resources import files
from pydub import AudioSegment, silence
from f5_tts.model import CFM
from f5_tts.model.utils import (
get_tokenizer,
convert_char_to_pinyin,
)
from f5_tts.infer.utils_infer import (
chunk_text,
load_vocoder,
transcribe,
initialize_asr_pipeline,
)
class F5TTSWrapper:
"""
A wrapper class for F5-TTS that preprocesses reference audio once
and allows for repeated TTS generation.
"""
def __init__(
self,
model_name: str = "F5TTS_v1_Base",
ckpt_path: Optional[str] = None,
vocab_file: Optional[str] = None,
vocoder_name: str = "vocos",
use_local_vocoder: bool = False,
vocoder_path: Optional[str] = None,
device: Optional[str] = None,
hf_cache_dir: Optional[str] = None,
target_sample_rate: int = 24000,
n_mel_channels: int = 100,
hop_length: int = 256,
win_length: int = 1024,
n_fft: int = 1024,
ode_method: str = "euler",
use_ema: bool = True,
):
"""
Initialize the F5-TTS wrapper with model configuration.
Args:
model_name: Name of the F5-TTS model variant (e.g., "F5TTS_v1_Base")
ckpt_path: Path to the model checkpoint file. If None, will use default path.
vocab_file: Path to the vocab file. If None, will use default.
vocoder_name: Name of the vocoder to use ("vocos" or "bigvgan")
use_local_vocoder: Whether to use a local vocoder or download from HF
vocoder_path: Path to the local vocoder. Only used if use_local_vocoder is True.
device: Device to run the model on. If None, will automatically determine.
hf_cache_dir: Directory to cache HuggingFace models
target_sample_rate: Target sample rate for audio
n_mel_channels: Number of mel channels
hop_length: Hop length for the mel spectrogram
win_length: Window length for the mel spectrogram
n_fft: FFT size for the mel spectrogram
ode_method: ODE method for sampling ("euler" or "midpoint")
use_ema: Whether to use EMA weights from the checkpoint
"""
# Set device
if device is None:
self.device = (
"cuda" if torch.cuda.is_available()
else "xpu" if torch.xpu.is_available()
else "mps" if torch.backends.mps.is_available()
else "cpu"
)
else:
self.device = device
# Audio processing parameters
self.target_sample_rate = target_sample_rate
self.n_mel_channels = n_mel_channels
self.hop_length = hop_length
self.win_length = win_length
self.n_fft = n_fft
self.mel_spec_type = vocoder_name
# Sampling parameters
self.ode_method = ode_method
# Initialize ASR for transcription if needed
initialize_asr_pipeline(device=self.device)
# Load model configuration
if ckpt_path is None:
repo_name = "F5-TTS"
ckpt_step = 1250000
ckpt_type = "safetensors"
# Adjust for previous models
if model_name == "F5TTS_Base":
if vocoder_name == "vocos":
ckpt_step = 1200000
elif vocoder_name == "bigvgan":
model_name = "F5TTS_Base_bigvgan"
ckpt_type = "pt"
elif model_name == "E2TTS_Base":
repo_name = "E2-TTS"
ckpt_step = 1200000
ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{model_name}/model_{ckpt_step}.{ckpt_type}"))
# Load model configuration
config_path = str(files("f5_tts").joinpath(f"configs/{model_name}.yaml"))
model_cfg = OmegaConf.load(config_path)
model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
model_arc = model_cfg.model.arch
# Load tokenizer
if vocab_file is None:
vocab_file = str(files("f5_tts").joinpath("infer/examples/vocab.txt"))
tokenizer_type = "custom"
self.vocab_char_map, vocab_size = get_tokenizer(vocab_file, tokenizer_type)
# Create model
self.model = CFM(
transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
mel_spec_kwargs=dict(
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
n_mel_channels=n_mel_channels,
target_sample_rate=target_sample_rate,
mel_spec_type=vocoder_name,
),
odeint_kwargs=dict(
method=ode_method,
),
vocab_char_map=self.vocab_char_map,
).to(self.device)
# Load checkpoint
dtype = torch.float32 if vocoder_name == "bigvgan" else None
self._load_checkpoint(self.model, ckpt_path, dtype=dtype, use_ema=use_ema)
# Load vocoder
if vocoder_path is None:
if vocoder_name == "vocos":
vocoder_path = "../checkpoints/vocos-mel-24khz"
elif vocoder_name == "bigvgan":
vocoder_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
self.vocoder = load_vocoder(
vocoder_name=vocoder_name,
is_local=use_local_vocoder,
local_path=vocoder_path,
device=self.device,
hf_cache_dir=hf_cache_dir
)
# Storage for reference data
self.ref_audio_processed = None
self.ref_text = None
self.ref_audio_len = None
# Default inference parameters
self.target_rms = 0.1
self.cross_fade_duration = 0.15
self.nfe_step = 32
self.cfg_strength = 2.0
self.sway_sampling_coef = -1.0
self.speed = 1.0
self.fix_duration = None
def _load_checkpoint(self, model, ckpt_path, dtype=None, use_ema=True):
"""
Load model checkpoint with proper handling of different checkpoint formats.
Args:
model: The model to load weights into
ckpt_path: Path to the checkpoint file
dtype: Data type for model weights
use_ema: Whether to use EMA weights from the checkpoint
Returns:
Loaded model
"""
if dtype is None:
dtype = (
torch.float16
if "cuda" in self.device
and torch.cuda.get_device_properties(self.device).major >= 7
and not torch.cuda.get_device_name().endswith("[ZLUDA]")
else torch.float32
)
model = model.to(dtype)
ckpt_type = ckpt_path.split(".")[-1]
if ckpt_type == "safetensors":
from safetensors.torch import load_file
checkpoint = load_file(ckpt_path, device=self.device)
else:
checkpoint = torch.load(ckpt_path, map_location=self.device, weights_only=True)
if use_ema:
if ckpt_type == "safetensors":
checkpoint = {"ema_model_state_dict": checkpoint}
checkpoint["model_state_dict"] = {
k.replace("ema_model.", ""): v
for k, v in checkpoint["ema_model_state_dict"].items()
if k not in ["initted", "step"]
}
# patch for backward compatibility
for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]:
if key in checkpoint["model_state_dict"]:
del checkpoint["model_state_dict"][key]
model.load_state_dict(checkpoint["model_state_dict"])
else:
if ckpt_type == "safetensors":
checkpoint = {"model_state_dict": checkpoint}
model.load_state_dict(checkpoint["model_state_dict"])
del checkpoint
torch.cuda.empty_cache()
return model.to(self.device)
def preprocess_reference(self, ref_audio_path: str, ref_text: str = "", clip_short: bool = True):
"""
Preprocess the reference audio and text, storing them for later use.
Args:
ref_audio_path: Path to the reference audio file
ref_text: Text transcript of reference audio. If empty, will auto-transcribe.
clip_short: Whether to clip long audio to shorter segments
Returns:
Tuple of processed audio and text
"""
print("Converting audio...")
# Load audio file
aseg = AudioSegment.from_file(ref_audio_path)
if clip_short:
# 1. try to find long silence for clipping
non_silent_segs = silence.split_on_silence(
aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000, seek_step=10
)
non_silent_wave = AudioSegment.silent(duration=0)
for non_silent_seg in non_silent_segs:
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 12000:
print("Audio is over 12s, clipping short. (1)")
break
non_silent_wave += non_silent_seg
# 2. try to find short silence for clipping if 1. failed
if len(non_silent_wave) > 12000:
non_silent_segs = silence.split_on_silence(
aseg, min_silence_len=100, silence_thresh=-40, keep_silence=1000, seek_step=10
)
non_silent_wave = AudioSegment.silent(duration=0)
for non_silent_seg in non_silent_segs:
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 12000:
print("Audio is over 12s, clipping short. (2)")
break
non_silent_wave += non_silent_seg
aseg = non_silent_wave
# 3. if no proper silence found for clipping
if len(aseg) > 12000:
aseg = aseg[:12000]
print("Audio is over 12s, clipping short. (3)")
# Remove silence edges
aseg = self._remove_silence_edges(aseg) + AudioSegment.silent(duration=50)
# Export to temporary file and load as tensor
import tempfile
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
aseg.export(tmp_file.name, format="wav")
processed_audio_path = tmp_file.name
# Transcribe if needed
if not ref_text.strip():
print("No reference text provided, transcribing reference audio...")
ref_text = transcribe(processed_audio_path)
else:
print("Using custom reference text...")
# Ensure ref_text ends with proper punctuation
if not ref_text.endswith(". ") and not ref_text.endswith("。"):
if ref_text.endswith("."):
ref_text += " "
else:
ref_text += ". "
print("\nReference text:", ref_text)
# Load and process audio
audio, sr = torchaudio.load(processed_audio_path)
if audio.shape[0] > 1: # Convert stereo to mono
audio = torch.mean(audio, dim=0, keepdim=True)
# Normalize volume
rms = torch.sqrt(torch.mean(torch.square(audio)))
if rms < self.target_rms:
audio = audio * self.target_rms / rms
# Resample if needed
if sr != self.target_sample_rate:
resampler = torchaudio.transforms.Resample(sr, self.target_sample_rate)
audio = resampler(audio)
# Move to device
audio = audio.to(self.device)
# Store reference data
self.ref_audio_processed = audio
self.ref_text = ref_text
self.ref_audio_len = audio.shape[-1] // self.hop_length
# Remove temporary file
os.unlink(processed_audio_path)
return audio, ref_text
def _remove_silence_edges(self, audio, silence_threshold=-42):
"""
Remove silence from the start and end of audio.
Args:
audio: AudioSegment to process
silence_threshold: dB threshold to consider as silence
Returns:
Processed AudioSegment
"""
# Remove silence from the start
non_silent_start_idx = silence.detect_leading_silence(audio, silence_threshold=silence_threshold)
audio = audio[non_silent_start_idx:]
# Remove silence from the end
non_silent_end_duration = audio.duration_seconds
for ms in reversed(audio):
if ms.dBFS > silence_threshold:
break
non_silent_end_duration -= 0.001
trimmed_audio = audio[: int(non_silent_end_duration * 1000)]
return trimmed_audio
def generate(
self,
text: str,
output_path: Optional[str] = None,
nfe_step: Optional[int] = None,
cfg_strength: Optional[float] = None,
sway_sampling_coef: Optional[float] = None,
speed: Optional[float] = None,
fix_duration: Optional[float] = None,
cross_fade_duration: Optional[float] = None,
return_numpy: bool = False,
return_spectrogram: bool = False,
) -> Union[str, Tuple[np.ndarray, int], Tuple[np.ndarray, int, np.ndarray]]:
"""
Generate speech for the given text using the stored reference audio.
Args:
text: Text to synthesize
output_path: Path to save the generated audio. If None, won't save.
nfe_step: Number of function evaluation steps
cfg_strength: Classifier-free guidance strength
sway_sampling_coef: Sway sampling coefficient
speed: Speed of generated audio
fix_duration: Fixed duration in seconds
cross_fade_duration: Duration of cross-fade between segments
return_numpy: If True, returns the audio as a numpy array
return_spectrogram: If True, also returns the spectrogram
Returns:
If output_path provided: path to output file
If return_numpy=True: tuple of (audio_array, sample_rate)
If return_spectrogram=True: tuple of (audio_array, sample_rate, spectrogram)
"""
if self.ref_audio_processed is None or self.ref_text is None:
raise ValueError("Reference audio not preprocessed. Call preprocess_reference() first.")
# Use default values if not specified
nfe_step = nfe_step if nfe_step is not None else self.nfe_step
cfg_strength = cfg_strength if cfg_strength is not None else self.cfg_strength
sway_sampling_coef = sway_sampling_coef if sway_sampling_coef is not None else self.sway_sampling_coef
speed = speed if speed is not None else self.speed
fix_duration = fix_duration if fix_duration is not None else self.fix_duration
cross_fade_duration = cross_fade_duration if cross_fade_duration is not None else self.cross_fade_duration
# Split the input text into batches
audio_len = self.ref_audio_processed.shape[-1] / self.target_sample_rate
max_chars = int(len(self.ref_text.encode("utf-8")) / audio_len * (22 - audio_len))
text_batches = chunk_text(text, max_chars=max_chars)
for i, text_batch in enumerate(text_batches):
print(f"Text batch {i}: {text_batch}")
print("\n")
# Generate audio for each batch
generated_waves = []
spectrograms = []
for text_batch in text_batches:
# Adjust speed for very short texts
local_speed = speed
if len(text_batch.encode("utf-8")) < 10:
local_speed = 0.3
# Prepare the text
text_list = [self.ref_text + text_batch]
final_text_list = convert_char_to_pinyin(text_list)
# Calculate duration
if fix_duration is not None:
duration = int(fix_duration * self.target_sample_rate / self.hop_length)
else:
# Calculate duration based on text length
ref_text_len = len(self.ref_text.encode("utf-8"))
gen_text_len = len(text_batch.encode("utf-8"))
duration = self.ref_audio_len + int(self.ref_audio_len / ref_text_len * gen_text_len / local_speed)
# Generate audio
with torch.inference_mode():
generated, _ = self.model.sample(
cond=self.ref_audio_processed,
text=final_text_list,
duration=duration,
steps=nfe_step,
cfg_strength=cfg_strength,
sway_sampling_coef=sway_sampling_coef,
)
# Process the generated mel spectrogram
generated = generated.to(torch.float32)
generated = generated[:, self.ref_audio_len:, :]
generated = generated.permute(0, 2, 1)
# Convert to audio
if self.mel_spec_type == "vocos":
generated_wave = self.vocoder.decode(generated)
elif self.mel_spec_type == "bigvgan":
generated_wave = self.vocoder(generated)
# Normalize volume if needed
rms = torch.sqrt(torch.mean(torch.square(self.ref_audio_processed)))
if rms < self.target_rms:
generated_wave = generated_wave * rms / self.target_rms
# Convert to numpy and append to list
generated_wave = generated_wave.squeeze().cpu().numpy()
generated_waves.append(generated_wave)
# Store spectrogram if needed
if return_spectrogram or output_path is not None:
spectrograms.append(generated.squeeze().cpu().numpy())
# Combine all segments
if generated_waves:
if cross_fade_duration <= 0:
# Simply concatenate
final_wave = np.concatenate(generated_waves)
else:
# Cross-fade between segments
final_wave = generated_waves[0]
for i in range(1, len(generated_waves)):
prev_wave = final_wave
next_wave = generated_waves[i]
# Calculate cross-fade samples
cross_fade_samples = int(cross_fade_duration * self.target_sample_rate)
cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave))
if cross_fade_samples <= 0:
# No overlap possible, concatenate
final_wave = np.concatenate([prev_wave, next_wave])
continue
# Create cross-fade
prev_overlap = prev_wave[-cross_fade_samples:]
next_overlap = next_wave[:cross_fade_samples]
fade_out = np.linspace(1, 0, cross_fade_samples)
fade_in = np.linspace(0, 1, cross_fade_samples)
cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in
final_wave = np.concatenate([
prev_wave[:-cross_fade_samples],
cross_faded_overlap,
next_wave[cross_fade_samples:]
])
# Combine spectrograms if needed
if return_spectrogram or output_path is not None:
combined_spectrogram = np.concatenate(spectrograms, axis=1)
# Save to file if path provided
if output_path is not None:
output_dir = os.path.dirname(output_path)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir)
# Save audio
torchaudio.save(output_path,
torch.tensor(final_wave).unsqueeze(0),
self.target_sample_rate)
# Save spectrogram if needed
if return_spectrogram:
spectrogram_path = os.path.splitext(output_path)[0] + '_spec.png'
self._save_spectrogram(combined_spectrogram, spectrogram_path)
if not return_numpy:
return output_path
# Return as requested
if return_spectrogram:
return final_wave, self.target_sample_rate, combined_spectrogram
else:
return final_wave, self.target_sample_rate
else:
raise RuntimeError("No audio generated")
def _save_spectrogram(self, spectrogram, path):
"""Save spectrogram as image"""
import matplotlib.pyplot as plt
plt.figure(figsize=(12, 4))
plt.imshow(spectrogram, origin="lower", aspect="auto")
plt.colorbar()
plt.savefig(path)
plt.close()
def get_current_audio_length(self):
"""Get the length of the reference audio in seconds"""
if self.ref_audio_processed is None:
return 0
return self.ref_audio_processed.shape[-1] / self.target_sample_rate