|
import os |
|
import torch |
|
import torchaudio |
|
import numpy as np |
|
from pathlib import Path |
|
from typing import Optional, Union, List, Tuple, Dict |
|
|
|
from cached_path import cached_path |
|
from hydra.utils import get_class |
|
from omegaconf import OmegaConf |
|
from importlib.resources import files |
|
from pydub import AudioSegment, silence |
|
|
|
from f5_tts.model import CFM |
|
from f5_tts.model.utils import ( |
|
get_tokenizer, |
|
convert_char_to_pinyin, |
|
) |
|
from f5_tts.infer.utils_infer import ( |
|
chunk_text, |
|
load_vocoder, |
|
transcribe, |
|
initialize_asr_pipeline, |
|
) |
|
|
|
|
|
class F5TTSWrapper: |
|
""" |
|
A wrapper class for F5-TTS that preprocesses reference audio once |
|
and allows for repeated TTS generation. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
model_name: str = "F5TTS_v1_Base", |
|
ckpt_path: Optional[str] = None, |
|
vocab_file: Optional[str] = None, |
|
vocoder_name: str = "vocos", |
|
use_local_vocoder: bool = False, |
|
vocoder_path: Optional[str] = None, |
|
device: Optional[str] = None, |
|
hf_cache_dir: Optional[str] = None, |
|
target_sample_rate: int = 24000, |
|
n_mel_channels: int = 100, |
|
hop_length: int = 256, |
|
win_length: int = 1024, |
|
n_fft: int = 1024, |
|
ode_method: str = "euler", |
|
use_ema: bool = True, |
|
): |
|
""" |
|
Initialize the F5-TTS wrapper with model configuration. |
|
|
|
Args: |
|
model_name: Name of the F5-TTS model variant (e.g., "F5TTS_v1_Base") |
|
ckpt_path: Path to the model checkpoint file. If None, will use default path. |
|
vocab_file: Path to the vocab file. If None, will use default. |
|
vocoder_name: Name of the vocoder to use ("vocos" or "bigvgan") |
|
use_local_vocoder: Whether to use a local vocoder or download from HF |
|
vocoder_path: Path to the local vocoder. Only used if use_local_vocoder is True. |
|
device: Device to run the model on. If None, will automatically determine. |
|
hf_cache_dir: Directory to cache HuggingFace models |
|
target_sample_rate: Target sample rate for audio |
|
n_mel_channels: Number of mel channels |
|
hop_length: Hop length for the mel spectrogram |
|
win_length: Window length for the mel spectrogram |
|
n_fft: FFT size for the mel spectrogram |
|
ode_method: ODE method for sampling ("euler" or "midpoint") |
|
use_ema: Whether to use EMA weights from the checkpoint |
|
""" |
|
|
|
if device is None: |
|
self.device = ( |
|
"cuda" if torch.cuda.is_available() |
|
else "xpu" if torch.xpu.is_available() |
|
else "mps" if torch.backends.mps.is_available() |
|
else "cpu" |
|
) |
|
else: |
|
self.device = device |
|
|
|
|
|
self.target_sample_rate = target_sample_rate |
|
self.n_mel_channels = n_mel_channels |
|
self.hop_length = hop_length |
|
self.win_length = win_length |
|
self.n_fft = n_fft |
|
self.mel_spec_type = vocoder_name |
|
|
|
|
|
self.ode_method = ode_method |
|
|
|
|
|
initialize_asr_pipeline(device=self.device) |
|
|
|
|
|
if ckpt_path is None: |
|
repo_name = "F5-TTS" |
|
ckpt_step = 1250000 |
|
ckpt_type = "safetensors" |
|
|
|
|
|
if model_name == "F5TTS_Base": |
|
if vocoder_name == "vocos": |
|
ckpt_step = 1200000 |
|
elif vocoder_name == "bigvgan": |
|
model_name = "F5TTS_Base_bigvgan" |
|
ckpt_type = "pt" |
|
elif model_name == "E2TTS_Base": |
|
repo_name = "E2-TTS" |
|
ckpt_step = 1200000 |
|
|
|
ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{model_name}/model_{ckpt_step}.{ckpt_type}")) |
|
|
|
|
|
config_path = str(files("f5_tts").joinpath(f"configs/{model_name}.yaml")) |
|
model_cfg = OmegaConf.load(config_path) |
|
model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}") |
|
model_arc = model_cfg.model.arch |
|
|
|
|
|
if vocab_file is None: |
|
vocab_file = str(files("f5_tts").joinpath("infer/examples/vocab.txt")) |
|
tokenizer_type = "custom" |
|
self.vocab_char_map, vocab_size = get_tokenizer(vocab_file, tokenizer_type) |
|
|
|
|
|
self.model = CFM( |
|
transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=n_mel_channels), |
|
mel_spec_kwargs=dict( |
|
n_fft=n_fft, |
|
hop_length=hop_length, |
|
win_length=win_length, |
|
n_mel_channels=n_mel_channels, |
|
target_sample_rate=target_sample_rate, |
|
mel_spec_type=vocoder_name, |
|
), |
|
odeint_kwargs=dict( |
|
method=ode_method, |
|
), |
|
vocab_char_map=self.vocab_char_map, |
|
).to(self.device) |
|
|
|
|
|
dtype = torch.float32 if vocoder_name == "bigvgan" else None |
|
self._load_checkpoint(self.model, ckpt_path, dtype=dtype, use_ema=use_ema) |
|
|
|
|
|
if vocoder_path is None: |
|
if vocoder_name == "vocos": |
|
vocoder_path = "../checkpoints/vocos-mel-24khz" |
|
elif vocoder_name == "bigvgan": |
|
vocoder_path = "../checkpoints/bigvgan_v2_24khz_100band_256x" |
|
|
|
self.vocoder = load_vocoder( |
|
vocoder_name=vocoder_name, |
|
is_local=use_local_vocoder, |
|
local_path=vocoder_path, |
|
device=self.device, |
|
hf_cache_dir=hf_cache_dir |
|
) |
|
|
|
|
|
self.ref_audio_processed = None |
|
self.ref_text = None |
|
self.ref_audio_len = None |
|
|
|
|
|
self.target_rms = 0.1 |
|
self.cross_fade_duration = 0.15 |
|
self.nfe_step = 32 |
|
self.cfg_strength = 2.0 |
|
self.sway_sampling_coef = -1.0 |
|
self.speed = 1.0 |
|
self.fix_duration = None |
|
|
|
def _load_checkpoint(self, model, ckpt_path, dtype=None, use_ema=True): |
|
""" |
|
Load model checkpoint with proper handling of different checkpoint formats. |
|
|
|
Args: |
|
model: The model to load weights into |
|
ckpt_path: Path to the checkpoint file |
|
dtype: Data type for model weights |
|
use_ema: Whether to use EMA weights from the checkpoint |
|
|
|
Returns: |
|
Loaded model |
|
""" |
|
if dtype is None: |
|
dtype = ( |
|
torch.float16 |
|
if "cuda" in self.device |
|
and torch.cuda.get_device_properties(self.device).major >= 7 |
|
and not torch.cuda.get_device_name().endswith("[ZLUDA]") |
|
else torch.float32 |
|
) |
|
model = model.to(dtype) |
|
|
|
ckpt_type = ckpt_path.split(".")[-1] |
|
if ckpt_type == "safetensors": |
|
from safetensors.torch import load_file |
|
checkpoint = load_file(ckpt_path, device=self.device) |
|
else: |
|
checkpoint = torch.load(ckpt_path, map_location=self.device, weights_only=True) |
|
|
|
if use_ema: |
|
if ckpt_type == "safetensors": |
|
checkpoint = {"ema_model_state_dict": checkpoint} |
|
checkpoint["model_state_dict"] = { |
|
k.replace("ema_model.", ""): v |
|
for k, v in checkpoint["ema_model_state_dict"].items() |
|
if k not in ["initted", "step"] |
|
} |
|
|
|
|
|
for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]: |
|
if key in checkpoint["model_state_dict"]: |
|
del checkpoint["model_state_dict"][key] |
|
|
|
model.load_state_dict(checkpoint["model_state_dict"]) |
|
else: |
|
if ckpt_type == "safetensors": |
|
checkpoint = {"model_state_dict": checkpoint} |
|
model.load_state_dict(checkpoint["model_state_dict"]) |
|
|
|
del checkpoint |
|
torch.cuda.empty_cache() |
|
|
|
return model.to(self.device) |
|
|
|
def preprocess_reference(self, ref_audio_path: str, ref_text: str = "", clip_short: bool = True): |
|
""" |
|
Preprocess the reference audio and text, storing them for later use. |
|
|
|
Args: |
|
ref_audio_path: Path to the reference audio file |
|
ref_text: Text transcript of reference audio. If empty, will auto-transcribe. |
|
clip_short: Whether to clip long audio to shorter segments |
|
|
|
Returns: |
|
Tuple of processed audio and text |
|
""" |
|
print("Converting audio...") |
|
|
|
aseg = AudioSegment.from_file(ref_audio_path) |
|
|
|
if clip_short: |
|
|
|
non_silent_segs = silence.split_on_silence( |
|
aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000, seek_step=10 |
|
) |
|
non_silent_wave = AudioSegment.silent(duration=0) |
|
for non_silent_seg in non_silent_segs: |
|
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 12000: |
|
print("Audio is over 12s, clipping short. (1)") |
|
break |
|
non_silent_wave += non_silent_seg |
|
|
|
|
|
if len(non_silent_wave) > 12000: |
|
non_silent_segs = silence.split_on_silence( |
|
aseg, min_silence_len=100, silence_thresh=-40, keep_silence=1000, seek_step=10 |
|
) |
|
non_silent_wave = AudioSegment.silent(duration=0) |
|
for non_silent_seg in non_silent_segs: |
|
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 12000: |
|
print("Audio is over 12s, clipping short. (2)") |
|
break |
|
non_silent_wave += non_silent_seg |
|
|
|
aseg = non_silent_wave |
|
|
|
|
|
if len(aseg) > 12000: |
|
aseg = aseg[:12000] |
|
print("Audio is over 12s, clipping short. (3)") |
|
|
|
|
|
aseg = self._remove_silence_edges(aseg) + AudioSegment.silent(duration=50) |
|
|
|
|
|
import tempfile |
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file: |
|
aseg.export(tmp_file.name, format="wav") |
|
processed_audio_path = tmp_file.name |
|
|
|
|
|
if not ref_text.strip(): |
|
print("No reference text provided, transcribing reference audio...") |
|
ref_text = transcribe(processed_audio_path) |
|
else: |
|
print("Using custom reference text...") |
|
|
|
|
|
if not ref_text.endswith(". ") and not ref_text.endswith("。"): |
|
if ref_text.endswith("."): |
|
ref_text += " " |
|
else: |
|
ref_text += ". " |
|
|
|
print("\nReference text:", ref_text) |
|
|
|
|
|
audio, sr = torchaudio.load(processed_audio_path) |
|
if audio.shape[0] > 1: |
|
audio = torch.mean(audio, dim=0, keepdim=True) |
|
|
|
|
|
rms = torch.sqrt(torch.mean(torch.square(audio))) |
|
if rms < self.target_rms: |
|
audio = audio * self.target_rms / rms |
|
|
|
|
|
if sr != self.target_sample_rate: |
|
resampler = torchaudio.transforms.Resample(sr, self.target_sample_rate) |
|
audio = resampler(audio) |
|
|
|
|
|
audio = audio.to(self.device) |
|
|
|
|
|
self.ref_audio_processed = audio |
|
self.ref_text = ref_text |
|
self.ref_audio_len = audio.shape[-1] // self.hop_length |
|
|
|
|
|
os.unlink(processed_audio_path) |
|
|
|
return audio, ref_text |
|
|
|
def _remove_silence_edges(self, audio, silence_threshold=-42): |
|
""" |
|
Remove silence from the start and end of audio. |
|
|
|
Args: |
|
audio: AudioSegment to process |
|
silence_threshold: dB threshold to consider as silence |
|
|
|
Returns: |
|
Processed AudioSegment |
|
""" |
|
|
|
non_silent_start_idx = silence.detect_leading_silence(audio, silence_threshold=silence_threshold) |
|
audio = audio[non_silent_start_idx:] |
|
|
|
|
|
non_silent_end_duration = audio.duration_seconds |
|
for ms in reversed(audio): |
|
if ms.dBFS > silence_threshold: |
|
break |
|
non_silent_end_duration -= 0.001 |
|
trimmed_audio = audio[: int(non_silent_end_duration * 1000)] |
|
|
|
return trimmed_audio |
|
|
|
def generate( |
|
self, |
|
text: str, |
|
output_path: Optional[str] = None, |
|
nfe_step: Optional[int] = None, |
|
cfg_strength: Optional[float] = None, |
|
sway_sampling_coef: Optional[float] = None, |
|
speed: Optional[float] = None, |
|
fix_duration: Optional[float] = None, |
|
cross_fade_duration: Optional[float] = None, |
|
return_numpy: bool = False, |
|
return_spectrogram: bool = False, |
|
) -> Union[str, Tuple[np.ndarray, int], Tuple[np.ndarray, int, np.ndarray]]: |
|
""" |
|
Generate speech for the given text using the stored reference audio. |
|
|
|
Args: |
|
text: Text to synthesize |
|
output_path: Path to save the generated audio. If None, won't save. |
|
nfe_step: Number of function evaluation steps |
|
cfg_strength: Classifier-free guidance strength |
|
sway_sampling_coef: Sway sampling coefficient |
|
speed: Speed of generated audio |
|
fix_duration: Fixed duration in seconds |
|
cross_fade_duration: Duration of cross-fade between segments |
|
return_numpy: If True, returns the audio as a numpy array |
|
return_spectrogram: If True, also returns the spectrogram |
|
|
|
Returns: |
|
If output_path provided: path to output file |
|
If return_numpy=True: tuple of (audio_array, sample_rate) |
|
If return_spectrogram=True: tuple of (audio_array, sample_rate, spectrogram) |
|
""" |
|
if self.ref_audio_processed is None or self.ref_text is None: |
|
raise ValueError("Reference audio not preprocessed. Call preprocess_reference() first.") |
|
|
|
|
|
nfe_step = nfe_step if nfe_step is not None else self.nfe_step |
|
cfg_strength = cfg_strength if cfg_strength is not None else self.cfg_strength |
|
sway_sampling_coef = sway_sampling_coef if sway_sampling_coef is not None else self.sway_sampling_coef |
|
speed = speed if speed is not None else self.speed |
|
fix_duration = fix_duration if fix_duration is not None else self.fix_duration |
|
cross_fade_duration = cross_fade_duration if cross_fade_duration is not None else self.cross_fade_duration |
|
|
|
|
|
audio_len = self.ref_audio_processed.shape[-1] / self.target_sample_rate |
|
max_chars = int(len(self.ref_text.encode("utf-8")) / audio_len * (22 - audio_len)) |
|
text_batches = chunk_text(text, max_chars=max_chars) |
|
|
|
for i, text_batch in enumerate(text_batches): |
|
print(f"Text batch {i}: {text_batch}") |
|
print("\n") |
|
|
|
|
|
generated_waves = [] |
|
spectrograms = [] |
|
|
|
for text_batch in text_batches: |
|
|
|
local_speed = speed |
|
if len(text_batch.encode("utf-8")) < 10: |
|
local_speed = 0.3 |
|
|
|
|
|
text_list = [self.ref_text + text_batch] |
|
final_text_list = convert_char_to_pinyin(text_list) |
|
|
|
|
|
if fix_duration is not None: |
|
duration = int(fix_duration * self.target_sample_rate / self.hop_length) |
|
else: |
|
|
|
ref_text_len = len(self.ref_text.encode("utf-8")) |
|
gen_text_len = len(text_batch.encode("utf-8")) |
|
duration = self.ref_audio_len + int(self.ref_audio_len / ref_text_len * gen_text_len / local_speed) |
|
|
|
|
|
with torch.inference_mode(): |
|
generated, _ = self.model.sample( |
|
cond=self.ref_audio_processed, |
|
text=final_text_list, |
|
duration=duration, |
|
steps=nfe_step, |
|
cfg_strength=cfg_strength, |
|
sway_sampling_coef=sway_sampling_coef, |
|
) |
|
|
|
|
|
generated = generated.to(torch.float32) |
|
generated = generated[:, self.ref_audio_len:, :] |
|
generated = generated.permute(0, 2, 1) |
|
|
|
|
|
if self.mel_spec_type == "vocos": |
|
generated_wave = self.vocoder.decode(generated) |
|
elif self.mel_spec_type == "bigvgan": |
|
generated_wave = self.vocoder(generated) |
|
|
|
|
|
rms = torch.sqrt(torch.mean(torch.square(self.ref_audio_processed))) |
|
if rms < self.target_rms: |
|
generated_wave = generated_wave * rms / self.target_rms |
|
|
|
|
|
generated_wave = generated_wave.squeeze().cpu().numpy() |
|
generated_waves.append(generated_wave) |
|
|
|
|
|
if return_spectrogram or output_path is not None: |
|
spectrograms.append(generated.squeeze().cpu().numpy()) |
|
|
|
|
|
if generated_waves: |
|
if cross_fade_duration <= 0: |
|
|
|
final_wave = np.concatenate(generated_waves) |
|
else: |
|
|
|
final_wave = generated_waves[0] |
|
for i in range(1, len(generated_waves)): |
|
prev_wave = final_wave |
|
next_wave = generated_waves[i] |
|
|
|
|
|
cross_fade_samples = int(cross_fade_duration * self.target_sample_rate) |
|
cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave)) |
|
|
|
if cross_fade_samples <= 0: |
|
|
|
final_wave = np.concatenate([prev_wave, next_wave]) |
|
continue |
|
|
|
|
|
prev_overlap = prev_wave[-cross_fade_samples:] |
|
next_overlap = next_wave[:cross_fade_samples] |
|
|
|
fade_out = np.linspace(1, 0, cross_fade_samples) |
|
fade_in = np.linspace(0, 1, cross_fade_samples) |
|
|
|
cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in |
|
|
|
final_wave = np.concatenate([ |
|
prev_wave[:-cross_fade_samples], |
|
cross_faded_overlap, |
|
next_wave[cross_fade_samples:] |
|
]) |
|
|
|
|
|
if return_spectrogram or output_path is not None: |
|
combined_spectrogram = np.concatenate(spectrograms, axis=1) |
|
|
|
|
|
if output_path is not None: |
|
output_dir = os.path.dirname(output_path) |
|
if output_dir and not os.path.exists(output_dir): |
|
os.makedirs(output_dir) |
|
|
|
|
|
torchaudio.save(output_path, |
|
torch.tensor(final_wave).unsqueeze(0), |
|
self.target_sample_rate) |
|
|
|
|
|
if return_spectrogram: |
|
spectrogram_path = os.path.splitext(output_path)[0] + '_spec.png' |
|
self._save_spectrogram(combined_spectrogram, spectrogram_path) |
|
|
|
if not return_numpy: |
|
return output_path |
|
|
|
|
|
if return_spectrogram: |
|
return final_wave, self.target_sample_rate, combined_spectrogram |
|
else: |
|
return final_wave, self.target_sample_rate |
|
|
|
else: |
|
raise RuntimeError("No audio generated") |
|
|
|
def _save_spectrogram(self, spectrogram, path): |
|
"""Save spectrogram as image""" |
|
import matplotlib.pyplot as plt |
|
plt.figure(figsize=(12, 4)) |
|
plt.imshow(spectrogram, origin="lower", aspect="auto") |
|
plt.colorbar() |
|
plt.savefig(path) |
|
plt.close() |
|
|
|
def get_current_audio_length(self): |
|
"""Get the length of the reference audio in seconds""" |
|
if self.ref_audio_processed is None: |
|
return 0 |
|
return self.ref_audio_processed.shape[-1] / self.target_sample_rate |
|
|