Spaces:
Runtime error
Runtime error
File size: 4,498 Bytes
5c81b55 d7016b3 5f635fb d7016b3 5c81b55 d7016b3 55fd1c7 5f635fb 55fd1c7 5c81b55 d7016b3 5c81b55 d7016b3 5c81b55 6a1a9b3 5c81b55 d7016b3 5c81b55 5f635fb 6a1a9b3 5f635fb 5c81b55 d7016b3 5c81b55 d7016b3 5c81b55 d7016b3 5c81b55 d7016b3 5c81b55 d7016b3 5c81b55 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
import spaces
import torch
import os
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
import gradio as gr
import traceback
from huggingface_hub import snapshot_download
from tts.infer_cli import MegaTTS3DiTInfer, convert_to_wav, cut_wav
def download_weights():
"""Download model weights from HuggingFace if not already present."""
repo_id = "mrfakename/MegaTTS3-VoiceCloning"
weights_dir = "checkpoints"
if not os.path.exists(weights_dir):
print("Downloading model weights from HuggingFace...")
snapshot_download(
repo_id=repo_id,
local_dir=weights_dir,
local_dir_use_symlinks=False
)
print("Model weights downloaded successfully!")
else:
print("Model weights already exist.")
return weights_dir
# Download weights and initialize model
download_weights()
print("Initializing MegaTTS3 model...")
infer_pipe = MegaTTS3DiTInfer()
print("Model loaded successfully!")
@spaces.GPU
def generate_speech(inp_audio, inp_text, infer_timestep, p_w, t_w):
if not inp_audio or not inp_text:
gr.Warning("Please provide both reference audio and text to generate.")
return None
try:
print(f"Generating speech with: {inp_text}...")
# Convert and prepare audio
convert_to_wav(inp_audio)
wav_path = os.path.splitext(inp_audio)[0] + '.wav'
cut_wav(wav_path, max_len=28)
# Read audio file
with open(wav_path, 'rb') as file:
file_content = file.read()
# Generate speech
resource_context = infer_pipe.preprocess(file_content)
wav_bytes = infer_pipe.forward(resource_context, inp_text, time_step=infer_timestep, p_w=p_w, t_w=t_w)
return wav_bytes
except Exception as e:
traceback.print_exc()
gr.Warning(f"Speech generation failed: {str(e)}")
return None
with gr.Blocks(title="MegaTTS3 Voice Cloning") as demo:
gr.Markdown("# MegaTTS 3 Voice Cloning")
gr.Markdown("MegaTTS 3 is a text-to-speech model trained by ByteDance with exceptional voice cloning capabilities. The original authors did not release the WavVAE encoder, so voice cloning was not publicly available; however, thanks to [@ACoderPassBy](https://modelscope.cn/models/ACoderPassBy/MegaTTS-SFT)'s WavVAE encoder, we can now clone voices with MegaTTS 3!")
gr.Markdown("This is by no means the best voice cloning solution, but it works pretty well for some specific use-cases. Try out multiple and see which one works best for you.")
gr.Markdown("**Please use this Space responsibly and do not abuse it!**")
gr.Markdown("h/t to MysteryShack on Discord for the info about the unofficial WavVAE encoder!")
gr.Markdown("Upload a reference audio clip and enter text to generate speech with the cloned voice.")
with gr.Row():
with gr.Column():
reference_audio = gr.Audio(
label="Reference Audio",
type="filepath",
sources=["upload", "microphone"]
)
text_input = gr.Textbox(
label="Text to Generate",
placeholder="Enter the text you want to synthesize...",
lines=3
)
with gr.Accordion("Advanced Options", open=False):
infer_timestep = gr.Number(
label="Inference Timesteps",
value=32,
minimum=1,
maximum=100,
step=1
)
p_w = gr.Number(
label="Intelligibility Weight",
value=1.4,
minimum=0.1,
maximum=5.0,
step=0.1
)
t_w = gr.Number(
label="Similarity Weight",
value=3.0,
minimum=0.1,
maximum=10.0,
step=0.1
)
generate_btn = gr.Button("Generate Speech", variant="primary")
with gr.Column():
output_audio = gr.Audio(label="Generated Audio")
generate_btn.click(
fn=generate_speech,
inputs=[reference_audio, text_input, infer_timestep, p_w, t_w],
outputs=[output_audio]
)
if __name__ == '__main__':
demo.launch(server_name='0.0.0.0', server_port=7860, debug=True) |