Spaces:
Running
Running
| ο»Ώimport os | |
| import torch | |
| import tempfile | |
| import gradio as gr | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, HttpUrl | |
| import subprocess | |
| import json | |
| from pathlib import Path | |
| import logging | |
| import requests | |
| from urllib.parse import urlparse | |
| from PIL import Image | |
| import io | |
| from typing import Optional | |
| import aiohttp | |
| import asyncio | |
| from dotenv import load_dotenv | |
| # Load environment variables | |
| load_dotenv() | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| app = FastAPI(title="OmniAvatar-14B API with ElevenLabs", version="1.0.0") | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Pydantic models for request/response | |
| class GenerateRequest(BaseModel): | |
| prompt: str | |
| text_to_speech: Optional[str] = None # Text to convert to speech | |
| elevenlabs_audio_url: Optional[HttpUrl] = None # Direct audio URL | |
| voice_id: Optional[str] = "21m00Tcm4TlvDq8ikWAM" # Default ElevenLabs voice | |
| image_url: Optional[HttpUrl] = None | |
| guidance_scale: float = 5.0 | |
| audio_scale: float = 3.0 | |
| num_steps: int = 30 | |
| sp_size: int = 1 | |
| tea_cache_l1_thresh: Optional[float] = None | |
| class GenerateResponse(BaseModel): | |
| message: str | |
| output_path: str | |
| processing_time: float | |
| audio_generated: bool = False | |
| class ElevenLabsClient: | |
| def __init__(self, api_key: str = None): | |
| self.api_key = api_key or os.getenv("ELEVENLABS_API_KEY", "sk_c7a0b115cd48fc026226158c5ac87755b063c802ad892de6") | |
| self.base_url = "https://api.elevenlabs.io/v1" | |
| async def text_to_speech(self, text: str, voice_id: str = "21m00Tcm4TlvDq8ikWAM") -> str: | |
| """Convert text to speech using ElevenLabs and return temporary file path""" | |
| url = f"{self.base_url}/text-to-speech/{voice_id}" | |
| headers = { | |
| "Accept": "audio/mpeg", | |
| "Content-Type": "application/json", | |
| "xi-api-key": self.api_key | |
| } | |
| data = { | |
| "text": text, | |
| "model_id": "eleven_monolingual_v1", | |
| "voice_settings": { | |
| "stability": 0.5, | |
| "similarity_boost": 0.5 | |
| } | |
| } | |
| try: | |
| async with aiohttp.ClientSession() as session: | |
| async with session.post(url, headers=headers, json=data) as response: | |
| if response.status != 200: | |
| error_text = await response.text() | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"ElevenLabs API error: {response.status} - {error_text}" | |
| ) | |
| audio_content = await response.read() | |
| # Save to temporary file | |
| temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp3') | |
| temp_file.write(audio_content) | |
| temp_file.close() | |
| logger.info(f"Generated speech audio: {temp_file.name}") | |
| return temp_file.name | |
| except aiohttp.ClientError as e: | |
| logger.error(f"Network error calling ElevenLabs: {e}") | |
| raise HTTPException(status_code=400, detail=f"Network error calling ElevenLabs: {e}") | |
| except Exception as e: | |
| logger.error(f"Error generating speech: {e}") | |
| raise HTTPException(status_code=500, detail=f"Error generating speech: {e}") | |
| class OmniAvatarAPI: | |
| def __init__(self): | |
| self.model_loaded = False | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.elevenlabs_client = ElevenLabsClient() | |
| logger.info(f"Using device: {self.device}") | |
| logger.info(f"ElevenLabs API Key configured: {'Yes' if self.elevenlabs_client.api_key else 'No'}") | |
| def load_model(self): | |
| """Load the OmniAvatar model""" | |
| try: | |
| # Check if models are downloaded | |
| model_paths = [ | |
| "./pretrained_models/Wan2.1-T2V-14B", | |
| "./pretrained_models/OmniAvatar-14B", | |
| "./pretrained_models/wav2vec2-base-960h" | |
| ] | |
| for path in model_paths: | |
| if not os.path.exists(path): | |
| logger.error(f"Model path not found: {path}") | |
| return False | |
| self.model_loaded = True | |
| logger.info("Models loaded successfully") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Error loading model: {str(e)}") | |
| return False | |
| async def download_file(self, url: str, suffix: str = "") -> str: | |
| """Download file from URL and save to temporary location""" | |
| try: | |
| async with aiohttp.ClientSession() as session: | |
| async with session.get(str(url)) as response: | |
| if response.status != 200: | |
| raise HTTPException(status_code=400, detail=f"Failed to download file from URL: {url}") | |
| content = await response.read() | |
| # Create temporary file | |
| temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=suffix) | |
| temp_file.write(content) | |
| temp_file.close() | |
| return temp_file.name | |
| except aiohttp.ClientError as e: | |
| logger.error(f"Network error downloading {url}: {e}") | |
| raise HTTPException(status_code=400, detail=f"Network error downloading file: {e}") | |
| except Exception as e: | |
| logger.error(f"Error downloading file from {url}: {e}") | |
| raise HTTPException(status_code=500, detail=f"Error downloading file: {e}") | |
| def validate_audio_url(self, url: str) -> bool: | |
| """Validate if URL is likely an audio file""" | |
| try: | |
| parsed = urlparse(url) | |
| # Check for common audio file extensions or ElevenLabs patterns | |
| audio_extensions = ['.mp3', '.wav', '.m4a', '.ogg', '.aac'] | |
| is_audio_ext = any(parsed.path.lower().endswith(ext) for ext in audio_extensions) | |
| is_elevenlabs = 'elevenlabs' in parsed.netloc.lower() | |
| return is_audio_ext or is_elevenlabs or 'audio' in url.lower() | |
| except: | |
| return False | |
| def validate_image_url(self, url: str) -> bool: | |
| """Validate if URL is likely an image file""" | |
| try: | |
| parsed = urlparse(url) | |
| image_extensions = ['.jpg', '.jpeg', '.png', '.webp', '.bmp', '.gif'] | |
| return any(parsed.path.lower().endswith(ext) for ext in image_extensions) | |
| except: | |
| return False | |
| async def generate_avatar(self, request: GenerateRequest) -> tuple[str, float, bool]: | |
| """Generate avatar video from prompt and audio/text""" | |
| import time | |
| start_time = time.time() | |
| audio_generated = False | |
| try: | |
| # Determine audio source | |
| audio_path = None | |
| if request.text_to_speech: | |
| # Generate speech from text using ElevenLabs | |
| logger.info(f"Generating speech from text: {request.text_to_speech[:50]}...") | |
| audio_path = await self.elevenlabs_client.text_to_speech( | |
| request.text_to_speech, | |
| request.voice_id or "21m00Tcm4TlvDq8ikWAM" | |
| ) | |
| audio_generated = True | |
| elif request.elevenlabs_audio_url: | |
| # Download audio from provided URL | |
| logger.info(f"Downloading audio from URL: {request.elevenlabs_audio_url}") | |
| if not self.validate_audio_url(str(request.elevenlabs_audio_url)): | |
| logger.warning(f"Audio URL may not be valid: {request.elevenlabs_audio_url}") | |
| audio_path = await self.download_file(str(request.elevenlabs_audio_url), ".mp3") | |
| else: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="Either text_to_speech or elevenlabs_audio_url must be provided" | |
| ) | |
| # Download image if provided | |
| image_path = None | |
| if request.image_url: | |
| logger.info(f"Downloading image from URL: {request.image_url}") | |
| if not self.validate_image_url(str(request.image_url)): | |
| logger.warning(f"Image URL may not be valid: {request.image_url}") | |
| # Determine image extension from URL or default to .jpg | |
| parsed = urlparse(str(request.image_url)) | |
| ext = os.path.splitext(parsed.path)[1] or ".jpg" | |
| image_path = await self.download_file(str(request.image_url), ext) | |
| # Create temporary input file for inference | |
| with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f: | |
| if image_path: | |
| input_line = f"{request.prompt}@@{image_path}@@{audio_path}" | |
| else: | |
| input_line = f"{request.prompt}@@@@{audio_path}" | |
| f.write(input_line) | |
| temp_input_file = f.name | |
| # Prepare inference command | |
| cmd = [ | |
| "python", "-m", "torch.distributed.run", | |
| "--standalone", f"--nproc_per_node={request.sp_size}", | |
| "scripts/inference.py", | |
| "--config", "configs/inference.yaml", | |
| "--input_file", temp_input_file, | |
| "--guidance_scale", str(request.guidance_scale), | |
| "--audio_scale", str(request.audio_scale), | |
| "--num_steps", str(request.num_steps) | |
| ] | |
| if request.tea_cache_l1_thresh: | |
| cmd.extend(["--tea_cache_l1_thresh", str(request.tea_cache_l1_thresh)]) | |
| logger.info(f"Running inference with command: {' '.join(cmd)}") | |
| # Run inference | |
| result = subprocess.run(cmd, capture_output=True, text=True) | |
| # Clean up temporary files | |
| os.unlink(temp_input_file) | |
| os.unlink(audio_path) | |
| if image_path: | |
| os.unlink(image_path) | |
| if result.returncode != 0: | |
| logger.error(f"Inference failed: {result.stderr}") | |
| raise Exception(f"Inference failed: {result.stderr}") | |
| # Find output video file | |
| output_dir = "./outputs" | |
| if os.path.exists(output_dir): | |
| video_files = [f for f in os.listdir(output_dir) if f.endswith(('.mp4', '.avi'))] | |
| if video_files: | |
| # Return the most recent video file | |
| video_files.sort(key=lambda x: os.path.getmtime(os.path.join(output_dir, x)), reverse=True) | |
| output_path = os.path.join(output_dir, video_files[0]) | |
| processing_time = time.time() - start_time | |
| return output_path, processing_time, audio_generated | |
| raise Exception("No output video generated") | |
| except Exception as e: | |
| # Clean up any temporary files in case of error | |
| try: | |
| if 'audio_path' in locals() and audio_path and os.path.exists(audio_path): | |
| os.unlink(audio_path) | |
| if 'image_path' in locals() and image_path and os.path.exists(image_path): | |
| os.unlink(image_path) | |
| if 'temp_input_file' in locals() and os.path.exists(temp_input_file): | |
| os.unlink(temp_input_file) | |
| except: | |
| pass | |
| logger.error(f"Generation error: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # Initialize API | |
| omni_api = OmniAvatarAPI() | |
| async def startup_event(): | |
| """Load model on startup""" | |
| success = omni_api.load_model() | |
| if not success: | |
| logger.warning("Model loading failed on startup") | |
| async def health_check(): | |
| """Health check endpoint""" | |
| return { | |
| "status": "healthy", | |
| "model_loaded": omni_api.model_loaded, | |
| "device": omni_api.device, | |
| "supports_elevenlabs": True, | |
| "supports_image_urls": True, | |
| "supports_text_to_speech": True, | |
| "elevenlabs_api_configured": bool(omni_api.elevenlabs_client.api_key) | |
| } | |
| async def generate_avatar(request: GenerateRequest): | |
| """Generate avatar video from prompt, text/audio, and optional image URL""" | |
| if not omni_api.model_loaded: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| logger.info(f"Generating avatar with prompt: {request.prompt}") | |
| if request.text_to_speech: | |
| logger.info(f"Text to speech: {request.text_to_speech[:100]}...") | |
| logger.info(f"Voice ID: {request.voice_id}") | |
| if request.elevenlabs_audio_url: | |
| logger.info(f"Audio URL: {request.elevenlabs_audio_url}") | |
| if request.image_url: | |
| logger.info(f"Image URL: {request.image_url}") | |
| try: | |
| output_path, processing_time, audio_generated = await omni_api.generate_avatar(request) | |
| return GenerateResponse( | |
| message="Avatar generation completed successfully", | |
| output_path=output_path, | |
| processing_time=processing_time, | |
| audio_generated=audio_generated | |
| ) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Unexpected error: {e}") | |
| raise HTTPException(status_code=500, detail=f"Unexpected error: {e}") | |
| # Enhanced Gradio interface with text-to-speech option | |
| def gradio_generate(prompt, text_to_speech, audio_url, image_url, voice_id, guidance_scale, audio_scale, num_steps): | |
| """Gradio interface wrapper with text-to-speech support""" | |
| if not omni_api.model_loaded: | |
| return "Error: Model not loaded" | |
| try: | |
| # Create request object | |
| request_data = { | |
| "prompt": prompt, | |
| "guidance_scale": guidance_scale, | |
| "audio_scale": audio_scale, | |
| "num_steps": int(num_steps) | |
| } | |
| # Add audio source | |
| if text_to_speech and text_to_speech.strip(): | |
| request_data["text_to_speech"] = text_to_speech | |
| request_data["voice_id"] = voice_id or "21m00Tcm4TlvDq8ikWAM" | |
| elif audio_url and audio_url.strip(): | |
| request_data["elevenlabs_audio_url"] = audio_url | |
| else: | |
| return "Error: Please provide either text to speech or audio URL" | |
| if image_url and image_url.strip(): | |
| request_data["image_url"] = image_url | |
| request = GenerateRequest(**request_data) | |
| # Run async function in sync context | |
| loop = asyncio.new_event_loop() | |
| asyncio.set_event_loop(loop) | |
| output_path, processing_time, audio_generated = loop.run_until_complete(omni_api.generate_avatar(request)) | |
| loop.close() | |
| return output_path | |
| except Exception as e: | |
| logger.error(f"Gradio generation error: {e}") | |
| return f"Error: {str(e)}" | |
| # Updated Gradio interface with text-to-speech support | |
| iface = gr.Interface( | |
| fn=gradio_generate, | |
| inputs=[ | |
| gr.Textbox( | |
| label="Prompt", | |
| placeholder="Describe the character behavior (e.g., 'A friendly person explaining a concept')", | |
| lines=2 | |
| ), | |
| gr.Textbox( | |
| label="Text to Speech", | |
| placeholder="Enter text to convert to speech using ElevenLabs", | |
| lines=3, | |
| info="This will be converted to speech automatically" | |
| ), | |
| gr.Textbox( | |
| label="OR Audio URL", | |
| placeholder="https://api.elevenlabs.io/v1/text-to-speech/...", | |
| info="Direct URL to audio file (alternative to text-to-speech)" | |
| ), | |
| gr.Textbox( | |
| label="Image URL (Optional)", | |
| placeholder="https://example.com/image.jpg", | |
| info="Direct URL to reference image (JPG, PNG, etc.)" | |
| ), | |
| gr.Dropdown( | |
| choices=["21m00Tcm4TlvDq8ikWAM", "pNInz6obpgDQGcFmaJgB", "EXAVITQu4vr4xnSDxMaL"], | |
| value="21m00Tcm4TlvDq8ikWAM", | |
| label="ElevenLabs Voice ID", | |
| info="Choose voice for text-to-speech" | |
| ), | |
| gr.Slider(minimum=1, maximum=10, value=5.0, label="Guidance Scale", info="4-6 recommended"), | |
| gr.Slider(minimum=1, maximum=10, value=3.0, label="Audio Scale", info="Higher values = better lip-sync"), | |
| gr.Slider(minimum=10, maximum=100, value=30, step=1, label="Number of Steps", info="20-50 recommended") | |
| ], | |
| outputs=gr.Video(label="Generated Avatar Video"), | |
| title="π OmniAvatar-14B with ElevenLabs TTS", | |
| description=""" | |
| Generate avatar videos with lip-sync from text prompts and speech. | |
| **Features:** | |
| - β **Text-to-Speech**: Enter text to generate speech automatically | |
| - β **ElevenLabs Integration**: High-quality voice synthesis | |
| - β **Audio URL Support**: Use pre-generated audio files | |
| - β **Image URL Support**: Reference images for character appearance | |
| - β **Customizable Parameters**: Fine-tune generation quality | |
| **Usage:** | |
| 1. Enter a character description in the prompt | |
| 2. **Either** enter text for speech generation **OR** provide an audio URL | |
| 3. Optionally add a reference image URL | |
| 4. Choose voice and adjust parameters | |
| 5. Generate your avatar video! | |
| **Tips:** | |
| - Use guidance scale 4-6 for best prompt following | |
| - Increase audio scale for better lip-sync | |
| - Clear, descriptive prompts work best | |
| """, | |
| examples=[ | |
| [ | |
| "A professional teacher explaining a mathematical concept with clear gestures", | |
| "Hello students! Today we're going to learn about calculus and how derivatives work in real life.", | |
| "", | |
| "https://example.com/teacher.jpg", | |
| "21m00Tcm4TlvDq8ikWAM", | |
| 5.0, | |
| 3.5, | |
| 30 | |
| ], | |
| [ | |
| "A friendly presenter speaking confidently to an audience", | |
| "Welcome everyone to our presentation on artificial intelligence and its applications!", | |
| "", | |
| "", | |
| "pNInz6obpgDQGcFmaJgB", | |
| 5.5, | |
| 4.0, | |
| 35 | |
| ] | |
| ] | |
| ) | |
| # Mount Gradio app | |
| app = gr.mount_gradio_app(app, iface, path="/gradio") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |