|  | import soundfile as sf | 
					
						
						|  | import torch | 
					
						
						|  | import tqdm | 
					
						
						|  | from cached_path import cached_path | 
					
						
						|  |  | 
					
						
						|  | from model import DiT, UNetT | 
					
						
						|  | from model.utils import save_spectrogram | 
					
						
						|  |  | 
					
						
						|  | from model.utils_infer import load_vocoder, load_model, infer_process, remove_silence_for_generated_wav | 
					
						
						|  | from model.utils import seed_everything | 
					
						
						|  | import random | 
					
						
						|  | import sys | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class F5TTS: | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | model_type="F5-TTS", | 
					
						
						|  | ckpt_file="", | 
					
						
						|  | vocab_file="", | 
					
						
						|  | ode_method="euler", | 
					
						
						|  | use_ema=True, | 
					
						
						|  | local_path=None, | 
					
						
						|  | device=None, | 
					
						
						|  | ): | 
					
						
						|  |  | 
					
						
						|  | self.final_wave = None | 
					
						
						|  | self.target_sample_rate = 24000 | 
					
						
						|  | self.n_mel_channels = 100 | 
					
						
						|  | self.hop_length = 256 | 
					
						
						|  | self.target_rms = 0.1 | 
					
						
						|  | self.seed = -1 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.device = device or ( | 
					
						
						|  | "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.load_vocoder_model(local_path) | 
					
						
						|  | self.load_ema_model(model_type, ckpt_file, vocab_file, ode_method, use_ema) | 
					
						
						|  |  | 
					
						
						|  | def load_vocoder_model(self, local_path): | 
					
						
						|  | self.vocos = load_vocoder(local_path is not None, local_path, self.device) | 
					
						
						|  |  | 
					
						
						|  | def load_ema_model(self, model_type, ckpt_file, vocab_file, ode_method, use_ema): | 
					
						
						|  | if model_type == "F5-TTS": | 
					
						
						|  | if not ckpt_file: | 
					
						
						|  | ckpt_file = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors")) | 
					
						
						|  | model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) | 
					
						
						|  | model_cls = DiT | 
					
						
						|  | elif model_type == "E2-TTS": | 
					
						
						|  | if not ckpt_file: | 
					
						
						|  | ckpt_file = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors")) | 
					
						
						|  | model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) | 
					
						
						|  | model_cls = UNetT | 
					
						
						|  | else: | 
					
						
						|  | raise ValueError(f"Unknown model type: {model_type}") | 
					
						
						|  |  | 
					
						
						|  | self.ema_model = load_model(model_cls, model_cfg, ckpt_file, vocab_file, ode_method, use_ema, self.device) | 
					
						
						|  |  | 
					
						
						|  | def export_wav(self, wav, file_wave, remove_silence=False): | 
					
						
						|  | sf.write(file_wave, wav, self.target_sample_rate) | 
					
						
						|  |  | 
					
						
						|  | if remove_silence: | 
					
						
						|  | remove_silence_for_generated_wav(file_wave) | 
					
						
						|  |  | 
					
						
						|  | def export_spectrogram(self, spect, file_spect): | 
					
						
						|  | save_spectrogram(spect, file_spect) | 
					
						
						|  |  | 
					
						
						|  | def infer( | 
					
						
						|  | self, | 
					
						
						|  | ref_file, | 
					
						
						|  | ref_text, | 
					
						
						|  | gen_text, | 
					
						
						|  | show_info=print, | 
					
						
						|  | progress=tqdm, | 
					
						
						|  | target_rms=0.1, | 
					
						
						|  | cross_fade_duration=0.15, | 
					
						
						|  | sway_sampling_coef=-1, | 
					
						
						|  | cfg_strength=2, | 
					
						
						|  | nfe_step=32, | 
					
						
						|  | speed=1.0, | 
					
						
						|  | fix_duration=None, | 
					
						
						|  | remove_silence=False, | 
					
						
						|  | file_wave=None, | 
					
						
						|  | file_spect=None, | 
					
						
						|  | seed=-1, | 
					
						
						|  | ): | 
					
						
						|  | if seed == -1: | 
					
						
						|  | seed = random.randint(0, sys.maxsize) | 
					
						
						|  | seed_everything(seed) | 
					
						
						|  | self.seed = seed | 
					
						
						|  | wav, sr, spect = infer_process( | 
					
						
						|  | ref_file, | 
					
						
						|  | ref_text, | 
					
						
						|  | gen_text, | 
					
						
						|  | self.ema_model, | 
					
						
						|  | show_info=show_info, | 
					
						
						|  | progress=progress, | 
					
						
						|  | target_rms=target_rms, | 
					
						
						|  | cross_fade_duration=cross_fade_duration, | 
					
						
						|  | nfe_step=nfe_step, | 
					
						
						|  | cfg_strength=cfg_strength, | 
					
						
						|  | sway_sampling_coef=sway_sampling_coef, | 
					
						
						|  | speed=speed, | 
					
						
						|  | fix_duration=fix_duration, | 
					
						
						|  | device=self.device, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | if file_wave is not None: | 
					
						
						|  | self.export_wav(wav, file_wave, remove_silence) | 
					
						
						|  |  | 
					
						
						|  | if file_spect is not None: | 
					
						
						|  | self.export_spectrogram(spect, file_spect) | 
					
						
						|  |  | 
					
						
						|  | return wav, sr, spect | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if __name__ == "__main__": | 
					
						
						|  | f5tts = F5TTS() | 
					
						
						|  |  | 
					
						
						|  | wav, sr, spect = f5tts.infer( | 
					
						
						|  | ref_file="tests/ref_audio/test_en_1_ref_short.wav", | 
					
						
						|  | ref_text="some call me nature, others call me mother nature.", | 
					
						
						|  | gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""", | 
					
						
						|  | file_wave="tests/out.wav", | 
					
						
						|  | file_spect="tests/out.png", | 
					
						
						|  | seed=-1, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | print("seed :", f5tts.seed) | 
					
						
						|  |  |