Spaces:
Runtime error
Runtime error
""" | |
VibeVoice Gradio Demo - High-Quality Dialogue Generation Interface with Streaming Support | |
""" | |
import spaces | |
import argparse | |
import json | |
import os | |
import sys | |
import tempfile | |
import time | |
from pathlib import Path | |
from typing import List, Dict, Any, Iterator | |
from datetime import datetime | |
import threading | |
import numpy as np | |
import gradio as gr | |
import librosa | |
import soundfile as sf | |
import torch | |
import os | |
import traceback | |
from vibevoice.modular.configuration_vibevoice import VibeVoiceConfig | |
from vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference | |
from vibevoice.processor.vibevoice_processor import VibeVoiceProcessor | |
from vibevoice.modular.streamer import AudioStreamer | |
from transformers.utils import logging | |
from transformers import set_seed | |
logging.set_verbosity_info() | |
logger = logging.get_logger(__name__) | |
class VibeVoiceDemo: | |
def __init__(self, model_path: str, device: str = "cuda", inference_steps: int = 5): | |
"""Initialize the VibeVoice demo with model loading.""" | |
self.model_path = model_path | |
self.device = device | |
self.inference_steps = inference_steps | |
self.is_generating = False # Track generation state | |
self.stop_generation = False # Flag to stop generation | |
self.current_streamer = None # Track current audio streamer | |
self.load_model() | |
self.setup_voice_presets() | |
self.load_example_scripts() # Load example scripts | |
def load_model(self): | |
"""Load the VibeVoice model and processor.""" | |
print(f"Loading processor & model from {self.model_path}") | |
# Load processor | |
self.processor = VibeVoiceProcessor.from_pretrained( | |
self.model_path, | |
) | |
# Load model | |
self.model = VibeVoiceForConditionalGenerationInference.from_pretrained( | |
self.model_path, | |
torch_dtype=torch.bfloat16, | |
device_map='cuda', | |
attn_implementation="flash_attention_2", | |
) | |
self.model.eval() | |
# Use SDE solver by default | |
self.model.model.noise_scheduler = self.model.model.noise_scheduler.from_config( | |
self.model.model.noise_scheduler.config, | |
algorithm_type='sde-dpmsolver++', | |
beta_schedule='squaredcos_cap_v2' | |
) | |
self.model.set_ddpm_inference_steps(num_steps=self.inference_steps) | |
if hasattr(self.model.model, 'language_model'): | |
print(f"Language model attention: {self.model.model.language_model.config._attn_implementation}") | |
def setup_voice_presets(self): | |
"""Setup voice presets by scanning the voices directory.""" | |
voices_dir = os.path.join(os.path.dirname(__file__), "voices") | |
# Check if voices directory exists | |
if not os.path.exists(voices_dir): | |
print(f"Warning: Voices directory not found at {voices_dir}") | |
self.voice_presets = {} | |
self.available_voices = {} | |
return | |
# Scan for all WAV files in the voices directory | |
self.voice_presets = {} | |
# Get all .wav files in the voices directory | |
wav_files = [f for f in os.listdir(voices_dir) | |
if f.lower().endswith(('.wav', '.mp3', '.flac', '.ogg', '.m4a', '.aac')) and os.path.isfile(os.path.join(voices_dir, f))] | |
# Create dictionary with filename (without extension) as key | |
for wav_file in wav_files: | |
# Remove .wav extension to get the name | |
name = os.path.splitext(wav_file)[0] | |
# Create full path | |
full_path = os.path.join(voices_dir, wav_file) | |
self.voice_presets[name] = full_path | |
# Sort the voice presets alphabetically by name for better UI | |
self.voice_presets = dict(sorted(self.voice_presets.items())) | |
# Filter out voices that don't exist (this is now redundant but kept for safety) | |
self.available_voices = { | |
name: path for name, path in self.voice_presets.items() | |
if os.path.exists(path) | |
} | |
if not self.available_voices: | |
raise gr.Error("No voice presets found. Please add .wav files to the demo/voices directory.") | |
print(f"Found {len(self.available_voices)} voice files in {voices_dir}") | |
print(f"Available voices: {', '.join(self.available_voices.keys())}") | |
def read_audio(self, audio_path: str, target_sr: int = 24000) -> np.ndarray: | |
"""Read and preprocess audio file.""" | |
try: | |
wav, sr = sf.read(audio_path) | |
if len(wav.shape) > 1: | |
wav = np.mean(wav, axis=1) | |
if sr != target_sr: | |
wav = librosa.resample(wav, orig_sr=sr, target_sr=target_sr) | |
return wav | |
except Exception as e: | |
print(f"Error reading audio {audio_path}: {e}") | |
return np.array([]) | |
def generate_podcast_streaming(self, | |
num_speakers: int, | |
script: str, | |
speaker_1: str = None, | |
speaker_2: str = None, | |
speaker_3: str = None, | |
speaker_4: str = None, | |
cfg_scale: float = 1.3) -> Iterator[tuple]: | |
try: | |
# Reset stop flag and set generating state | |
self.stop_generation = False | |
self.is_generating = True | |
# Validate inputs | |
if not script.strip(): | |
self.is_generating = False | |
raise gr.Error("Error: Please provide a script.") | |
if num_speakers < 1 or num_speakers > 4: | |
self.is_generating = False | |
raise gr.Error("Error: Number of speakers must be between 1 and 4.") | |
# Collect selected speakers | |
selected_speakers = [speaker_1, speaker_2, speaker_3, speaker_4][:num_speakers] | |
# Validate speaker selections | |
for i, speaker in enumerate(selected_speakers): | |
if not speaker or speaker not in self.available_voices: | |
self.is_generating = False | |
raise gr.Error(f"Error: Please select a valid speaker for Speaker {i+1}.") | |
# Build initial log | |
log = f"ποΈ Generating podcast with {num_speakers} speakers\n" | |
log += f"π Parameters: CFG Scale={cfg_scale}, Inference Steps={self.inference_steps}\n" | |
log += f"π Speakers: {', '.join(selected_speakers)}\n" | |
# Check for stop signal | |
if self.stop_generation: | |
self.is_generating = False | |
yield None, "π Generation stopped by user", gr.update(visible=False) | |
return | |
# Load voice samples | |
voice_samples = [] | |
for speaker_name in selected_speakers: | |
audio_path = self.available_voices[speaker_name] | |
audio_data = self.read_audio(audio_path) | |
if len(audio_data) == 0: | |
self.is_generating = False | |
raise gr.Error(f"Error: Failed to load audio for {speaker_name}") | |
voice_samples.append(audio_data) | |
# log += f"β Loaded {len(voice_samples)} voice samples\n" | |
# Check for stop signal | |
if self.stop_generation: | |
self.is_generating = False | |
yield None, "π Generation stopped by user", gr.update(visible=False) | |
return | |
# Parse script to assign speaker ID's | |
lines = script.strip().split('\n') | |
formatted_script_lines = [] | |
for line in lines: | |
line = line.strip() | |
if not line: | |
continue | |
# Check if line already has speaker format | |
if line.startswith('Speaker ') and ':' in line: | |
formatted_script_lines.append(line) | |
else: | |
# Auto-assign to speakers in rotation | |
speaker_id = len(formatted_script_lines) % num_speakers | |
formatted_script_lines.append(f"Speaker {speaker_id}: {line}") | |
formatted_script = '\n'.join(formatted_script_lines) | |
log += f"π Formatted script with {len(formatted_script_lines)} turns\n\n" | |
log += "π Processing with VibeVoice (streaming mode)...\n" | |
# Check for stop signal before processing | |
if self.stop_generation: | |
self.is_generating = False | |
yield None, "π Generation stopped by user", gr.update(visible=False) | |
return | |
start_time = time.time() | |
inputs = self.processor( | |
text=[formatted_script], | |
voice_samples=[voice_samples], | |
padding=True, | |
return_tensors="pt", | |
return_attention_mask=True, | |
) | |
# Create audio streamer | |
audio_streamer = AudioStreamer( | |
batch_size=1, | |
stop_signal=None, | |
timeout=None | |
) | |
# Store current streamer for potential stopping | |
self.current_streamer = audio_streamer | |
# Start generation in a separate thread | |
generation_thread = threading.Thread( | |
target=self._generate_with_streamer, | |
args=(inputs, cfg_scale, audio_streamer) | |
) | |
generation_thread.start() | |
# Wait for generation to actually start producing audio | |
time.sleep(1) # Reduced from 3 to 1 second | |
# Check for stop signal after thread start | |
if self.stop_generation: | |
audio_streamer.end() | |
generation_thread.join(timeout=5.0) # Wait up to 5 seconds for thread to finish | |
self.is_generating = False | |
yield None, "π Generation stopped by user", gr.update(visible=False) | |
return | |
# Collect audio chunks as they arrive | |
sample_rate = 24000 | |
all_audio_chunks = [] # For final statistics | |
pending_chunks = [] # Buffer for accumulating small chunks | |
chunk_count = 0 | |
last_yield_time = time.time() | |
min_yield_interval = 15 # Yield every 15 seconds | |
min_chunk_size = sample_rate * 30 # At least 2 seconds of audio | |
# Get the stream for the first (and only) sample | |
audio_stream = audio_streamer.get_stream(0) | |
has_yielded_audio = False | |
has_received_chunks = False # Track if we received any chunks at all | |
for audio_chunk in audio_stream: | |
# Check for stop signal in the streaming loop | |
if self.stop_generation: | |
audio_streamer.end() | |
break | |
chunk_count += 1 | |
has_received_chunks = True # Mark that we received at least one chunk | |
# Convert tensor to numpy | |
if torch.is_tensor(audio_chunk): | |
# Convert bfloat16 to float32 first, then to numpy | |
if audio_chunk.dtype == torch.bfloat16: | |
audio_chunk = audio_chunk.float() | |
audio_np = audio_chunk.cpu().numpy().astype(np.float32) | |
else: | |
audio_np = np.array(audio_chunk, dtype=np.float32) | |
# Ensure audio is 1D and properly normalized | |
if len(audio_np.shape) > 1: | |
audio_np = audio_np.squeeze() | |
# Convert to 16-bit for Gradio | |
audio_16bit = convert_to_16_bit_wav(audio_np) | |
# Store for final statistics | |
all_audio_chunks.append(audio_16bit) | |
# Add to pending chunks buffer | |
pending_chunks.append(audio_16bit) | |
# Calculate pending audio size | |
pending_audio_size = sum(len(chunk) for chunk in pending_chunks) | |
current_time = time.time() | |
time_since_last_yield = current_time - last_yield_time | |
# Decide whether to yield | |
should_yield = False | |
if not has_yielded_audio and pending_audio_size >= min_chunk_size: | |
# First yield: wait for minimum chunk size | |
should_yield = True | |
has_yielded_audio = True | |
elif has_yielded_audio and (pending_audio_size >= min_chunk_size or time_since_last_yield >= min_yield_interval): | |
# Subsequent yields: either enough audio or enough time has passed | |
should_yield = True | |
if should_yield and pending_chunks: | |
# Concatenate and yield only the new audio chunks | |
new_audio = np.concatenate(pending_chunks) | |
new_duration = len(new_audio) / sample_rate | |
total_duration = sum(len(chunk) for chunk in all_audio_chunks) / sample_rate | |
log_update = log + f"π΅ Streaming: {total_duration:.1f}s generated (chunk {chunk_count})\n" | |
# Yield streaming audio chunk and keep complete_audio as None during streaming | |
yield (sample_rate, new_audio), None, log_update, gr.update(visible=True) | |
# Clear pending chunks after yielding | |
pending_chunks = [] | |
last_yield_time = current_time | |
# Yield any remaining chunks | |
if pending_chunks: | |
final_new_audio = np.concatenate(pending_chunks) | |
total_duration = sum(len(chunk) for chunk in all_audio_chunks) / sample_rate | |
log_update = log + f"π΅ Streaming final chunk: {total_duration:.1f}s total\n" | |
yield (sample_rate, final_new_audio), None, log_update, gr.update(visible=True) | |
has_yielded_audio = True # Mark that we yielded audio | |
# Wait for generation to complete (with timeout to prevent hanging) | |
generation_thread.join(timeout=5.0) # Increased timeout to 5 seconds | |
# If thread is still alive after timeout, force end | |
if generation_thread.is_alive(): | |
print("Warning: Generation thread did not complete within timeout") | |
audio_streamer.end() | |
generation_thread.join(timeout=5.0) | |
# Clean up | |
self.current_streamer = None | |
self.is_generating = False | |
generation_time = time.time() - start_time | |
# Check if stopped by user | |
if self.stop_generation: | |
yield None, None, "π Generation stopped by user", gr.update(visible=False) | |
return | |
# Debug logging | |
# print(f"Debug: has_received_chunks={has_received_chunks}, chunk_count={chunk_count}, all_audio_chunks length={len(all_audio_chunks)}") | |
# Check if we received any chunks but didn't yield audio | |
if has_received_chunks and not has_yielded_audio and all_audio_chunks: | |
# We have chunks but didn't meet the yield criteria, yield them now | |
complete_audio = np.concatenate(all_audio_chunks) | |
final_duration = len(complete_audio) / sample_rate | |
final_log = log + f"β±οΈ Generation completed in {generation_time:.2f} seconds\n" | |
final_log += f"π΅ Final audio duration: {final_duration:.2f} seconds\n" | |
final_log += f"π Total chunks: {chunk_count}\n" | |
final_log += "β¨ Generation successful! Complete audio is ready.\n" | |
final_log += "π‘ Not satisfied? You can regenerate or adjust the CFG scale for different results." | |
# Yield the complete audio | |
yield None, (sample_rate, complete_audio), final_log, gr.update(visible=False) | |
return | |
if not has_received_chunks: | |
error_log = log + f"\nβ Error: No audio chunks were received from the model. Generation time: {generation_time:.2f}s" | |
yield None, None, error_log, gr.update(visible=False) | |
return | |
if not has_yielded_audio: | |
error_log = log + f"\nβ Error: Audio was generated but not streamed. Chunk count: {chunk_count}" | |
yield None, None, error_log, gr.update(visible=False) | |
return | |
# Prepare the complete audio | |
if all_audio_chunks: | |
complete_audio = np.concatenate(all_audio_chunks) | |
final_duration = len(complete_audio) / sample_rate | |
final_log = log + f"β±οΈ Generation completed in {generation_time:.2f} seconds\n" | |
final_log += f"π΅ Final audio duration: {final_duration:.2f} seconds\n" | |
final_log += f"π Total chunks: {chunk_count}\n" | |
final_log += "β¨ Generation successful! Complete audio is ready in the 'Complete Audio' tab.\n" | |
final_log += "π‘ Not satisfied? You can regenerate or adjust the CFG scale for different results." | |
# Final yield: Clear streaming audio and provide complete audio | |
yield None, (sample_rate, complete_audio), final_log, gr.update(visible=False) | |
else: | |
final_log = log + "β No audio was generated." | |
yield None, None, final_log, gr.update(visible=False) | |
except gr.Error as e: | |
# Handle Gradio-specific errors (like input validation) | |
self.is_generating = False | |
self.current_streamer = None | |
error_msg = f"β Input Error: {str(e)}" | |
print(error_msg) | |
yield None, None, error_msg, gr.update(visible=False) | |
except Exception as e: | |
self.is_generating = False | |
self.current_streamer = None | |
error_msg = f"β An unexpected error occurred: {str(e)}" | |
print(error_msg) | |
import traceback | |
traceback.print_exc() | |
yield None, None, error_msg, gr.update(visible=False) | |
def _generate_with_streamer(self, inputs, cfg_scale, audio_streamer): | |
"""Helper method to run generation with streamer in a separate thread.""" | |
try: | |
# Check for stop signal before starting generation | |
if self.stop_generation: | |
audio_streamer.end() | |
return | |
# Define a stop check function that can be called from generate | |
def check_stop_generation(): | |
return self.stop_generation | |
outputs = self.model.generate( | |
**inputs, | |
max_new_tokens=None, | |
cfg_scale=cfg_scale, | |
tokenizer=self.processor.tokenizer, | |
generation_config={ | |
'do_sample': False, | |
}, | |
audio_streamer=audio_streamer, | |
stop_check_fn=check_stop_generation, # Pass the stop check function | |
verbose=False, # Disable verbose in streaming mode | |
refresh_negative=True, | |
) | |
except Exception as e: | |
print(f"Error in generation thread: {e}") | |
traceback.print_exc() | |
# Make sure to end the stream on error | |
audio_streamer.end() | |
def stop_audio_generation(self): | |
"""Stop the current audio generation process.""" | |
self.stop_generation = True | |
if self.current_streamer is not None: | |
try: | |
self.current_streamer.end() | |
except Exception as e: | |
print(f"Error stopping streamer: {e}") | |
print("π Audio generation stop requested") | |
def load_example_scripts(self): | |
"""Load example scripts from the text_examples directory.""" | |
examples_dir = os.path.join(os.path.dirname(__file__), "text_examples") | |
self.example_scripts = [] | |
# Check if text_examples directory exists | |
if not os.path.exists(examples_dir): | |
print(f"Warning: text_examples directory not found at {examples_dir}") | |
return | |
# Get all .txt files in the text_examples directory | |
txt_files = sorted([f for f in os.listdir(examples_dir) | |
if f.lower().endswith('.txt') and os.path.isfile(os.path.join(examples_dir, f))]) | |
for txt_file in txt_files: | |
file_path = os.path.join(examples_dir, txt_file) | |
import re | |
# Check if filename contains a time pattern like "45min", "90min", etc. | |
time_pattern = re.search(r'(\d+)min', txt_file.lower()) | |
if time_pattern: | |
minutes = int(time_pattern.group(1)) | |
if minutes > 15: | |
print(f"Skipping {txt_file}: duration {minutes} minutes exceeds 15-minute limit") | |
continue | |
try: | |
with open(file_path, 'r', encoding='utf-8') as f: | |
script_content = f.read().strip() | |
# Remove empty lines and lines with only whitespace | |
script_content = '\n'.join(line for line in script_content.split('\n') if line.strip()) | |
if not script_content: | |
continue | |
# Parse the script to determine number of speakers | |
num_speakers = self._get_num_speakers_from_script(script_content) | |
# Add to examples list as [num_speakers, script_content] | |
self.example_scripts.append([num_speakers, script_content]) | |
print(f"Loaded example: {txt_file} with {num_speakers} speakers") | |
except Exception as e: | |
print(f"Error loading example script {txt_file}: {e}") | |
if self.example_scripts: | |
print(f"Successfully loaded {len(self.example_scripts)} example scripts") | |
else: | |
print("No example scripts were loaded") | |
def _get_num_speakers_from_script(self, script: str) -> int: | |
"""Determine the number of unique speakers in a script.""" | |
import re | |
speakers = set() | |
lines = script.strip().split('\n') | |
for line in lines: | |
# Use regex to find speaker patterns | |
match = re.match(r'^Speaker\s+(\d+)\s*:', line.strip(), re.IGNORECASE) | |
if match: | |
speaker_id = int(match.group(1)) | |
speakers.add(speaker_id) | |
# If no speakers found, default to 1 | |
if not speakers: | |
return 1 | |
# Return the maximum speaker ID + 1 (assuming 0-based indexing) | |
# or the count of unique speakers if they're 1-based | |
max_speaker = max(speakers) | |
min_speaker = min(speakers) | |
if min_speaker == 0: | |
return max_speaker + 1 | |
else: | |
# Assume 1-based indexing, return the count | |
return len(speakers) | |
def create_demo_interface(demo_instance: VibeVoiceDemo): | |
"""Create the Gradio interface with streaming support.""" | |
# Custom CSS for high-end aesthetics with lighter theme | |
custom_css = """ | |
/* Modern light theme with gradients */ | |
.gradio-container { | |
background: linear-gradient(135deg, #f8fafc 0%, #e2e8f0 100%); | |
font-family: 'SF Pro Display', -apple-system, BlinkMacSystemFont, sans-serif; | |
} | |
/* Header styling */ | |
.main-header { | |
background: linear-gradient(90deg, #667eea 0%, #764ba2 100%); | |
padding: 2rem; | |
border-radius: 20px; | |
margin-bottom: 2rem; | |
text-align: center; | |
box-shadow: 0 10px 40px rgba(102, 126, 234, 0.3); | |
} | |
.main-header h1 { | |
color: white; | |
font-size: 2.5rem; | |
font-weight: 700; | |
margin: 0; | |
text-shadow: 0 2px 4px rgba(0,0,0,0.3); | |
} | |
.main-header p { | |
color: rgba(255,255,255,0.9); | |
font-size: 1.1rem; | |
margin: 0.5rem 0 0 0; | |
} | |
/* Card styling */ | |
.settings-card, .generation-card { | |
background: rgba(255, 255, 255, 0.8); | |
backdrop-filter: blur(10px); | |
border: 1px solid rgba(226, 232, 240, 0.8); | |
border-radius: 16px; | |
padding: 1.5rem; | |
margin-bottom: 1rem; | |
box-shadow: 0 8px 32px rgba(0, 0, 0, 0.1); | |
} | |
/* Speaker selection styling */ | |
.speaker-grid { | |
display: grid; | |
gap: 1rem; | |
margin-bottom: 1rem; | |
} | |
.speaker-item { | |
background: linear-gradient(135deg, #e2e8f0 0%, #cbd5e1 100%); | |
border: 1px solid rgba(148, 163, 184, 0.4); | |
border-radius: 12px; | |
padding: 1rem; | |
color: #374151; | |
font-weight: 500; | |
} | |
/* Streaming indicator */ | |
.streaming-indicator { | |
display: inline-block; | |
width: 10px; | |
height: 10px; | |
background: #22c55e; | |
border-radius: 50%; | |
margin-right: 8px; | |
animation: pulse 1.5s infinite; | |
} | |
@keyframes pulse { | |
0% { opacity: 1; transform: scale(1); } | |
50% { opacity: 0.5; transform: scale(1.1); } | |
100% { opacity: 1; transform: scale(1); } | |
} | |
/* Queue status styling */ | |
.queue-status { | |
background: linear-gradient(135deg, #f0f9ff 0%, #e0f2fe 100%); | |
border: 1px solid rgba(14, 165, 233, 0.3); | |
border-radius: 8px; | |
padding: 0.75rem; | |
margin: 0.5rem 0; | |
text-align: center; | |
font-size: 0.9rem; | |
color: #0369a1; | |
} | |
.generate-btn { | |
background: linear-gradient(135deg, #059669 0%, #0d9488 100%); | |
border: none; | |
border-radius: 12px; | |
padding: 1rem 2rem; | |
color: white; | |
font-weight: 600; | |
font-size: 1.1rem; | |
box-shadow: 0 4px 20px rgba(5, 150, 105, 0.4); | |
transition: all 0.3s ease; | |
} | |
.generate-btn:hover { | |
transform: translateY(-2px); | |
box-shadow: 0 6px 25px rgba(5, 150, 105, 0.6); | |
} | |
.stop-btn { | |
background: linear-gradient(135deg, #ef4444 0%, #dc2626 100%); | |
border: none; | |
border-radius: 12px; | |
padding: 1rem 2rem; | |
color: white; | |
font-weight: 600; | |
font-size: 1.1rem; | |
box-shadow: 0 4px 20px rgba(239, 68, 68, 0.4); | |
transition: all 0.3s ease; | |
} | |
.stop-btn:hover { | |
transform: translateY(-2px); | |
box-shadow: 0 6px 25px rgba(239, 68, 68, 0.6); | |
} | |
/* Audio player styling */ | |
.audio-output { | |
background: linear-gradient(135deg, #f1f5f9 0%, #e2e8f0 100%); | |
border-radius: 16px; | |
padding: 1.5rem; | |
border: 1px solid rgba(148, 163, 184, 0.3); | |
} | |
.complete-audio-section { | |
margin-top: 1rem; | |
padding: 1rem; | |
background: linear-gradient(135deg, #f0fdf4 0%, #dcfce7 100%); | |
border: 1px solid rgba(34, 197, 94, 0.3); | |
border-radius: 12px; | |
} | |
/* Text areas */ | |
.script-input, .log-output { | |
background: rgba(255, 255, 255, 0.9) !important; | |
border: 1px solid rgba(148, 163, 184, 0.4) !important; | |
border-radius: 12px !important; | |
color: #1e293b !important; | |
font-family: 'JetBrains Mono', monospace !important; | |
} | |
.script-input::placeholder { | |
color: #64748b !important; | |
} | |
/* Sliders */ | |
.slider-container { | |
background: rgba(248, 250, 252, 0.8); | |
border: 1px solid rgba(226, 232, 240, 0.6); | |
border-radius: 8px; | |
padding: 1rem; | |
margin: 0.5rem 0; | |
} | |
/* Labels and text */ | |
.gradio-container label { | |
color: #374151 !important; | |
font-weight: 600 !important; | |
} | |
.gradio-container .markdown { | |
color: #1f2937 !important; | |
} | |
/* Responsive design */ | |
@media (max-width: 768px) { | |
.main-header h1 { font-size: 2rem; } | |
.settings-card, .generation-card { padding: 1rem; } | |
} | |
/* Random example button styling - more subtle professional color */ | |
.random-btn { | |
background: linear-gradient(135deg, #64748b 0%, #475569 100%); | |
border: none; | |
border-radius: 12px; | |
padding: 1rem 1.5rem; | |
color: white; | |
font-weight: 600; | |
font-size: 1rem; | |
box-shadow: 0 4px 20px rgba(100, 116, 139, 0.3); | |
transition: all 0.3s ease; | |
display: inline-flex; | |
align-items: center; | |
gap: 0.5rem; | |
} | |
.random-btn:hover { | |
transform: translateY(-2px); | |
box-shadow: 0 6px 25px rgba(100, 116, 139, 0.4); | |
background: linear-gradient(135deg, #475569 0%, #334155 100%); | |
} | |
""" | |
with gr.Blocks( | |
title="VibeVoice - AI Podcast Generator", | |
css=custom_css, | |
theme=gr.themes.Soft( | |
primary_hue="blue", | |
secondary_hue="purple", | |
neutral_hue="slate", | |
) | |
) as interface: | |
# Header | |
gr.HTML(""" | |
<div class="main-header"> | |
<h1>ποΈ Vibe Podcasting </h1> | |
<p>Generating Long-form Multi-speaker AI Podcast with VibeVoice</p> | |
</div> | |
""") | |
with gr.Row(): | |
# Left column - Settings | |
with gr.Column(scale=1, elem_classes="settings-card"): | |
gr.Markdown("### ποΈ **Podcast Settings**") | |
# Number of speakers | |
num_speakers = gr.Slider( | |
minimum=1, | |
maximum=4, | |
value=2, | |
step=1, | |
label="Number of Speakers", | |
elem_classes="slider-container" | |
) | |
# Speaker selection | |
gr.Markdown("### π **Speaker Selection**") | |
available_speaker_names = list(demo_instance.available_voices.keys()) | |
# default_speakers = available_speaker_names[:4] if len(available_speaker_names) >= 4 else available_speaker_names | |
default_speakers = ['en-Alice_woman', 'en-Carter_man', 'en-Frank_man', 'en-Maya_woman'] | |
speaker_selections = [] | |
for i in range(4): | |
default_value = default_speakers[i] if i < len(default_speakers) else None | |
speaker = gr.Dropdown( | |
choices=available_speaker_names, | |
value=default_value, | |
label=f"Speaker {i+1}", | |
visible=(i < 2), # Initially show only first 2 speakers | |
elem_classes="speaker-item" | |
) | |
speaker_selections.append(speaker) | |
# Advanced settings | |
gr.Markdown("### βοΈ **Advanced Settings**") | |
# Sampling parameters (contains all generation settings) | |
with gr.Accordion("Generation Parameters", open=False): | |
cfg_scale = gr.Slider( | |
minimum=1.0, | |
maximum=2.0, | |
value=1.3, | |
step=0.05, | |
label="CFG Scale (Guidance Strength)", | |
# info="Higher values increase adherence to text", | |
elem_classes="slider-container" | |
) | |
# Right column - Generation | |
with gr.Column(scale=2, elem_classes="generation-card"): | |
gr.Markdown("### π **Script Input**") | |
script_input = gr.Textbox( | |
label="Conversation Script", | |
placeholder="""Enter your podcast script here. You can format it as: | |
Speaker 0: Welcome to our podcast today! | |
Speaker 1: Thanks for having me. I'm excited to discuss... | |
Or paste text directly and it will auto-assign speakers.""", | |
lines=12, | |
max_lines=20, | |
elem_classes="script-input" | |
) | |
# Button row with Random Example on the left and Generate on the right | |
with gr.Row(): | |
# Random example button (now on the left) | |
random_example_btn = gr.Button( | |
"π² Random Example", | |
size="lg", | |
variant="secondary", | |
elem_classes="random-btn", | |
scale=1 # Smaller width | |
) | |
# Generate button (now on the right) | |
generate_btn = gr.Button( | |
"π Generate Podcast", | |
size="lg", | |
variant="primary", | |
elem_classes="generate-btn", | |
scale=2 # Wider than random button | |
) | |
# Stop button | |
stop_btn = gr.Button( | |
"π Stop Generation", | |
size="lg", | |
variant="stop", | |
elem_classes="stop-btn", | |
visible=False | |
) | |
# Streaming status indicator | |
streaming_status = gr.HTML( | |
value=""" | |
<div style="background: linear-gradient(135deg, #dcfce7 0%, #bbf7d0 100%); | |
border: 1px solid rgba(34, 197, 94, 0.3); | |
border-radius: 8px; | |
padding: 0.75rem; | |
margin: 0.5rem 0; | |
text-align: center; | |
font-size: 0.9rem; | |
color: #166534;"> | |
<span class="streaming-indicator"></span> | |
<strong>LIVE STREAMING</strong> - Audio is being generated in real-time | |
</div> | |
""", | |
visible=False, | |
elem_id="streaming-status" | |
) | |
# Output section | |
gr.Markdown("### π΅ **Generated Podcast**") | |
# Streaming audio output (outside of tabs for simpler handling) | |
audio_output = gr.Audio( | |
label="Streaming Audio (Real-time)", | |
type="numpy", | |
elem_classes="audio-output", | |
streaming=True, # Enable streaming mode | |
autoplay=True, | |
show_download_button=False, # Explicitly show download button | |
visible=True | |
) | |
# Complete audio output (non-streaming) | |
complete_audio_output = gr.Audio( | |
label="Complete Podcast (Download after generation)", | |
type="numpy", | |
elem_classes="audio-output complete-audio-section", | |
streaming=False, # Non-streaming mode | |
autoplay=False, | |
show_download_button=True, # Explicitly show download button | |
visible=False # Initially hidden, shown when audio is ready | |
) | |
gr.Markdown(""" | |
*π‘ **Streaming**: Audio plays as it's being generated (may have slight pauses) | |
*π‘ **Complete Audio**: Will appear below after generation finishes* | |
""") | |
# Generation log | |
log_output = gr.Textbox( | |
label="Generation Log", | |
lines=8, | |
max_lines=15, | |
interactive=False, | |
elem_classes="log-output" | |
) | |
def update_speaker_visibility(num_speakers): | |
updates = [] | |
for i in range(4): | |
updates.append(gr.update(visible=(i < num_speakers))) | |
return updates | |
num_speakers.change( | |
fn=update_speaker_visibility, | |
inputs=[num_speakers], | |
outputs=speaker_selections | |
) | |
# Main generation function with streaming | |
def generate_podcast_wrapper(num_speakers, script, *speakers_and_params): | |
"""Wrapper function to handle the streaming generation call.""" | |
try: | |
# Extract speakers and parameters | |
speakers = speakers_and_params[:4] # First 4 are speaker selections | |
cfg_scale = speakers_and_params[4] # CFG scale | |
# Clear outputs and reset visibility at start | |
yield None, gr.update(value=None, visible=False), "ποΈ Starting generation...", gr.update(visible=True), gr.update(visible=False), gr.update(visible=True) | |
# The generator will yield multiple times | |
final_log = "Starting generation..." | |
for streaming_audio, complete_audio, log, streaming_visible in demo_instance.generate_podcast_streaming( | |
num_speakers=int(num_speakers), | |
script=script, | |
speaker_1=speakers[0], | |
speaker_2=speakers[1], | |
speaker_3=speakers[2], | |
speaker_4=speakers[3], | |
cfg_scale=cfg_scale | |
): | |
final_log = log | |
# Check if we have complete audio (final yield) | |
if complete_audio is not None: | |
# Final state: clear streaming, show complete audio | |
yield None, gr.update(value=complete_audio, visible=True), log, gr.update(visible=False), gr.update(visible=True), gr.update(visible=False) | |
else: | |
# Streaming state: update streaming audio only | |
if streaming_audio is not None: | |
yield streaming_audio, gr.update(visible=False), log, streaming_visible, gr.update(visible=False), gr.update(visible=True) | |
else: | |
# No new audio, just update status | |
yield None, gr.update(visible=False), log, streaming_visible, gr.update(visible=False), gr.update(visible=True) | |
except Exception as e: | |
error_msg = f"β A critical error occurred in the wrapper: {str(e)}" | |
print(error_msg) | |
import traceback | |
traceback.print_exc() | |
# Reset button states on error | |
yield None, gr.update(value=None, visible=False), error_msg, gr.update(visible=False), gr.update(visible=True), gr.update(visible=False) | |
def stop_generation_handler(): | |
"""Handle stopping generation.""" | |
demo_instance.stop_audio_generation() | |
# Return values for: log_output, streaming_status, generate_btn, stop_btn | |
return "π Generation stopped.", gr.update(visible=False), gr.update(visible=True), gr.update(visible=False) | |
# Add a clear audio function | |
def clear_audio_outputs(): | |
"""Clear both audio outputs before starting new generation.""" | |
return None, gr.update(value=None, visible=False) | |
# Connect generation button with streaming outputs | |
generate_btn.click( | |
fn=clear_audio_outputs, | |
inputs=[], | |
outputs=[audio_output, complete_audio_output], | |
queue=False | |
).then( | |
fn=generate_podcast_wrapper, | |
inputs=[num_speakers, script_input] + speaker_selections + [cfg_scale], | |
outputs=[audio_output, complete_audio_output, log_output, streaming_status, generate_btn, stop_btn], | |
queue=True # Enable Gradio's built-in queue | |
) | |
# Connect stop button | |
stop_btn.click( | |
fn=stop_generation_handler, | |
inputs=[], | |
outputs=[log_output, streaming_status, generate_btn, stop_btn], | |
queue=False # Don't queue stop requests | |
).then( | |
# Clear both audio outputs after stopping | |
fn=lambda: (None, None), | |
inputs=[], | |
outputs=[audio_output, complete_audio_output], | |
queue=False | |
) | |
# Function to randomly select an example | |
def load_random_example(): | |
"""Randomly select and load an example script.""" | |
import random | |
# Get available examples | |
if hasattr(demo_instance, 'example_scripts') and demo_instance.example_scripts: | |
example_scripts = demo_instance.example_scripts | |
else: | |
# Fallback to default | |
example_scripts = [ | |
[2, "Speaker 0: Welcome to our AI podcast demonstration!\nSpeaker 1: Thanks for having me. This is exciting!"] | |
] | |
# Randomly select one | |
if example_scripts: | |
selected = random.choice(example_scripts) | |
num_speakers_value = selected[0] | |
script_value = selected[1] | |
# Return the values to update the UI | |
return num_speakers_value, script_value | |
# Default values if no examples | |
return 2, "" | |
# Connect random example button | |
random_example_btn.click( | |
fn=load_random_example, | |
inputs=[], | |
outputs=[num_speakers, script_input], | |
queue=False # Don't queue this simple operation | |
) | |
# Add usage tips | |
gr.Markdown(""" | |
### π‘ **Usage Tips** | |
- Click **π Generate Podcast** to start audio generation | |
- **Live Streaming** tab shows audio as it's generated (may have slight pauses) | |
- **Complete Audio** tab provides the full, uninterrupted podcast after generation | |
- During generation, you can click **π Stop Generation** to interrupt the process | |
- The streaming indicator shows real-time generation progress | |
""") | |
# Add example scripts | |
gr.Markdown("### π **Example Scripts**") | |
# Use dynamically loaded examples if available, otherwise provide a default | |
if hasattr(demo_instance, 'example_scripts') and demo_instance.example_scripts: | |
example_scripts = demo_instance.example_scripts | |
else: | |
# Fallback to a simple default example if no scripts loaded | |
example_scripts = [ | |
[1, "Speaker 1: Welcome to our AI podcast demonstration! This is a sample script showing how VibeVoice can generate natural-sounding speech."] | |
] | |
gr.Examples( | |
examples=example_scripts, | |
inputs=[num_speakers, script_input], | |
label="Try these example scripts:" | |
) | |
return interface | |
def convert_to_16_bit_wav(data): | |
# Check if data is a tensor and move to cpu | |
if torch.is_tensor(data): | |
data = data.detach().cpu().numpy() | |
# Ensure data is numpy array | |
data = np.array(data) | |
# Normalize to range [-1, 1] if it's not already | |
if np.max(np.abs(data)) > 1.0: | |
data = data / np.max(np.abs(data)) | |
# Scale to 16-bit integer range | |
data = (data * 32767).astype(np.int16) | |
return data | |
def parse_args(): | |
parser = argparse.ArgumentParser(description="VibeVoice Gradio Demo") | |
parser.add_argument( | |
"--model_path", | |
type=str, | |
default="/tmp/vibevoice-model", | |
help="Path to the VibeVoice model directory", | |
) | |
parser.add_argument( | |
"--device", | |
type=str, | |
default="cuda" if torch.cuda.is_available() else "cpu", | |
help="Device for inference", | |
) | |
parser.add_argument( | |
"--inference_steps", | |
type=int, | |
default=10, | |
help="Number of inference steps for DDPM (not exposed to users)", | |
) | |
parser.add_argument( | |
"--share", | |
action="store_true", | |
help="Share the demo publicly via Gradio", | |
) | |
parser.add_argument( | |
"--port", | |
type=int, | |
default=7860, | |
help="Port to run the demo on", | |
) | |
return parser.parse_args() | |
def main(): | |
"""Main function to run the demo.""" | |
args = parse_args() | |
set_seed(42) # Set a fixed seed for reproducibility | |
print("ποΈ Initializing VibeVoice Demo with Streaming Support...") | |
# Initialize demo instance | |
demo_instance = VibeVoiceDemo( | |
model_path='microsoft/VibeVoice-1.5B', | |
device='cuda', | |
inference_steps=10 | |
) | |
# Create interface | |
interface = create_demo_interface(demo_instance) | |
print(f"π Launching demo on port {args.port}") | |
print(f"π Model path: {args.model_path}") | |
print(f"π Available voices: {len(demo_instance.available_voices)}") | |
print(f"π΄ Streaming mode: ENABLED") | |
print(f"π Session isolation: ENABLED") | |
# Launch the interface | |
try: | |
interface.queue( | |
max_size=20, # Maximum queue size | |
).launch( | |
show_error=True, | |
show_api=False # Hide API docs for cleaner interface | |
) | |
except KeyboardInterrupt: | |
print("\nπ Shutting down gracefully...") | |
except Exception as e: | |
print(f"β Server error: {e}") | |
raise | |
if __name__ == "__main__": | |
main() |