File size: 4,498 Bytes
5c81b55
d7016b3
 
5f635fb
d7016b3
 
5c81b55
d7016b3
 
 
55fd1c7
 
 
5f635fb
55fd1c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5c81b55
 
 
 
 
d7016b3
5c81b55
 
d7016b3
 
 
 
5c81b55
6a1a9b3
5c81b55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d7016b3
 
 
5c81b55
5f635fb
 
6a1a9b3
 
5f635fb
5c81b55
d7016b3
5c81b55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d7016b3
5c81b55
 
 
 
 
 
 
 
 
 
 
 
 
d7016b3
 
5c81b55
d7016b3
5c81b55
 
d7016b3
5c81b55
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import spaces
import torch
import os
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
import gradio as gr
import traceback
from huggingface_hub import snapshot_download
from tts.infer_cli import MegaTTS3DiTInfer, convert_to_wav, cut_wav


def download_weights():
    """Download model weights from HuggingFace if not already present."""
    repo_id = "mrfakename/MegaTTS3-VoiceCloning"
    weights_dir = "checkpoints"
    
    if not os.path.exists(weights_dir):
        print("Downloading model weights from HuggingFace...")
        snapshot_download(
            repo_id=repo_id,
            local_dir=weights_dir,
            local_dir_use_symlinks=False
        )
        print("Model weights downloaded successfully!")
    else:
        print("Model weights already exist.")
    
    return weights_dir


# Download weights and initialize model
download_weights()
print("Initializing MegaTTS3 model...")
infer_pipe = MegaTTS3DiTInfer()
print("Model loaded successfully!")

@spaces.GPU
def generate_speech(inp_audio, inp_text, infer_timestep, p_w, t_w):
    if not inp_audio or not inp_text:
        gr.Warning("Please provide both reference audio and text to generate.")
        return None
    
    try:
        print(f"Generating speech with: {inp_text}...")
        
        # Convert and prepare audio
        convert_to_wav(inp_audio)
        wav_path = os.path.splitext(inp_audio)[0] + '.wav'
        cut_wav(wav_path, max_len=28)
        
        # Read audio file
        with open(wav_path, 'rb') as file:
            file_content = file.read()
        
        # Generate speech
        resource_context = infer_pipe.preprocess(file_content)
        wav_bytes = infer_pipe.forward(resource_context, inp_text, time_step=infer_timestep, p_w=p_w, t_w=t_w)
        
        return wav_bytes
    except Exception as e:
        traceback.print_exc()
        gr.Warning(f"Speech generation failed: {str(e)}")
        return None


with gr.Blocks(title="MegaTTS3 Voice Cloning") as demo:
    gr.Markdown("# MegaTTS 3 Voice Cloning")
    gr.Markdown("MegaTTS 3 is a text-to-speech model trained by ByteDance with exceptional voice cloning capabilities. The original authors did not release the WavVAE encoder, so voice cloning was not publicly available; however, thanks to [@ACoderPassBy](https://modelscope.cn/models/ACoderPassBy/MegaTTS-SFT)'s WavVAE encoder, we can now clone voices with MegaTTS 3!")
    gr.Markdown("This is by no means the best voice cloning solution, but it works pretty well for some specific use-cases. Try out multiple and see which one works best for you.")
    gr.Markdown("**Please use this Space responsibly and do not abuse it!**")
    gr.Markdown("h/t to MysteryShack on Discord for the info about the unofficial WavVAE encoder!")
    gr.Markdown("Upload a reference audio clip and enter text to generate speech with the cloned voice.")
    
    with gr.Row():
        with gr.Column():
            reference_audio = gr.Audio(
                label="Reference Audio",
                type="filepath",
                sources=["upload", "microphone"]
            )
            text_input = gr.Textbox(
                label="Text to Generate",
                placeholder="Enter the text you want to synthesize...",
                lines=3
            )
            
            with gr.Accordion("Advanced Options", open=False):
                infer_timestep = gr.Number(
                    label="Inference Timesteps",
                    value=32,
                    minimum=1,
                    maximum=100,
                    step=1
                )
                p_w = gr.Number(
                    label="Intelligibility Weight",
                    value=1.4,
                    minimum=0.1,
                    maximum=5.0,
                    step=0.1
                )
                t_w = gr.Number(
                    label="Similarity Weight", 
                    value=3.0,
                    minimum=0.1,
                    maximum=10.0,
                    step=0.1
                )
            
            generate_btn = gr.Button("Generate Speech", variant="primary")
        
        with gr.Column():
            output_audio = gr.Audio(label="Generated Audio")
    
    generate_btn.click(
        fn=generate_speech,
        inputs=[reference_audio, text_input, infer_timestep, p_w, t_w],
        outputs=[output_audio]
    )

if __name__ == '__main__':
    demo.launch(server_name='0.0.0.0', server_port=7860, debug=True)