mrfakename's picture
Update app.py
e001c3b verified
import gradio as gr
import torch
import torchaudio
from huggingface_hub import snapshot_download
from tts import StepAudioTTS
from tokenizer import StepAudioTokenizer
import os
import tempfile
import spaces
class StepAudioDemo:
def __init__(self):
# Download models from HuggingFace
print("Downloading models from HuggingFace...")
self.model_path = snapshot_download(repo_id="stepfun-ai/Step-Audio-TTS-3B")
self.tokenizer_path = snapshot_download(repo_id="stepfun-ai/Step-Audio-Tokenizer")
# Initialize models
print("Initializing models...")
self.encoder = StepAudioTokenizer(self.tokenizer_path)
self.tts_engine = StepAudioTTS(self.model_path, self.encoder)
# Create temporary directory for outputs
self.temp_dir = tempfile.mkdtemp()
print("Models loaded and ready!")
@spaces.GPU
def generate_tts(self, text, speaker_name):
"""Generate TTS audio"""
try:
output_audio, sr = self.tts_engine(text, speaker_name)
output_path = os.path.join(self.temp_dir, "output_tts.wav")
torchaudio.save(output_path, output_audio, sr)
return output_path
except Exception as e:
return f"Error generating audio: {str(e)}"
@spaces.GPU
def generate_clone(self, text, prompt_audio, prompt_text):
"""Generate cloned voice audio"""
try:
clone_speaker = {
"speaker": "clone",
"prompt_text": prompt_text,
"wav_path": prompt_audio
}
output_audio, sr = self.tts_engine(text, "", clone_speaker)
output_path = os.path.join(self.temp_dir, "output_clone.wav")
torchaudio.save(output_path, output_audio, sr)
return output_path
except Exception as e:
return f"Error generating cloned audio: {str(e)}"
def create_demo():
demo = StepAudioDemo()
with gr.Blocks() as interface:
gr.Markdown("# Step Audio TTS Demo")
with gr.Tabs():
# TTS Tab
with gr.TabItem("Text-to-Speech"):
with gr.Row():
with gr.Column():
tts_text = gr.Textbox(
label="Input Text",
placeholder="Enter text to synthesize...",
lines=5
)
speaker_name = gr.Textbox(
label="Speaker Name",
placeholder="Enter speaker name (e.g., 闫雨婷)",
value="闫雨婷"
)
tts_button = gr.Button("Generate Speech")
with gr.Column():
tts_output = gr.Audio(label="Generated Audio")
tts_button.click(
fn=demo.generate_tts,
inputs=[tts_text, speaker_name],
outputs=tts_output
)
# Voice Cloning Tab
with gr.TabItem("Voice Cloning"):
with gr.Row():
with gr.Column():
clone_text = gr.Textbox(
label="Input Text",
placeholder="Enter text to synthesize with cloned voice...",
lines=5
)
prompt_text = gr.Textbox(
label="Prompt Text",
placeholder="Enter the transcript of your prompt audio...",
lines=3
)
prompt_audio = gr.Audio(
label="Upload Prompt Audio",
type="filepath"
)
clone_button = gr.Button("Generate Cloned Speech")
with gr.Column():
clone_output = gr.Audio(label="Generated Audio")
clone_button.click(
fn=demo.generate_clone,
inputs=[clone_text, prompt_audio, prompt_text],
outputs=clone_output
)
gr.Markdown("""
## Usage Notes:
- For basic TTS: Enter text and speaker name in the Text-to-Speech tab
- For voice cloning: Upload a prompt audio file, enter its transcript, and the text you want to synthesize
- Generation may take a few moments depending on text length
""")
return interface
if __name__ == "__main__":
demo = create_demo()
demo.queue().launch()