import os import torch import tempfile import gradio as gr from fastapi import FastAPI, HTTPException from fastapi.staticfiles import StaticFiles from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, HttpUrl import subprocess import json from pathlib import Path import logging import requests from urllib.parse import urlparse from PIL import Image import io from typing import Optional import aiohttp import asyncio from dotenv import load_dotenv # Load environment variables load_dotenv() # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI(title="OmniAvatar-14B API with Advanced TTS", version="1.0.0") # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Mount static files for serving generated videos app.mount("/outputs", StaticFiles(directory="outputs"), name="outputs") def get_video_url(output_path: str) -> str: """Convert local file path to accessible URL""" try: from pathlib import Path filename = Path(output_path).name # For HuggingFace Spaces, construct the URL base_url = "https://bravedims-ai-avatar-chat.hf.space" video_url = f"{base_url}/outputs/{filename}" logger.info(f"Generated video URL: {video_url}") return video_url except Exception as e: logger.error(f"Error creating video URL: {e}") return output_path # Fallback to original path # Pydantic models for request/response class GenerateRequest(BaseModel): prompt: str text_to_speech: Optional[str] = None # Text to convert to speech audio_url: Optional[HttpUrl] = None # Direct audio URL voice_id: Optional[str] = "21m00Tcm4TlvDq8ikWAM" # Voice profile ID image_url: Optional[HttpUrl] = None guidance_scale: float = 5.0 audio_scale: float = 3.0 num_steps: int = 30 sp_size: int = 1 tea_cache_l1_thresh: Optional[float] = None class GenerateResponse(BaseModel): message: str output_path: str processing_time: float audio_generated: bool = False tts_method: Optional[str] = None # Try to import TTS clients, but make them optional try: from advanced_tts_client_fixed import AdvancedTTSClient ADVANCED_TTS_AVAILABLE = True logger.info("✅ Advanced TTS client available") except ImportError as e: ADVANCED_TTS_AVAILABLE = False logger.warning(f"⚠️ Advanced TTS client not available: {e}") # Always import the robust fallback try: from robust_tts_client import RobustTTSClient ROBUST_TTS_AVAILABLE = True logger.info("✅ Robust TTS client available") except ImportError as e: ROBUST_TTS_AVAILABLE = False logger.error(f"❌ Robust TTS client not available: {e}") class TTSManager: """Manages multiple TTS clients with fallback chain""" def __init__(self): # Initialize TTS clients based on availability self.advanced_tts = None self.robust_tts = None self.clients_loaded = False if ADVANCED_TTS_AVAILABLE: try: self.advanced_tts = AdvancedTTSClient() logger.info("✅ Advanced TTS client initialized") except Exception as e: logger.warning(f"⚠️ Advanced TTS client initialization failed: {e}") if ROBUST_TTS_AVAILABLE: try: self.robust_tts = RobustTTSClient() logger.info("✅ Robust TTS client initialized") except Exception as e: logger.error(f"❌ Robust TTS client initialization failed: {e}") if not self.advanced_tts and not self.robust_tts: logger.error("❌ No TTS clients available!") async def load_models(self): """Load TTS models""" try: logger.info("Loading TTS models...") # Try to load advanced TTS first if self.advanced_tts: try: success = await self.advanced_tts.load_models() if success: logger.info("✅ Advanced TTS models loaded successfully") else: logger.warning("⚠️ Advanced TTS models failed to load") except Exception as e: logger.warning(f"⚠️ Advanced TTS loading error: {e}") # Always ensure robust TTS is available if self.robust_tts: try: await self.robust_tts.load_model() logger.info("✅ Robust TTS fallback ready") except Exception as e: logger.error(f"❌ Robust TTS loading failed: {e}") self.clients_loaded = True return True except Exception as e: logger.error(f"❌ TTS manager initialization failed: {e}") return False async def text_to_speech(self, text: str, voice_id: Optional[str] = None) -> tuple[str, str]: """ Convert text to speech with fallback chain Returns: (audio_file_path, method_used) """ if not self.clients_loaded: logger.info("TTS models not loaded, loading now...") await self.load_models() logger.info(f"Generating speech: {text[:50]}...") logger.info(f"Voice ID: {voice_id}") # Try Advanced TTS first (Facebook VITS / SpeechT5) if self.advanced_tts: try: audio_path = await self.advanced_tts.text_to_speech(text, voice_id) return audio_path, "Facebook VITS/SpeechT5" except Exception as advanced_error: logger.warning(f"Advanced TTS failed: {advanced_error}") # Fall back to robust TTS if self.robust_tts: try: logger.info("Falling back to robust TTS...") audio_path = await self.robust_tts.text_to_speech(text, voice_id) return audio_path, "Robust TTS (Fallback)" except Exception as robust_error: logger.error(f"Robust TTS also failed: {robust_error}") # If we get here, all methods failed logger.error("All TTS methods failed!") raise HTTPException( status_code=500, detail="All TTS methods failed. Please check system configuration." ) async def get_available_voices(self): """Get available voice configurations""" try: if self.advanced_tts and hasattr(self.advanced_tts, 'get_available_voices'): return await self.advanced_tts.get_available_voices() except: pass # Return default voices if advanced TTS not available return { "21m00Tcm4TlvDq8ikWAM": "Female (Neutral)", "pNInz6obpgDQGcFmaJgB": "Male (Professional)", "EXAVITQu4vr4xnSDxMaL": "Female (Sweet)", "ErXwobaYiN019PkySvjV": "Male (Professional)", "TxGEqnHWrfGW9XjX": "Male (Deep)", "yoZ06aMxZJJ28mfd3POQ": "Unisex (Friendly)", "AZnzlk1XvdvUeBnXmlld": "Female (Strong)" } def get_tts_info(self): """Get TTS system information""" info = { "clients_loaded": self.clients_loaded, "advanced_tts_available": self.advanced_tts is not None, "robust_tts_available": self.robust_tts is not None, "primary_method": "Robust TTS" } try: if self.advanced_tts and hasattr(self.advanced_tts, 'get_model_info'): advanced_info = self.advanced_tts.get_model_info() info.update({ "advanced_tts_loaded": advanced_info.get("models_loaded", False), "transformers_available": advanced_info.get("transformers_available", False), "primary_method": "Facebook VITS/SpeechT5" if advanced_info.get("models_loaded") else "Robust TTS", "device": advanced_info.get("device", "cpu"), "vits_available": advanced_info.get("vits_available", False), "speecht5_available": advanced_info.get("speecht5_available", False) }) except Exception as e: logger.debug(f"Could not get advanced TTS info: {e}") return info class OmniAvatarAPI: def __init__(self): self.model_loaded = False self.device = "cuda" if torch.cuda.is_available() else "cpu" self.tts_manager = TTSManager() logger.info(f"Using device: {self.device}") logger.info("Initialized with robust TTS system") def load_model(self): """Load the OmniAvatar model""" try: # Check if models are downloaded model_paths = [ "./pretrained_models/Wan2.1-T2V-14B", "./pretrained_models/OmniAvatar-14B", "./pretrained_models/wav2vec2-base-960h" ] for path in model_paths: if not os.path.exists(path): logger.error(f"Model path not found: {path}") return False self.model_loaded = True logger.info("Models loaded successfully") return True except Exception as e: logger.error(f"Error loading model: {str(e)}") return False async def download_file(self, url: str, suffix: str = "") -> str: """Download file from URL and save to temporary location""" try: async with aiohttp.ClientSession() as session: async with session.get(str(url)) as response: if response.status != 200: raise HTTPException(status_code=400, detail=f"Failed to download file from URL: {url}") content = await response.read() # Create temporary file temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=suffix) temp_file.write(content) temp_file.close() return temp_file.name except aiohttp.ClientError as e: logger.error(f"Network error downloading {url}: {e}") raise HTTPException(status_code=400, detail=f"Network error downloading file: {e}") except Exception as e: logger.error(f"Error downloading file from {url}: {e}") raise HTTPException(status_code=500, detail=f"Error downloading file: {e}") def validate_audio_url(self, url: str) -> bool: """Validate if URL is likely an audio file""" try: parsed = urlparse(url) # Check for common audio file extensions audio_extensions = ['.mp3', '.wav', '.m4a', '.ogg', '.aac', '.flac'] is_audio_ext = any(parsed.path.lower().endswith(ext) for ext in audio_extensions) return is_audio_ext or 'audio' in url.lower() except: return False def validate_image_url(self, url: str) -> bool: """Validate if URL is likely an image file""" try: parsed = urlparse(url) image_extensions = ['.jpg', '.jpeg', '.png', '.webp', '.bmp', '.gif'] return any(parsed.path.lower().endswith(ext) for ext in image_extensions) except: return False async def generate_avatar(self, request: GenerateRequest) -> tuple[str, float, bool, str]: """Generate avatar video from prompt and audio/text""" import time start_time = time.time() audio_generated = False tts_method = None try: # Determine audio source audio_path = None if request.text_to_speech: # Generate speech from text using TTS manager logger.info(f"Generating speech from text: {request.text_to_speech[:50]}...") audio_path, tts_method = await self.tts_manager.text_to_speech( request.text_to_speech, request.voice_id or "21m00Tcm4TlvDq8ikWAM" ) audio_generated = True elif request.audio_url: # Download audio from provided URL logger.info(f"Downloading audio from URL: {request.audio_url}") if not self.validate_audio_url(str(request.audio_url)): logger.warning(f"Audio URL may not be valid: {request.audio_url}") audio_path = await self.download_file(str(request.audio_url), ".mp3") tts_method = "External Audio URL" else: raise HTTPException( status_code=400, detail="Either text_to_speech or audio_url must be provided" ) # Download image if provided image_path = None if request.image_url: logger.info(f"Downloading image from URL: {request.image_url}") if not self.validate_image_url(str(request.image_url)): logger.warning(f"Image URL may not be valid: {request.image_url}") # Determine image extension from URL or default to .jpg parsed = urlparse(str(request.image_url)) ext = os.path.splitext(parsed.path)[1] or ".jpg" image_path = await self.download_file(str(request.image_url), ext) # Create temporary input file for inference with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f: if image_path: input_line = f"{request.prompt}@@{image_path}@@{audio_path}" else: input_line = f"{request.prompt}@@@@{audio_path}" f.write(input_line) temp_input_file = f.name # Prepare inference command cmd = [ "python", "-m", "torch.distributed.run", "--standalone", f"--nproc_per_node={request.sp_size}", "scripts/inference.py", "--config", "configs/inference.yaml", "--input_file", temp_input_file, "--guidance_scale", str(request.guidance_scale), "--audio_scale", str(request.audio_scale), "--num_steps", str(request.num_steps) ] if request.tea_cache_l1_thresh: cmd.extend(["--tea_cache_l1_thresh", str(request.tea_cache_l1_thresh)]) logger.info(f"Running inference with command: {' '.join(cmd)}") # Run inference result = subprocess.run(cmd, capture_output=True, text=True) # Clean up temporary files os.unlink(temp_input_file) os.unlink(audio_path) if image_path: os.unlink(image_path) if result.returncode != 0: logger.error(f"Inference failed: {result.stderr}") raise Exception(f"Inference failed: {result.stderr}") # Find output video file output_dir = "./outputs" if os.path.exists(output_dir): video_files = [f for f in os.listdir(output_dir) if f.endswith(('.mp4', '.avi'))] if video_files: # Return the most recent video file video_files.sort(key=lambda x: os.path.getmtime(os.path.join(output_dir, x)), reverse=True) output_path = os.path.join(output_dir, video_files[0]) processing_time = time.time() - start_time return output_path, processing_time, audio_generated, tts_method raise Exception("No output video generated") except Exception as e: # Clean up any temporary files in case of error try: if 'audio_path' in locals() and audio_path and os.path.exists(audio_path): os.unlink(audio_path) if 'image_path' in locals() and image_path and os.path.exists(image_path): os.unlink(image_path) if 'temp_input_file' in locals() and os.path.exists(temp_input_file): os.unlink(temp_input_file) except: pass logger.error(f"Generation error: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) # Initialize API omni_api = OmniAvatarAPI() @app.on_event("startup") async def startup_event(): """Load models on startup""" success = omni_api.load_model() if not success: logger.warning("OmniAvatar model loading failed on startup") # Load TTS models try: await omni_api.tts_manager.load_models() logger.info("TTS models initialization completed") except Exception as e: logger.error(f"TTS initialization failed: {e}") @app.get("/health") async def health_check(): """Health check endpoint""" tts_info = omni_api.tts_manager.get_tts_info() return { "status": "healthy", "model_loaded": omni_api.model_loaded, "device": omni_api.device, "supports_text_to_speech": True, "supports_image_urls": True, "supports_audio_urls": True, "tts_system": "Advanced TTS with Robust Fallback", "advanced_tts_available": ADVANCED_TTS_AVAILABLE, "robust_tts_available": ROBUST_TTS_AVAILABLE, **tts_info } @app.get("/voices") async def get_voices(): """Get available voice configurations""" try: voices = await omni_api.tts_manager.get_available_voices() return {"voices": voices} except Exception as e: logger.error(f"Error getting voices: {e}") return {"error": str(e)} @app.post("/generate", response_model=GenerateResponse) async def generate_avatar(request: GenerateRequest): """Generate avatar video from prompt, text/audio, and optional image URL""" if not omni_api.model_loaded: raise HTTPException(status_code=503, detail="Model not loaded") logger.info(f"Generating avatar with prompt: {request.prompt}") if request.text_to_speech: logger.info(f"Text to speech: {request.text_to_speech[:100]}...") logger.info(f"Voice ID: {request.voice_id}") if request.audio_url: logger.info(f"Audio URL: {request.audio_url}") if request.image_url: logger.info(f"Image URL: {request.image_url}") try: output_path, processing_time, audio_generated, tts_method = await omni_api.generate_avatar(request) return GenerateResponse( message="Avatar generation completed successfully", output_path=get_video_url(output_path), processing_time=processing_time, audio_generated=audio_generated, tts_method=tts_method ) except HTTPException: raise except Exception as e: logger.error(f"Unexpected error: {e}") raise HTTPException(status_code=500, detail=f"Unexpected error: {e}") # Enhanced Gradio interface def gradio_generate(prompt, text_to_speech, audio_url, image_url, voice_id, guidance_scale, audio_scale, num_steps): """Gradio interface wrapper with robust TTS support""" if not omni_api.model_loaded: return "Error: Model not loaded" try: # Create request object request_data = { "prompt": prompt, "guidance_scale": guidance_scale, "audio_scale": audio_scale, "num_steps": int(num_steps) } # Add audio source if text_to_speech and text_to_speech.strip(): request_data["text_to_speech"] = text_to_speech request_data["voice_id"] = voice_id or "21m00Tcm4TlvDq8ikWAM" elif audio_url and audio_url.strip(): request_data["audio_url"] = audio_url else: return "Error: Please provide either text to speech or audio URL" if image_url and image_url.strip(): request_data["image_url"] = image_url request = GenerateRequest(**request_data) # Run async function in sync context loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) output_path, processing_time, audio_generated, tts_method = loop.run_until_complete(omni_api.generate_avatar(request)) loop.close() success_message = f"✅ Generation completed in {processing_time:.1f}s using {tts_method}" print(success_message) return output_path except Exception as e: logger.error(f"Gradio generation error: {e}") return f"Error: {str(e)}" # Gradio interface iface = gr.Interface( fn=gradio_generate, inputs=[ gr.Textbox( label="Prompt", placeholder="Describe the character behavior (e.g., 'A friendly person explaining a concept')", lines=2 ), gr.Textbox( label="Text to Speech", placeholder="Enter text to convert to speech", lines=3, info="Will use best available TTS system (Advanced or Fallback)" ), gr.Textbox( label="OR Audio URL", placeholder="https://example.com/audio.mp3", info="Direct URL to audio file (alternative to text-to-speech)" ), gr.Textbox( label="Image URL (Optional)", placeholder="https://example.com/image.jpg", info="Direct URL to reference image (JPG, PNG, etc.)" ), gr.Dropdown( choices=[ "21m00Tcm4TlvDq8ikWAM", "pNInz6obpgDQGcFmaJgB", "EXAVITQu4vr4xnSDxMaL", "ErXwobaYiN019PkySvjV", "TxGEqnHWrfGW9XjX", "yoZ06aMxZJJ28mfd3POQ", "AZnzlk1XvdvUeBnXmlld" ], value="21m00Tcm4TlvDq8ikWAM", label="Voice Profile", info="Choose voice characteristics for TTS generation" ), gr.Slider(minimum=1, maximum=10, value=5.0, label="Guidance Scale", info="4-6 recommended"), gr.Slider(minimum=1, maximum=10, value=3.0, label="Audio Scale", info="Higher values = better lip-sync"), gr.Slider(minimum=10, maximum=100, value=30, step=1, label="Number of Steps", info="20-50 recommended") ], outputs=gr.Video(label="Generated Avatar Video"), title="🎭 OmniAvatar-14B with Advanced TTS System", description=""" Generate avatar videos with lip-sync from text prompts and speech using robust TTS system. **🔧 Robust TTS Architecture** - 🤖 **Primary**: Advanced TTS (Facebook VITS & SpeechT5) if available - 🔄 **Fallback**: Robust tone generation for 100% reliability - ⚡ **Automatic**: Seamless switching between methods **Features:** - ✅ **Guaranteed Generation**: Always produces audio output - ✅ **No Dependencies**: Works even without advanced models - ✅ **High Availability**: Multiple fallback layers - ✅ **Voice Profiles**: Multiple voice characteristics - ✅ **Audio URL Support**: Use external audio files - ✅ **Image URL Support**: Reference images for characters **Usage:** 1. Enter a character description in the prompt 2. **Either** enter text for speech generation **OR** provide an audio URL 3. Optionally add a reference image URL 4. Choose voice profile and adjust parameters 5. Generate your avatar video! **System Status:** - The system will automatically use the best available TTS method - If advanced models are available, you'll get high-quality speech - If not, robust fallback ensures the system always works """, examples=[ [ "A professional teacher explaining a mathematical concept with clear gestures", "Hello students! Today we're going to learn about calculus and derivatives.", "", "", "21m00Tcm4TlvDq8ikWAM", 5.0, 3.5, 30 ], [ "A friendly presenter speaking confidently to an audience", "Welcome everyone to our presentation on artificial intelligence!", "", "", "pNInz6obpgDQGcFmaJgB", 5.5, 4.0, 35 ] ] ) # Mount Gradio app app = gr.mount_gradio_app(app, iface, path="/gradio") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)