File size: 4,765 Bytes
0102e16
 
 
 
 
 
 
 
e001c3b
0102e16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e001c3b
0102e16
 
 
 
 
 
 
 
 
 
e001c3b
0102e16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import gradio as gr
import torch
import torchaudio
from huggingface_hub import snapshot_download
from tts import StepAudioTTS
from tokenizer import StepAudioTokenizer
import os
import tempfile
import spaces

class StepAudioDemo:
    def __init__(self):
        # Download models from HuggingFace
        print("Downloading models from HuggingFace...")
        self.model_path = snapshot_download(repo_id="stepfun-ai/Step-Audio-TTS-3B")
        self.tokenizer_path = snapshot_download(repo_id="stepfun-ai/Step-Audio-Tokenizer")
        
        # Initialize models
        print("Initializing models...")
        self.encoder = StepAudioTokenizer(self.tokenizer_path)
        self.tts_engine = StepAudioTTS(self.model_path, self.encoder)
        
        # Create temporary directory for outputs
        self.temp_dir = tempfile.mkdtemp()
        print("Models loaded and ready!")

    @spaces.GPU
    def generate_tts(self, text, speaker_name):
        """Generate TTS audio"""
        try:
            output_audio, sr = self.tts_engine(text, speaker_name)
            output_path = os.path.join(self.temp_dir, "output_tts.wav")
            torchaudio.save(output_path, output_audio, sr)
            return output_path
        except Exception as e:
            return f"Error generating audio: {str(e)}"

    @spaces.GPU
    def generate_clone(self, text, prompt_audio, prompt_text):
        """Generate cloned voice audio"""
        try:
            clone_speaker = {
                "speaker": "clone",
                "prompt_text": prompt_text,
                "wav_path": prompt_audio
            }
            output_audio, sr = self.tts_engine(text, "", clone_speaker)
            output_path = os.path.join(self.temp_dir, "output_clone.wav")
            torchaudio.save(output_path, output_audio, sr)
            return output_path
        except Exception as e:
            return f"Error generating cloned audio: {str(e)}"

def create_demo():
    demo = StepAudioDemo()
    
    with gr.Blocks() as interface:
        gr.Markdown("# Step Audio TTS Demo")
        
        with gr.Tabs():
            # TTS Tab
            with gr.TabItem("Text-to-Speech"):
                with gr.Row():
                    with gr.Column():
                        tts_text = gr.Textbox(
                            label="Input Text",
                            placeholder="Enter text to synthesize...",
                            lines=5
                        )
                        speaker_name = gr.Textbox(
                            label="Speaker Name",
                            placeholder="Enter speaker name (e.g., 闫雨婷)",
                            value="闫雨婷"
                        )
                        tts_button = gr.Button("Generate Speech")
                    with gr.Column():
                        tts_output = gr.Audio(label="Generated Audio")
                
                tts_button.click(
                    fn=demo.generate_tts,
                    inputs=[tts_text, speaker_name],
                    outputs=tts_output
                )
            
            # Voice Cloning Tab
            with gr.TabItem("Voice Cloning"):
                with gr.Row():
                    with gr.Column():
                        clone_text = gr.Textbox(
                            label="Input Text",
                            placeholder="Enter text to synthesize with cloned voice...",
                            lines=5
                        )
                        prompt_text = gr.Textbox(
                            label="Prompt Text",
                            placeholder="Enter the transcript of your prompt audio...",
                            lines=3
                        )
                        prompt_audio = gr.Audio(
                            label="Upload Prompt Audio",
                            type="filepath"
                        )
                        clone_button = gr.Button("Generate Cloned Speech")
                    with gr.Column():
                        clone_output = gr.Audio(label="Generated Audio")
                
                clone_button.click(
                    fn=demo.generate_clone,
                    inputs=[clone_text, prompt_audio, prompt_text],
                    outputs=clone_output
                )

        gr.Markdown("""
        ## Usage Notes:
        - For basic TTS: Enter text and speaker name in the Text-to-Speech tab
        - For voice cloning: Upload a prompt audio file, enter its transcript, and the text you want to synthesize
        - Generation may take a few moments depending on text length
        """)

    return interface

if __name__ == "__main__":
    demo = create_demo()
    demo.queue().launch()