Developer commited on
Commit
447c811
Β·
1 Parent(s): 84c5aae

πŸš€ NUCLEAR REBUILD: Force HF Spaces to rebuild completely

Browse files

πŸ”₯ AGGRESSIVE CACHE BUSTING:
- Updated Dockerfile timestamp to 2025-08-08_02-37-55-NUCLEAR-REBUILD
- Updated requirements.txt with new rebuild timestamp
- Updated README.md with build status and timestamp
- Added .rebuild_trigger file to force complete rebuild
- Updated .spacesrc with force_rebuild and ignore_cache flags

⚑ WHY THIS SHOULD WORK:
- Multiple timestamp changes across critical build files
- New .rebuild_trigger file that HF Spaces has never seen
- Updated .spacesrc with aggressive rebuild flags
- Changed Dockerfile first line (breaks ALL Docker cache layers)

🎯 GOAL: Force HF Spaces to:
1. Ignore all previous Docker cache layers
2. Rebuild container from scratch
3. Load app.py with model loading fix
4. Enable video generation with downloaded models

This is the most aggressive rebuild approach possible.

.rebuild_trigger ADDED
Binary file (582 Bytes). View file
 
.spacesrc CHANGED
Binary files a/.spacesrc and b/.spacesrc differ
 
Dockerfile CHANGED
Binary files a/Dockerfile and b/Dockerfile differ
 
README.md CHANGED
@@ -12,10 +12,15 @@ short_description: Avatar video generation with model downloads
12
 
13
  # AI Avatar Chat - Video Generation
14
 
15
- **Build: 2025-08-08_02-16-VIDEO-ENABLED - Fixed Model Download & Enabled Video Generation**
16
 
17
  Real avatar video generation with downloadable models for HF Spaces.
18
 
 
 
 
 
 
19
  ## Features:
20
 
21
  - βœ… Video generation ENABLED in HF Spaces
@@ -24,26 +29,7 @@ Real avatar video generation with downloadable models for HF Spaces.
24
  - 🎬 Text-to-video generation using downloaded models
25
  - πŸ”Œ API endpoints for programmatic access
26
 
27
- ## Quick Start:
28
-
29
- 1. **Download Models First**: Use `/download-models` API endpoint or web interface
30
- 2. **Wait for Download**: Models are ~30GB and will take time to download
31
- 3. **Generate Videos**: Once models are downloaded, video generation will work
32
- 4. **Check Status**: Use `/model-status` endpoint to verify models are loaded
33
-
34
- ## API Usage:
35
-
36
- ```bash
37
- # Download models first
38
- curl -X POST https://bravedims-ai-avatar-chat.hf.space/download-models
39
-
40
- # Check model status
41
- curl https://bravedims-ai-avatar-chat.hf.space/model-status
42
-
43
- # Generate video (after models are downloaded)
44
- curl -X POST https://bravedims-ai-avatar-chat.hf.space/generate \
45
- -H "Content-Type: application/json" \
46
- -d '{"text": "Hello world", "image_url": "https://example.com/image.jpg"}'
47
- ```
48
 
49
- **Status**: βœ… Video generation ENABLED - Download models to start generating videos!
 
12
 
13
  # AI Avatar Chat - Video Generation
14
 
15
+ **Build: 2025-08-08_02-37-55-NUCLEAR-REBUILD - Model Loading FIXED & Forced Rebuild**
16
 
17
  Real avatar video generation with downloadable models for HF Spaces.
18
 
19
+ ## πŸ”₯ LATEST FIX:
20
+ - βœ… Fixed model loading to check downloaded_models/ directory
21
+ - βœ… Video generation will work after models are downloaded
22
+ - βœ… Forced complete Docker cache bust
23
+
24
  ## Features:
25
 
26
  - βœ… Video generation ENABLED in HF Spaces
 
29
  - 🎬 Text-to-video generation using downloaded models
30
  - πŸ”Œ API endpoints for programmatic access
31
 
32
+ ## Status:
33
+ **βœ… READY - Models downloaded, loading logic fixed, forced rebuild**
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
+ **Status**: Model loading FIXED - Should work after rebuild!
app.py.backup_before_model_fix ADDED
@@ -0,0 +1,985 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AI Avatar Chat - HF Spaces Optimized Version
3
+ BUILD: 2025-01-08_00-44-FORCE-REBUILD - With Model Download Controls
4
+ FEATURES: Real video generation, model download UI, storage optimization
5
+ """
6
+ import os
7
+
8
+ # STORAGE OPTIMIZATION: Check if running on HF Spaces and disable model downloads
9
+ IS_HF_SPACE = any([
10
+ os.getenv("SPACE_ID"),
11
+ os.getenv("SPACE_AUTHOR_NAME"),
12
+ os.getenv("SPACES_BUILDKIT_VERSION"),
13
+ "/home/user/app" in os.getcwd()
14
+ ])
15
+
16
+ if IS_HF_SPACE:
17
+ # Force TTS-only mode to prevent storage limit exceeded
18
+ # os.environ[\"DISABLE_MODEL_DOWNLOAD\"] = \"1\" # ENABLED FOR VIDEO GENERATION
19
+ # os.environ[\"TTS_ONLY_MODE\"] = \"1\" # ENABLED FOR VIDEO GENERATION
20
+ os.environ["HF_SPACE_STORAGE_OPTIMIZED"] = "1"
21
+ print("?? STORAGE OPTIMIZATION: Detected HF Space environment")
22
+ print("?? Video generation ENABLED (models need manual download)")
23
+ print("?? WARNING: Use /download-models endpoint to download ~30GB models first")
24
+ import os
25
+ import torch
26
+ import tempfile
27
+ import gradio as gr
28
+ from fastapi import FastAPI, HTTPException
29
+ from fastapi.staticfiles import StaticFiles
30
+ from fastapi.middleware.cors import CORSMiddleware
31
+ from pydantic import BaseModel, HttpUrl
32
+ import subprocess
33
+ import json
34
+ from pathlib import Path
35
+ import logging
36
+ import requests
37
+ from urllib.parse import urlparse
38
+ from PIL import Image
39
+ import io
40
+ from typing import Optional
41
+ import aiohttp
42
+ import asyncio
43
+ # Safe dotenv import
44
+ try:
45
+ from dotenv import load_dotenv
46
+ load_dotenv()
47
+ except ImportError:
48
+ print("Warning: python-dotenv not found, continuing without .env support")
49
+ def load_dotenv():
50
+ pass
51
+
52
+ # CRITICAL: HF Spaces compatibility fix
53
+ try:
54
+ from hf_spaces_fix import setup_hf_spaces_environment, HFSpacesCompatible
55
+ setup_hf_spaces_environment()
56
+ except ImportError:
57
+ print('Warning: HF Spaces fix not available')
58
+
59
+ # Load environment variables
60
+ load_dotenv()
61
+
62
+ # Set up logging
63
+ logging.basicConfig(level=logging.INFO)
64
+ logger = logging.getLogger(__name__)
65
+
66
+ # Set environment variables for matplotlib, gradio, and huggingface cache
67
+ os.environ['MPLCONFIGDIR'] = '/tmp/matplotlib'
68
+ os.environ['GRADIO_ALLOW_FLAGGING'] = 'never'
69
+ os.environ['HF_HOME'] = '/tmp/huggingface'
70
+ # Use HF_HOME instead of deprecated TRANSFORMERS_CACHE
71
+ os.environ['HF_DATASETS_CACHE'] = '/tmp/huggingface/datasets'
72
+ os.environ['HUGGINGFACE_HUB_CACHE'] = '/tmp/huggingface/hub'
73
+
74
+ # FastAPI app will be created after lifespan is defined
75
+
76
+
77
+
78
+ # Create directories with proper permissions
79
+ os.makedirs("outputs", exist_ok=True)
80
+ os.makedirs("/tmp/matplotlib", exist_ok=True)
81
+ os.makedirs("/tmp/huggingface", exist_ok=True)
82
+ os.makedirs("/tmp/huggingface/transformers", exist_ok=True)
83
+ os.makedirs("/tmp/huggingface/datasets", exist_ok=True)
84
+ os.makedirs("/tmp/huggingface/hub", exist_ok=True)
85
+
86
+ # Mount static files for serving generated videos
87
+
88
+
89
+ def get_video_url(output_path: str) -> str:
90
+ """Convert local file path to accessible URL"""
91
+ try:
92
+ from pathlib import Path
93
+ filename = Path(output_path).name
94
+
95
+ # For HuggingFace Spaces, construct the URL
96
+ base_url = "https://bravedims-ai-avatar-chat.hf.space"
97
+ video_url = f"{base_url}/outputs/{filename}"
98
+ logger.info(f"Generated video URL: {video_url}")
99
+ return video_url
100
+ except Exception as e:
101
+ logger.error(f"Error creating video URL: {e}")
102
+ return output_path # Fallback to original path
103
+
104
+ # Pydantic models for request/response
105
+ class GenerateRequest(BaseModel):
106
+ prompt: str
107
+ text_to_speech: Optional[str] = None # Text to convert to speech
108
+ audio_url: Optional[HttpUrl] = None # Direct audio URL
109
+ voice_id: Optional[str] = "21m00Tcm4TlvDq8ikWAM" # Voice profile ID
110
+ image_url: Optional[HttpUrl] = None
111
+ guidance_scale: float = 5.0
112
+ audio_scale: float = 3.0
113
+ num_steps: int = 30
114
+ sp_size: int = 1
115
+ tea_cache_l1_thresh: Optional[float] = None
116
+
117
+ class GenerateResponse(BaseModel):
118
+ message: str
119
+ output_path: str
120
+ processing_time: float
121
+ audio_generated: bool = False
122
+ tts_method: Optional[str] = None
123
+
124
+ # Try to import TTS clients, but make them optional
125
+ try:
126
+ from advanced_tts_client import AdvancedTTSClient
127
+ ADVANCED_TTS_AVAILABLE = True
128
+ logger.info("SUCCESS: Advanced TTS client available")
129
+ except ImportError as e:
130
+ ADVANCED_TTS_AVAILABLE = False
131
+ logger.warning(f"WARNING: Advanced TTS client not available: {e}")
132
+
133
+ # Always import the robust fallback
134
+ try:
135
+ from robust_tts_client import RobustTTSClient
136
+ ROBUST_TTS_AVAILABLE = True
137
+ logger.info("SUCCESS: Robust TTS client available")
138
+ except ImportError as e:
139
+ ROBUST_TTS_AVAILABLE = False
140
+ logger.error(f"ERROR: Robust TTS client not available: {e}")
141
+
142
+ class TTSManager:
143
+ """Manages multiple TTS clients with fallback chain"""
144
+
145
+ def __init__(self):
146
+ # Initialize TTS clients based on availability
147
+ self.advanced_tts = None
148
+ self.robust_tts = None
149
+ self.clients_loaded = False
150
+
151
+ if ADVANCED_TTS_AVAILABLE:
152
+ try:
153
+ self.advanced_tts = AdvancedTTSClient()
154
+ logger.info("SUCCESS: Advanced TTS client initialized")
155
+ except Exception as e:
156
+ logger.warning(f"WARNING: Advanced TTS client initialization failed: {e}")
157
+
158
+ if ROBUST_TTS_AVAILABLE:
159
+ try:
160
+ self.robust_tts = RobustTTSClient()
161
+ logger.info("SUCCESS: Robust TTS client initialized")
162
+ except Exception as e:
163
+ logger.error(f"ERROR: Robust TTS client initialization failed: {e}")
164
+
165
+ if not self.advanced_tts and not self.robust_tts:
166
+ logger.error("ERROR: No TTS clients available!")
167
+
168
+ async def load_models(self):
169
+ """Load TTS models"""
170
+ try:
171
+ logger.info("Loading TTS models...")
172
+
173
+ # Try to load advanced TTS first
174
+ if self.advanced_tts:
175
+ try:
176
+ logger.info("[PROCESS] Loading advanced TTS models (this may take a few minutes)...")
177
+ success = await self.advanced_tts.load_models()
178
+ if success:
179
+ logger.info("SUCCESS: Advanced TTS models loaded successfully")
180
+ else:
181
+ logger.warning("WARNING: Advanced TTS models failed to load")
182
+ except Exception as e:
183
+ logger.warning(f"WARNING: Advanced TTS loading error: {e}")
184
+
185
+ # Always ensure robust TTS is available
186
+ if self.robust_tts:
187
+ try:
188
+ await self.robust_tts.load_model()
189
+ logger.info("SUCCESS: Robust TTS fallback ready")
190
+ except Exception as e:
191
+ logger.error(f"ERROR: Robust TTS loading failed: {e}")
192
+
193
+ self.clients_loaded = True
194
+ return True
195
+
196
+ except Exception as e:
197
+ logger.error(f"ERROR: TTS manager initialization failed: {e}")
198
+ return False
199
+
200
+ async def text_to_speech(self, text: str, voice_id: Optional[str] = None) -> tuple[str, str]:
201
+ """
202
+ Convert text to speech with fallback chain
203
+ Returns: (audio_file_path, method_used)
204
+ """
205
+ if not self.clients_loaded:
206
+ logger.info("TTS models not loaded, loading now...")
207
+ await self.load_models()
208
+
209
+ logger.info(f"Generating speech: {text[:50]}...")
210
+ logger.info(f"Voice ID: {voice_id}")
211
+
212
+ # Try Advanced TTS first (Facebook VITS / SpeechT5)
213
+ if self.advanced_tts:
214
+ try:
215
+ audio_path = await self.advanced_tts.text_to_speech(text, voice_id)
216
+ return audio_path, "Facebook VITS/SpeechT5"
217
+ except Exception as advanced_error:
218
+ logger.warning(f"Advanced TTS failed: {advanced_error}")
219
+
220
+ # Fall back to robust TTS
221
+ if self.robust_tts:
222
+ try:
223
+ logger.info("Falling back to robust TTS...")
224
+ audio_path = await self.robust_tts.text_to_speech(text, voice_id)
225
+ return audio_path, "Robust TTS (Fallback)"
226
+ except Exception as robust_error:
227
+ logger.error(f"Robust TTS also failed: {robust_error}")
228
+
229
+ # If we get here, all methods failed
230
+ logger.error("All TTS methods failed!")
231
+ raise HTTPException(
232
+ status_code=500,
233
+ detail="All TTS methods failed. Please check system configuration."
234
+ )
235
+
236
+ async def get_available_voices(self):
237
+ """Get available voice configurations"""
238
+ try:
239
+ if self.advanced_tts and hasattr(self.advanced_tts, 'get_available_voices'):
240
+ return await self.advanced_tts.get_available_voices()
241
+ except:
242
+ pass
243
+
244
+ # Return default voices if advanced TTS not available
245
+ return {
246
+ "21m00Tcm4TlvDq8ikWAM": "Female (Neutral)",
247
+ "pNInz6obpgDQGcFmaJgB": "Male (Professional)",
248
+ "EXAVITQu4vr4xnSDxMaL": "Female (Sweet)",
249
+ "ErXwobaYiN019PkySvjV": "Male (Professional)",
250
+ "TxGEqnHWrfGW9XjX": "Male (Deep)",
251
+ "yoZ06aMxZJJ28mfd3POQ": "Unisex (Friendly)",
252
+ "AZnzlk1XvdvUeBnXmlld": "Female (Strong)"
253
+ }
254
+
255
+ def get_tts_info(self):
256
+ """Get TTS system information"""
257
+ info = {
258
+ "clients_loaded": self.clients_loaded,
259
+ "advanced_tts_available": self.advanced_tts is not None,
260
+ "robust_tts_available": self.robust_tts is not None,
261
+ "primary_method": "Robust TTS"
262
+ }
263
+
264
+ try:
265
+ if self.advanced_tts and hasattr(self.advanced_tts, 'get_model_info'):
266
+ advanced_info = self.advanced_tts.get_model_info()
267
+ info.update({
268
+ "advanced_tts_loaded": advanced_info.get("models_loaded", False),
269
+ "transformers_available": advanced_info.get("transformers_available", False),
270
+ "primary_method": "Facebook VITS/SpeechT5" if advanced_info.get("models_loaded") else "Robust TTS",
271
+ "device": advanced_info.get("device", "cpu"),
272
+ "vits_available": advanced_info.get("vits_available", False),
273
+ "speecht5_available": advanced_info.get("speecht5_available", False)
274
+ })
275
+ except Exception as e:
276
+ logger.debug(f"Could not get advanced TTS info: {e}")
277
+
278
+ return info
279
+
280
+ # Import the VIDEO-FOCUSED engine
281
+ try:
282
+ from omniavatar_video_engine import video_engine
283
+ VIDEO_ENGINE_AVAILABLE = True
284
+ logger.info("SUCCESS: OmniAvatar Video Engine available")
285
+ except ImportError as e:
286
+ VIDEO_ENGINE_AVAILABLE = False
287
+ logger.error(f"ERROR: OmniAvatar Video Engine not available: {e}")
288
+
289
+ class OmniAvatarAPI:
290
+ def __init__(self):
291
+ self.model_loaded = False
292
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
293
+ self.tts_manager = TTSManager()
294
+ logger.info(f"Using device: {self.device}")
295
+ logger.info("Initialized with robust TTS system")
296
+
297
+ def load_model(self):
298
+ """Load the OmniAvatar model - now more flexible"""
299
+ try:
300
+ # Check if models are downloaded (but don't require them)
301
+ # Check both traditional and downloaded model paths
302
+ downloaded_video = "./downloaded_models/video"
303
+ downloaded_audio = "./downloaded_models/audio"
304
+
305
+ # Check downloaded models first
306
+ if os.path.exists(downloaded_video) and os.path.exists(downloaded_audio):
307
+ video_files = len([f for f in os.listdir(downloaded_video) if os.path.isfile(os.path.join(downloaded_video, f))]) if os.path.isdir(downloaded_video) else 0
308
+ audio_files = len([f for f in os.listdir(downloaded_audio) if os.path.isfile(os.path.join(downloaded_audio, f))]) if os.path.isdir(downloaded_audio) else 0
309
+ if video_files > 5 and audio_files > 5:
310
+ missing_models.append(path)
311
+
312
+ if missing_models:
313
+ logger.warning("WARNING: Some OmniAvatar models not found:")
314
+ for model in missing_models:
315
+ logger.warning(f" - {model}")
316
+ logger.info("TIP: App will run in TTS-only mode (no video generation)")
317
+ logger.info("TIP: To enable full avatar generation, download the required models")
318
+
319
+ # Set as loaded but in limited mode
320
+ self.model_loaded = False # Video generation disabled
321
+ return True # But app can still run
322
+ else:
323
+ self.model_loaded = True
324
+ logger.info("SUCCESS: All OmniAvatar models found - full functionality enabled")
325
+ return True
326
+
327
+ except Exception as e:
328
+ logger.error(f"Error checking models: {str(e)}")
329
+ logger.info("TIP: Continuing in TTS-only mode")
330
+ self.model_loaded = False
331
+ return True # Continue running
332
+
333
+ async def download_file(self, url: str, suffix: str = "") -> str:
334
+ """Download file from URL and save to temporary location"""
335
+ try:
336
+ async with aiohttp.ClientSession() as session:
337
+ async with session.get(str(url)) as response:
338
+ if response.status != 200:
339
+ raise HTTPException(status_code=400, detail=f"Failed to download file from URL: {url}")
340
+
341
+ content = await response.read()
342
+
343
+ # Create temporary file
344
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=suffix)
345
+ temp_file.write(content)
346
+ temp_file.close()
347
+
348
+ return temp_file.name
349
+
350
+ except aiohttp.ClientError as e:
351
+ logger.error(f"Network error downloading {url}: {e}")
352
+ raise HTTPException(status_code=400, detail=f"Network error downloading file: {e}")
353
+ except Exception as e:
354
+ logger.error(f"Error downloading file from {url}: {e}")
355
+ raise HTTPException(status_code=500, detail=f"Error downloading file: {e}")
356
+
357
+ def validate_audio_url(self, url: str) -> bool:
358
+ """Validate if URL is likely an audio file"""
359
+ try:
360
+ parsed = urlparse(url)
361
+ # Check for common audio file extensions
362
+ audio_extensions = ['.mp3', '.wav', '.m4a', '.ogg', '.aac', '.flac']
363
+ is_audio_ext = any(parsed.path.lower().endswith(ext) for ext in audio_extensions)
364
+
365
+ return is_audio_ext or 'audio' in url.lower()
366
+ except:
367
+ return False
368
+
369
+ def validate_image_url(self, url: str) -> bool:
370
+ """Validate if URL is likely an image file"""
371
+ try:
372
+ parsed = urlparse(url)
373
+ image_extensions = ['.jpg', '.jpeg', '.png', '.webp', '.bmp', '.gif']
374
+ return any(parsed.path.lower().endswith(ext) for ext in image_extensions)
375
+ except:
376
+ return False
377
+
378
+ async def generate_avatar(self, request: GenerateRequest) -> tuple[str, float, bool, str]:
379
+ """Generate avatar VIDEO - PRIMARY FUNCTIONALITY"""
380
+ import time
381
+ start_time = time.time()
382
+ audio_generated = False
383
+ method_used = "Unknown"
384
+
385
+ logger.info("[VIDEO] STARTING AVATAR VIDEO GENERATION")
386
+ logger.info(f"[INFO] Prompt: {request.prompt}")
387
+
388
+ if VIDEO_ENGINE_AVAILABLE:
389
+ try:
390
+ # PRIORITIZE VIDEO GENERATION
391
+ logger.info("[TARGET] Using OmniAvatar Video Engine for FULL video generation")
392
+
393
+ # Handle audio source
394
+ audio_path = None
395
+ if request.text_to_speech:
396
+ logger.info("[MIC] Generating audio from text...")
397
+ audio_path, method_used = await self.tts_manager.text_to_speech(
398
+ request.text_to_speech,
399
+ request.voice_id or "21m00Tcm4TlvDq8ikWAM"
400
+ )
401
+ audio_generated = True
402
+ elif request.audio_url:
403
+ logger.info("πŸ“₯ Downloading audio from URL...")
404
+ audio_path = await self.download_file(str(request.audio_url), ".mp3")
405
+ method_used = "External Audio"
406
+ else:
407
+ raise HTTPException(status_code=400, detail="Either text_to_speech or audio_url required for video generation")
408
+
409
+ # Handle image if provided
410
+ image_path = None
411
+ if request.image_url:
412
+ logger.info("[IMAGE] Downloading reference image...")
413
+ parsed = urlparse(str(request.image_url))
414
+ ext = os.path.splitext(parsed.path)[1] or ".jpg"
415
+ image_path = await self.download_file(str(request.image_url), ext)
416
+
417
+ # GENERATE VIDEO using OmniAvatar engine
418
+ logger.info("[VIDEO] Generating avatar video with adaptive body animation...")
419
+ video_path, generation_time = video_engine.generate_avatar_video(
420
+ prompt=request.prompt,
421
+ audio_path=audio_path,
422
+ image_path=image_path,
423
+ guidance_scale=request.guidance_scale,
424
+ audio_scale=request.audio_scale,
425
+ num_steps=request.num_steps
426
+ )
427
+
428
+ processing_time = time.time() - start_time
429
+ logger.info(f"SUCCESS: VIDEO GENERATED successfully in {processing_time:.1f}s")
430
+
431
+ # Cleanup temporary files
432
+ if audio_path and os.path.exists(audio_path):
433
+ os.unlink(audio_path)
434
+ if image_path and os.path.exists(image_path):
435
+ os.unlink(image_path)
436
+
437
+ return video_path, processing_time, audio_generated, f"OmniAvatar Video Generation ({method_used})"
438
+
439
+ except Exception as e:
440
+ logger.error(f"ERROR: Video generation failed: {e}")
441
+ # For a VIDEO generation app, we should NOT fall back to audio-only
442
+ # Instead, provide clear guidance
443
+ if "models" in str(e).lower():
444
+ raise HTTPException(
445
+ status_code=503,
446
+ detail=f"Video generation requires OmniAvatar models (~30GB). Please run model download script. Error: {str(e)}"
447
+ )
448
+ else:
449
+ raise HTTPException(status_code=500, detail=f"Video generation failed: {str(e)}")
450
+
451
+ # If video engine not available, this is a critical error for a VIDEO app
452
+ raise HTTPException(
453
+ status_code=503,
454
+ detail="Video generation engine not available. This application requires OmniAvatar models for video generation."
455
+ )
456
+
457
+ async def generate_avatar_BACKUP(self, request: GenerateRequest) -> tuple[str, float, bool, str]:
458
+ """OLD TTS-ONLY METHOD - kept as backup reference.
459
+ Generate avatar video from prompt and audio/text - now handles missing models"""
460
+ import time
461
+ start_time = time.time()
462
+ audio_generated = False
463
+ tts_method = None
464
+
465
+ try:
466
+ # Check if video generation is available
467
+ if not self.model_loaded:
468
+ logger.info("πŸŽ™οΈ Running in TTS-only mode (OmniAvatar models not available)")
469
+
470
+ # Only generate audio, no video
471
+ if request.text_to_speech:
472
+ logger.info(f"Generating speech from text: {request.text_to_speech[:50]}...")
473
+ audio_path, tts_method = await self.tts_manager.text_to_speech(
474
+ request.text_to_speech,
475
+ request.voice_id or "21m00Tcm4TlvDq8ikWAM"
476
+ )
477
+
478
+ # Return the audio file as the "output"
479
+ processing_time = time.time() - start_time
480
+ logger.info(f"SUCCESS: TTS completed in {processing_time:.1f}s using {tts_method}")
481
+ return audio_path, processing_time, True, f"{tts_method} (TTS-only mode)"
482
+ else:
483
+ raise HTTPException(
484
+ status_code=503,
485
+ detail="Video generation unavailable. OmniAvatar models not found. Only TTS from text is supported."
486
+ )
487
+
488
+ # Original video generation logic (when models are available)
489
+ # Determine audio source
490
+ audio_path = None
491
+
492
+ if request.text_to_speech:
493
+ # Generate speech from text using TTS manager
494
+ logger.info(f"Generating speech from text: {request.text_to_speech[:50]}...")
495
+ audio_path, tts_method = await self.tts_manager.text_to_speech(
496
+ request.text_to_speech,
497
+ request.voice_id or "21m00Tcm4TlvDq8ikWAM"
498
+ )
499
+ audio_generated = True
500
+
501
+ elif request.audio_url:
502
+ # Download audio from provided URL
503
+ logger.info(f"Downloading audio from URL: {request.audio_url}")
504
+ if not self.validate_audio_url(str(request.audio_url)):
505
+ logger.warning(f"Audio URL may not be valid: {request.audio_url}")
506
+
507
+ audio_path = await self.download_file(str(request.audio_url), ".mp3")
508
+ tts_method = "External Audio URL"
509
+
510
+ else:
511
+ raise HTTPException(
512
+ status_code=400,
513
+ detail="Either text_to_speech or audio_url must be provided"
514
+ )
515
+
516
+ # Download image if provided
517
+ image_path = None
518
+ if request.image_url:
519
+ logger.info(f"Downloading image from URL: {request.image_url}")
520
+ if not self.validate_image_url(str(request.image_url)):
521
+ logger.warning(f"Image URL may not be valid: {request.image_url}")
522
+
523
+ # Determine image extension from URL or default to .jpg
524
+ parsed = urlparse(str(request.image_url))
525
+ ext = os.path.splitext(parsed.path)[1] or ".jpg"
526
+ image_path = await self.download_file(str(request.image_url), ext)
527
+
528
+ # Create temporary input file for inference
529
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f:
530
+ if image_path:
531
+ input_line = f"{request.prompt}@@{image_path}@@{audio_path}"
532
+ else:
533
+ input_line = f"{request.prompt}@@@@{audio_path}"
534
+ f.write(input_line)
535
+ temp_input_file = f.name
536
+
537
+ # Prepare inference command
538
+ cmd = [
539
+ "python", "-m", "torch.distributed.run",
540
+ "--standalone", f"--nproc_per_node={request.sp_size}",
541
+ "scripts/inference.py",
542
+ "--config", "configs/inference.yaml",
543
+ "--input_file", temp_input_file,
544
+ "--guidance_scale", str(request.guidance_scale),
545
+ "--audio_scale", str(request.audio_scale),
546
+ "--num_steps", str(request.num_steps)
547
+ ]
548
+
549
+ if request.tea_cache_l1_thresh:
550
+ cmd.extend(["--tea_cache_l1_thresh", str(request.tea_cache_l1_thresh)])
551
+
552
+ logger.info(f"Running inference with command: {' '.join(cmd)}")
553
+
554
+ # Run inference
555
+ result = subprocess.run(cmd, capture_output=True, text=True)
556
+
557
+ # Clean up temporary files
558
+ os.unlink(temp_input_file)
559
+ os.unlink(audio_path)
560
+ if image_path:
561
+ os.unlink(image_path)
562
+
563
+ if result.returncode != 0:
564
+ logger.error(f"Inference failed: {result.stderr}")
565
+ raise Exception(f"Inference failed: {result.stderr}")
566
+
567
+ # Find output video file
568
+ output_dir = "./outputs"
569
+ if os.path.exists(output_dir):
570
+ video_files = [f for f in os.listdir(output_dir) if f.endswith(('.mp4', '.avi'))]
571
+ if video_files:
572
+ # Return the most recent video file
573
+ video_files.sort(key=lambda x: os.path.getmtime(os.path.join(output_dir, x)), reverse=True)
574
+ output_path = os.path.join(output_dir, video_files[0])
575
+ processing_time = time.time() - start_time
576
+ return output_path, processing_time, audio_generated, tts_method
577
+
578
+ raise Exception("No output video generated")
579
+
580
+ except Exception as e:
581
+ # Clean up any temporary files in case of error
582
+ try:
583
+ if 'audio_path' in locals() and audio_path and os.path.exists(audio_path):
584
+ os.unlink(audio_path)
585
+ if 'image_path' in locals() and image_path and os.path.exists(image_path):
586
+ os.unlink(image_path)
587
+ if 'temp_input_file' in locals() and os.path.exists(temp_input_file):
588
+ os.unlink(temp_input_file)
589
+ except:
590
+ pass
591
+
592
+ logger.error(f"Generation error: {str(e)}")
593
+ raise HTTPException(status_code=500, detail=str(e))
594
+
595
+ # Initialize API
596
+ omni_api = OmniAvatarAPI()
597
+
598
+ # Use FastAPI lifespan instead of deprecated on_event
599
+ from contextlib import asynccontextmanager
600
+
601
+ @asynccontextmanager
602
+ async def lifespan(app: FastAPI):
603
+ # Startup
604
+ success = omni_api.load_model()
605
+ if not success:
606
+ logger.warning("WARNING: OmniAvatar model loading failed - running in limited mode")
607
+
608
+ # Load TTS models
609
+ try:
610
+ await omni_api.tts_manager.load_models()
611
+ logger.info("SUCCESS: TTS models initialization completed")
612
+ except Exception as e:
613
+ logger.error(f"ERROR: TTS initialization failed: {e}")
614
+
615
+ yield
616
+
617
+ # Shutdown (if needed)
618
+ logger.info("Application shutting down...")
619
+
620
+ # Create FastAPI app WITH lifespan parameter
621
+ app = FastAPI(
622
+ title="OmniAvatar-14B API with Advanced TTS",
623
+ version="1.0.0",
624
+ lifespan=lifespan
625
+ )
626
+
627
+ # Add CORS middleware
628
+ app.add_middleware(
629
+ CORSMiddleware,
630
+ allow_origins=["*"],
631
+ allow_credentials=True,
632
+ allow_methods=["*"],
633
+ allow_headers=["*"],
634
+ )
635
+
636
+ # Mount static files for serving generated videos
637
+ app.mount("/outputs", StaticFiles(directory="outputs"), name="outputs")
638
+
639
+ @app.get("/health")
640
+ async def health_check():
641
+ """Health check endpoint"""
642
+ tts_info = omni_api.tts_manager.get_tts_info()
643
+
644
+ return {
645
+ "status": "healthy",
646
+ "model_loaded": omni_api.model_loaded,
647
+ "video_generation_available": omni_api.model_loaded,
648
+ "tts_only_mode": not omni_api.model_loaded,
649
+ "device": omni_api.device,
650
+ "supports_text_to_speech": True,
651
+ "supports_image_urls": omni_api.model_loaded,
652
+ "supports_audio_urls": omni_api.model_loaded,
653
+ "tts_system": "Advanced TTS with Robust Fallback",
654
+ "advanced_tts_available": ADVANCED_TTS_AVAILABLE,
655
+ "robust_tts_available": ROBUST_TTS_AVAILABLE,
656
+ **tts_info
657
+ }
658
+
659
+ @app.get("/voices")
660
+ async def get_voices():
661
+ """Get available voice configurations"""
662
+ try:
663
+ voices = await omni_api.tts_manager.get_available_voices()
664
+ return {"voices": voices}
665
+ except Exception as e:
666
+ logger.error(f"Error getting voices: {e}")
667
+ return {"error": str(e)}
668
+
669
+ @app.post("/generate", response_model=GenerateResponse)
670
+ async def generate_avatar(request: GenerateRequest):
671
+ """Generate avatar video from prompt, text/audio, and optional image URL"""
672
+
673
+ logger.info(f"Generating avatar with prompt: {request.prompt}")
674
+ if request.text_to_speech:
675
+ logger.info(f"Text to speech: {request.text_to_speech[:100]}...")
676
+ logger.info(f"Voice ID: {request.voice_id}")
677
+ if request.audio_url:
678
+ logger.info(f"Audio URL: {request.audio_url}")
679
+ if request.image_url:
680
+ logger.info(f"Image URL: {request.image_url}")
681
+
682
+ try:
683
+ output_path, processing_time, audio_generated, tts_method = await omni_api.generate_avatar(request)
684
+
685
+ return GenerateResponse(
686
+ message="Generation completed successfully" + (" (TTS-only mode)" if not omni_api.model_loaded else ""),
687
+ output_path=get_video_url(output_path) if omni_api.model_loaded else output_path,
688
+ processing_time=processing_time,
689
+ audio_generated=audio_generated,
690
+ tts_method=tts_method
691
+ )
692
+
693
+ except HTTPException:
694
+ raise
695
+ except Exception as e:
696
+ logger.error(f"Unexpected error: {e}")
697
+ raise HTTPException(status_code=500, detail=f"Unexpected error: {e}")
698
+
699
+ @app.post("/download-models")
700
+ async def download_video_models():
701
+ """Manually trigger video model downloads"""
702
+ logger.info("?? Manual model download requested...")
703
+
704
+ try:
705
+ from huggingface_hub import snapshot_download
706
+ import shutil
707
+
708
+ # Check storage first
709
+ _, _, free_bytes = shutil.disk_usage(".")
710
+ free_gb = free_bytes / (1024**3)
711
+
712
+ logger.info(f"?? Available storage: {free_gb:.1f}GB")
713
+
714
+ if free_gb < 10: # Need at least 10GB free
715
+ return {
716
+ "success": False,
717
+ "message": f"Insufficient storage: {free_gb:.1f}GB available, 10GB+ required",
718
+ "storage_gb": free_gb
719
+ }
720
+
721
+ # Download small video generation model
722
+ logger.info("?? Downloading text-to-video model...")
723
+
724
+ model_path = snapshot_download(
725
+ repo_id="ali-vilab/text-to-video-ms-1.7b",
726
+ cache_dir="./downloaded_models/video",
727
+ local_files_only=False
728
+ )
729
+
730
+ logger.info(f"? Video model downloaded: {model_path}")
731
+
732
+ # Download audio model
733
+ audio_model_path = snapshot_download(
734
+ repo_id="facebook/wav2vec2-base-960h",
735
+ cache_dir="./downloaded_models/audio",
736
+ local_files_only=False
737
+ )
738
+
739
+ logger.info(f"? Audio model downloaded: {audio_model_path}")
740
+
741
+ # Check final storage usage
742
+ _, _, free_bytes_after = shutil.disk_usage(".")
743
+ free_gb_after = free_bytes_after / (1024**3)
744
+ used_gb = free_gb - free_gb_after
745
+
746
+ return {
747
+ "success": True,
748
+ "message": "? Video generation models downloaded successfully!",
749
+ "models_downloaded": [
750
+ "ali-vilab/text-to-video-ms-1.7b",
751
+ "facebook/wav2vec2-base-960h"
752
+ ],
753
+ "storage_used_gb": round(used_gb, 2),
754
+ "storage_remaining_gb": round(free_gb_after, 2),
755
+ "video_model_path": model_path,
756
+ "audio_model_path": audio_model_path,
757
+ "status": "READY FOR VIDEO GENERATION"
758
+ }
759
+
760
+ except Exception as e:
761
+ logger.error(f"? Model download failed: {e}")
762
+ return {
763
+ "success": False,
764
+ "message": f"Model download failed: {str(e)}",
765
+ "error": str(e)
766
+ }
767
+
768
+ @app.get("/model-status")
769
+ async def get_model_status():
770
+ """Check status of downloaded models"""
771
+ try:
772
+ models_dir = Path("./downloaded_models")
773
+
774
+ status = {
775
+ "models_downloaded": models_dir.exists(),
776
+ "available_models": [],
777
+ "storage_info": {}
778
+ }
779
+
780
+ if models_dir.exists():
781
+ for model_dir in models_dir.iterdir():
782
+ if model_dir.is_dir():
783
+ status["available_models"].append({
784
+ "name": model_dir.name,
785
+ "path": str(model_dir),
786
+ "files": len(list(model_dir.rglob("*")))
787
+ })
788
+
789
+ # Storage info
790
+ import shutil
791
+ _, _, free_bytes = shutil.disk_usage(".")
792
+ status["storage_info"] = {
793
+ "free_gb": round(free_bytes / (1024**3), 2),
794
+ "models_dir_exists": models_dir.exists()
795
+ }
796
+
797
+ return status
798
+
799
+ except Exception as e:
800
+ return {"error": str(e)}
801
+
802
+
803
+ # Enhanced Gradio interface
804
+ def gradio_generate(prompt, text_to_speech, audio_url, image_url, voice_id, guidance_scale, audio_scale, num_steps):
805
+ """Gradio interface wrapper with robust TTS support"""
806
+ try:
807
+ # Create request object
808
+ request_data = {
809
+ "prompt": prompt,
810
+ "guidance_scale": guidance_scale,
811
+ "audio_scale": audio_scale,
812
+ "num_steps": int(num_steps)
813
+ }
814
+
815
+ # Add audio source
816
+ if text_to_speech and text_to_speech.strip():
817
+ request_data["text_to_speech"] = text_to_speech
818
+ request_data["voice_id"] = voice_id or "21m00Tcm4TlvDq8ikWAM"
819
+ elif audio_url and audio_url.strip():
820
+ if omni_api.model_loaded:
821
+ request_data["audio_url"] = audio_url
822
+ else:
823
+ return "Error: Audio URL input requires full OmniAvatar models. Please use text-to-speech instead."
824
+ else:
825
+ return "Error: Please provide either text to speech or audio URL"
826
+
827
+ if image_url and image_url.strip():
828
+ if omni_api.model_loaded:
829
+ request_data["image_url"] = image_url
830
+ else:
831
+ return "Error: Image URL input requires full OmniAvatar models for video generation."
832
+
833
+ request = GenerateRequest(**request_data)
834
+
835
+ # Run async function in sync context
836
+ loop = asyncio.new_event_loop()
837
+ asyncio.set_event_loop(loop)
838
+ output_path, processing_time, audio_generated, tts_method = loop.run_until_complete(omni_api.generate_avatar(request))
839
+ loop.close()
840
+
841
+ success_message = f"SUCCESS: Generation completed in {processing_time:.1f}s using {tts_method}"
842
+ print(success_message)
843
+
844
+ if omni_api.model_loaded:
845
+ return output_path
846
+ else:
847
+ return f"πŸŽ™οΈ TTS Audio generated successfully using {tts_method}\nFile: {output_path}\n\nWARNING: Video generation unavailable (OmniAvatar models not found)"
848
+
849
+ except Exception as e:
850
+ logger.error(f"Gradio generation error: {e}")
851
+ return f"Error: {str(e)}"
852
+
853
+ # Create Gradio interface
854
+ mode_info = " (TTS-Only Mode)" if not omni_api.model_loaded else ""
855
+ description_extra = """
856
+ WARNING: Running in TTS-Only Mode - OmniAvatar models not found. Only text-to-speech generation is available.
857
+ To enable full video generation, the required model files need to be downloaded.
858
+ """ if not omni_api.model_loaded else ""
859
+
860
+ iface = gr.Interface(
861
+ fn=gradio_generate,
862
+ inputs=[
863
+ gr.Textbox(
864
+ label="Prompt",
865
+ placeholder="Describe the character behavior (e.g., 'A friendly person explaining a concept')",
866
+ lines=2
867
+ ),
868
+ gr.Textbox(
869
+ label="Text to Speech",
870
+ placeholder="Enter text to convert to speech",
871
+ lines=3,
872
+ info="Will use best available TTS system (Advanced or Fallback)"
873
+ ),
874
+ gr.Textbox(
875
+ label="OR Audio URL",
876
+ placeholder="https://example.com/audio.mp3",
877
+ info="Direct URL to audio file (requires full models)" if not omni_api.model_loaded else "Direct URL to audio file"
878
+ ),
879
+ gr.Textbox(
880
+ label="Image URL (Optional)",
881
+ placeholder="https://example.com/image.jpg",
882
+ info="Direct URL to reference image (requires full models)" if not omni_api.model_loaded else "Direct URL to reference image"
883
+ ),
884
+ gr.Dropdown(
885
+ choices=[
886
+ "21m00Tcm4TlvDq8ikWAM",
887
+ "pNInz6obpgDQGcFmaJgB",
888
+ "EXAVITQu4vr4xnSDxMaL",
889
+ "ErXwobaYiN019PkySvjV",
890
+ "TxGEqnHWrfGW9XjX",
891
+ "yoZ06aMxZJJ28mfd3POQ",
892
+ "AZnzlk1XvdvUeBnXmlld"
893
+ ],
894
+ value="21m00Tcm4TlvDq8ikWAM",
895
+ label="Voice Profile",
896
+ info="Choose voice characteristics for TTS generation"
897
+ ),
898
+ gr.Slider(minimum=1, maximum=10, value=5.0, label="Guidance Scale", info="4-6 recommended"),
899
+ gr.Slider(minimum=1, maximum=10, value=3.0, label="Audio Scale", info="Higher values = better lip-sync"),
900
+ gr.Slider(minimum=10, maximum=100, value=30, step=1, label="Number of Steps", info="20-50 recommended")
901
+ ],
902
+ outputs=gr.Video(label="Generated Avatar Video") if omni_api.model_loaded else gr.Textbox(label="TTS Output"),
903
+ title="[VIDEO] OmniAvatar-14B - Avatar Video Generation with Adaptive Body Animation",
904
+ description=f"""
905
+ Generate avatar videos with lip-sync from text prompts and speech using robust TTS system.
906
+
907
+ {description_extra}
908
+
909
+ **Robust TTS Architecture**
910
+ - **Primary**: Advanced TTS (Facebook VITS & SpeechT5) if available
911
+ - **Fallback**: Robust tone generation for 100% reliability
912
+ - **Automatic**: Seamless switching between methods
913
+
914
+ **Features:**
915
+ - **Guaranteed Generation**: Always produces audio output
916
+ - **No Dependencies**: Works even without advanced models
917
+ - **High Availability**: Multiple fallback layers
918
+ - **Voice Profiles**: Multiple voice characteristics
919
+ - **Audio URL Support**: Use external audio files {"(full models required)" if not omni_api.model_loaded else ""}
920
+ - **Image URL Support**: Reference images for characters {"(full models required)" if not omni_api.model_loaded else ""}
921
+
922
+ **Usage:**
923
+ 1. Enter a character description in the prompt
924
+ 2. **Enter text for speech generation** (recommended in current mode)
925
+ 3. {"Optionally add reference image/audio URLs (requires full models)" if not omni_api.model_loaded else "Optionally add reference image URL and choose audio source"}
926
+ 4. Choose voice profile and adjust parameters
927
+ 5. Generate your {"audio" if not omni_api.model_loaded else "avatar video"}!
928
+ """,
929
+ examples=[
930
+ [
931
+ "A professional teacher explaining a mathematical concept with clear gestures",
932
+ "Hello students! Today we're going to learn about calculus and derivatives.",
933
+ "",
934
+ "",
935
+ "21m00Tcm4TlvDq8ikWAM",
936
+ 5.0,
937
+ 3.5,
938
+ 30
939
+ ],
940
+ [
941
+ "A friendly presenter speaking confidently to an audience",
942
+ "Welcome everyone to our presentation on artificial intelligence!",
943
+ "",
944
+ "",
945
+ "pNInz6obpgDQGcFmaJgB",
946
+ 5.5,
947
+ 4.0,
948
+ 35
949
+ ]
950
+ ],
951
+ allow_flagging="never",
952
+ flagging_dir="/tmp/gradio_flagged"
953
+ )
954
+
955
+ # Mount Gradio app
956
+ app = gr.mount_gradio_app(app, iface, path="/gradio")
957
+
958
+ if __name__ == "__main__":
959
+ import uvicorn
960
+ uvicorn.run(app, host="0.0.0.0", port=7860)
961
+
962
+
963
+
964
+
965
+
966
+
967
+
968
+
969
+
970
+
971
+
972
+
973
+
974
+
975
+
976
+
977
+
978
+
979
+
980
+
981
+
982
+
983
+
984
+
985
+
app.py.backup_before_video_fix ADDED
@@ -0,0 +1,983 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AI Avatar Chat - HF Spaces Optimized Version
3
+ BUILD: 2025-01-08_00-44-FORCE-REBUILD - With Model Download Controls
4
+ FEATURES: Real video generation, model download UI, storage optimization
5
+ """
6
+ import os
7
+
8
+ # STORAGE OPTIMIZATION: Check if running on HF Spaces and disable model downloads
9
+ IS_HF_SPACE = any([
10
+ os.getenv("SPACE_ID"),
11
+ os.getenv("SPACE_AUTHOR_NAME"),
12
+ os.getenv("SPACES_BUILDKIT_VERSION"),
13
+ "/home/user/app" in os.getcwd()
14
+ ])
15
+
16
+ if IS_HF_SPACE:
17
+ # Force TTS-only mode to prevent storage limit exceeded
18
+ # os.environ[\"DISABLE_MODEL_DOWNLOAD\"] = \"1\" # ENABLED FOR VIDEO GENERATION
19
+ # os.environ[\"TTS_ONLY_MODE\"] = \"1\" # ENABLED FOR VIDEO GENERATION
20
+ os.environ["HF_SPACE_STORAGE_OPTIMIZED"] = "1"
21
+ print("?? STORAGE OPTIMIZATION: Detected HF Space environment")
22
+ print("??? TTS-only mode ENABLED (video generation disabled for storage limits)")
23
+ print("?? Model auto-download DISABLED to prevent storage exceeded error")
24
+ import os
25
+ import torch
26
+ import tempfile
27
+ import gradio as gr
28
+ from fastapi import FastAPI, HTTPException
29
+ from fastapi.staticfiles import StaticFiles
30
+ from fastapi.middleware.cors import CORSMiddleware
31
+ from pydantic import BaseModel, HttpUrl
32
+ import subprocess
33
+ import json
34
+ from pathlib import Path
35
+ import logging
36
+ import requests
37
+ from urllib.parse import urlparse
38
+ from PIL import Image
39
+ import io
40
+ from typing import Optional
41
+ import aiohttp
42
+ import asyncio
43
+ # Safe dotenv import
44
+ try:
45
+ from dotenv import load_dotenv
46
+ load_dotenv()
47
+ except ImportError:
48
+ print("Warning: python-dotenv not found, continuing without .env support")
49
+ def load_dotenv():
50
+ pass
51
+
52
+ # CRITICAL: HF Spaces compatibility fix
53
+ try:
54
+ from hf_spaces_fix import setup_hf_spaces_environment, HFSpacesCompatible
55
+ setup_hf_spaces_environment()
56
+ except ImportError:
57
+ print('Warning: HF Spaces fix not available')
58
+
59
+ # Load environment variables
60
+ load_dotenv()
61
+
62
+ # Set up logging
63
+ logging.basicConfig(level=logging.INFO)
64
+ logger = logging.getLogger(__name__)
65
+
66
+ # Set environment variables for matplotlib, gradio, and huggingface cache
67
+ os.environ['MPLCONFIGDIR'] = '/tmp/matplotlib'
68
+ os.environ['GRADIO_ALLOW_FLAGGING'] = 'never'
69
+ os.environ['HF_HOME'] = '/tmp/huggingface'
70
+ # Use HF_HOME instead of deprecated TRANSFORMERS_CACHE
71
+ os.environ['HF_DATASETS_CACHE'] = '/tmp/huggingface/datasets'
72
+ os.environ['HUGGINGFACE_HUB_CACHE'] = '/tmp/huggingface/hub'
73
+
74
+ # FastAPI app will be created after lifespan is defined
75
+
76
+
77
+
78
+ # Create directories with proper permissions
79
+ os.makedirs("outputs", exist_ok=True)
80
+ os.makedirs("/tmp/matplotlib", exist_ok=True)
81
+ os.makedirs("/tmp/huggingface", exist_ok=True)
82
+ os.makedirs("/tmp/huggingface/transformers", exist_ok=True)
83
+ os.makedirs("/tmp/huggingface/datasets", exist_ok=True)
84
+ os.makedirs("/tmp/huggingface/hub", exist_ok=True)
85
+
86
+ # Mount static files for serving generated videos
87
+
88
+
89
+ def get_video_url(output_path: str) -> str:
90
+ """Convert local file path to accessible URL"""
91
+ try:
92
+ from pathlib import Path
93
+ filename = Path(output_path).name
94
+
95
+ # For HuggingFace Spaces, construct the URL
96
+ base_url = "https://bravedims-ai-avatar-chat.hf.space"
97
+ video_url = f"{base_url}/outputs/{filename}"
98
+ logger.info(f"Generated video URL: {video_url}")
99
+ return video_url
100
+ except Exception as e:
101
+ logger.error(f"Error creating video URL: {e}")
102
+ return output_path # Fallback to original path
103
+
104
+ # Pydantic models for request/response
105
+ class GenerateRequest(BaseModel):
106
+ prompt: str
107
+ text_to_speech: Optional[str] = None # Text to convert to speech
108
+ audio_url: Optional[HttpUrl] = None # Direct audio URL
109
+ voice_id: Optional[str] = "21m00Tcm4TlvDq8ikWAM" # Voice profile ID
110
+ image_url: Optional[HttpUrl] = None
111
+ guidance_scale: float = 5.0
112
+ audio_scale: float = 3.0
113
+ num_steps: int = 30
114
+ sp_size: int = 1
115
+ tea_cache_l1_thresh: Optional[float] = None
116
+
117
+ class GenerateResponse(BaseModel):
118
+ message: str
119
+ output_path: str
120
+ processing_time: float
121
+ audio_generated: bool = False
122
+ tts_method: Optional[str] = None
123
+
124
+ # Try to import TTS clients, but make them optional
125
+ try:
126
+ from advanced_tts_client import AdvancedTTSClient
127
+ ADVANCED_TTS_AVAILABLE = True
128
+ logger.info("SUCCESS: Advanced TTS client available")
129
+ except ImportError as e:
130
+ ADVANCED_TTS_AVAILABLE = False
131
+ logger.warning(f"WARNING: Advanced TTS client not available: {e}")
132
+
133
+ # Always import the robust fallback
134
+ try:
135
+ from robust_tts_client import RobustTTSClient
136
+ ROBUST_TTS_AVAILABLE = True
137
+ logger.info("SUCCESS: Robust TTS client available")
138
+ except ImportError as e:
139
+ ROBUST_TTS_AVAILABLE = False
140
+ logger.error(f"ERROR: Robust TTS client not available: {e}")
141
+
142
+ class TTSManager:
143
+ """Manages multiple TTS clients with fallback chain"""
144
+
145
+ def __init__(self):
146
+ # Initialize TTS clients based on availability
147
+ self.advanced_tts = None
148
+ self.robust_tts = None
149
+ self.clients_loaded = False
150
+
151
+ if ADVANCED_TTS_AVAILABLE:
152
+ try:
153
+ self.advanced_tts = AdvancedTTSClient()
154
+ logger.info("SUCCESS: Advanced TTS client initialized")
155
+ except Exception as e:
156
+ logger.warning(f"WARNING: Advanced TTS client initialization failed: {e}")
157
+
158
+ if ROBUST_TTS_AVAILABLE:
159
+ try:
160
+ self.robust_tts = RobustTTSClient()
161
+ logger.info("SUCCESS: Robust TTS client initialized")
162
+ except Exception as e:
163
+ logger.error(f"ERROR: Robust TTS client initialization failed: {e}")
164
+
165
+ if not self.advanced_tts and not self.robust_tts:
166
+ logger.error("ERROR: No TTS clients available!")
167
+
168
+ async def load_models(self):
169
+ """Load TTS models"""
170
+ try:
171
+ logger.info("Loading TTS models...")
172
+
173
+ # Try to load advanced TTS first
174
+ if self.advanced_tts:
175
+ try:
176
+ logger.info("[PROCESS] Loading advanced TTS models (this may take a few minutes)...")
177
+ success = await self.advanced_tts.load_models()
178
+ if success:
179
+ logger.info("SUCCESS: Advanced TTS models loaded successfully")
180
+ else:
181
+ logger.warning("WARNING: Advanced TTS models failed to load")
182
+ except Exception as e:
183
+ logger.warning(f"WARNING: Advanced TTS loading error: {e}")
184
+
185
+ # Always ensure robust TTS is available
186
+ if self.robust_tts:
187
+ try:
188
+ await self.robust_tts.load_model()
189
+ logger.info("SUCCESS: Robust TTS fallback ready")
190
+ except Exception as e:
191
+ logger.error(f"ERROR: Robust TTS loading failed: {e}")
192
+
193
+ self.clients_loaded = True
194
+ return True
195
+
196
+ except Exception as e:
197
+ logger.error(f"ERROR: TTS manager initialization failed: {e}")
198
+ return False
199
+
200
+ async def text_to_speech(self, text: str, voice_id: Optional[str] = None) -> tuple[str, str]:
201
+ """
202
+ Convert text to speech with fallback chain
203
+ Returns: (audio_file_path, method_used)
204
+ """
205
+ if not self.clients_loaded:
206
+ logger.info("TTS models not loaded, loading now...")
207
+ await self.load_models()
208
+
209
+ logger.info(f"Generating speech: {text[:50]}...")
210
+ logger.info(f"Voice ID: {voice_id}")
211
+
212
+ # Try Advanced TTS first (Facebook VITS / SpeechT5)
213
+ if self.advanced_tts:
214
+ try:
215
+ audio_path = await self.advanced_tts.text_to_speech(text, voice_id)
216
+ return audio_path, "Facebook VITS/SpeechT5"
217
+ except Exception as advanced_error:
218
+ logger.warning(f"Advanced TTS failed: {advanced_error}")
219
+
220
+ # Fall back to robust TTS
221
+ if self.robust_tts:
222
+ try:
223
+ logger.info("Falling back to robust TTS...")
224
+ audio_path = await self.robust_tts.text_to_speech(text, voice_id)
225
+ return audio_path, "Robust TTS (Fallback)"
226
+ except Exception as robust_error:
227
+ logger.error(f"Robust TTS also failed: {robust_error}")
228
+
229
+ # If we get here, all methods failed
230
+ logger.error("All TTS methods failed!")
231
+ raise HTTPException(
232
+ status_code=500,
233
+ detail="All TTS methods failed. Please check system configuration."
234
+ )
235
+
236
+ async def get_available_voices(self):
237
+ """Get available voice configurations"""
238
+ try:
239
+ if self.advanced_tts and hasattr(self.advanced_tts, 'get_available_voices'):
240
+ return await self.advanced_tts.get_available_voices()
241
+ except:
242
+ pass
243
+
244
+ # Return default voices if advanced TTS not available
245
+ return {
246
+ "21m00Tcm4TlvDq8ikWAM": "Female (Neutral)",
247
+ "pNInz6obpgDQGcFmaJgB": "Male (Professional)",
248
+ "EXAVITQu4vr4xnSDxMaL": "Female (Sweet)",
249
+ "ErXwobaYiN019PkySvjV": "Male (Professional)",
250
+ "TxGEqnHWrfGW9XjX": "Male (Deep)",
251
+ "yoZ06aMxZJJ28mfd3POQ": "Unisex (Friendly)",
252
+ "AZnzlk1XvdvUeBnXmlld": "Female (Strong)"
253
+ }
254
+
255
+ def get_tts_info(self):
256
+ """Get TTS system information"""
257
+ info = {
258
+ "clients_loaded": self.clients_loaded,
259
+ "advanced_tts_available": self.advanced_tts is not None,
260
+ "robust_tts_available": self.robust_tts is not None,
261
+ "primary_method": "Robust TTS"
262
+ }
263
+
264
+ try:
265
+ if self.advanced_tts and hasattr(self.advanced_tts, 'get_model_info'):
266
+ advanced_info = self.advanced_tts.get_model_info()
267
+ info.update({
268
+ "advanced_tts_loaded": advanced_info.get("models_loaded", False),
269
+ "transformers_available": advanced_info.get("transformers_available", False),
270
+ "primary_method": "Facebook VITS/SpeechT5" if advanced_info.get("models_loaded") else "Robust TTS",
271
+ "device": advanced_info.get("device", "cpu"),
272
+ "vits_available": advanced_info.get("vits_available", False),
273
+ "speecht5_available": advanced_info.get("speecht5_available", False)
274
+ })
275
+ except Exception as e:
276
+ logger.debug(f"Could not get advanced TTS info: {e}")
277
+
278
+ return info
279
+
280
+ # Import the VIDEO-FOCUSED engine
281
+ try:
282
+ from omniavatar_video_engine import video_engine
283
+ VIDEO_ENGINE_AVAILABLE = True
284
+ logger.info("SUCCESS: OmniAvatar Video Engine available")
285
+ except ImportError as e:
286
+ VIDEO_ENGINE_AVAILABLE = False
287
+ logger.error(f"ERROR: OmniAvatar Video Engine not available: {e}")
288
+
289
+ class OmniAvatarAPI:
290
+ def __init__(self):
291
+ self.model_loaded = False
292
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
293
+ self.tts_manager = TTSManager()
294
+ logger.info(f"Using device: {self.device}")
295
+ logger.info("Initialized with robust TTS system")
296
+
297
+ def load_model(self):
298
+ """Load the OmniAvatar model - now more flexible"""
299
+ try:
300
+ # Check if models are downloaded (but don't require them)
301
+ model_paths = [
302
+ "./pretrained_models/Wan2.1-T2V-14B",
303
+ "./pretrained_models/OmniAvatar-14B",
304
+ "./pretrained_models/wav2vec2-base-960h"
305
+ ]
306
+
307
+ missing_models = []
308
+ for path in model_paths:
309
+ if not os.path.exists(path):
310
+ missing_models.append(path)
311
+
312
+ if missing_models:
313
+ logger.warning("WARNING: Some OmniAvatar models not found:")
314
+ for model in missing_models:
315
+ logger.warning(f" - {model}")
316
+ logger.info("TIP: App will run in TTS-only mode (no video generation)")
317
+ logger.info("TIP: To enable full avatar generation, download the required models")
318
+
319
+ # Set as loaded but in limited mode
320
+ self.model_loaded = False # Video generation disabled
321
+ return True # But app can still run
322
+ else:
323
+ self.model_loaded = True
324
+ logger.info("SUCCESS: All OmniAvatar models found - full functionality enabled")
325
+ return True
326
+
327
+ except Exception as e:
328
+ logger.error(f"Error checking models: {str(e)}")
329
+ logger.info("TIP: Continuing in TTS-only mode")
330
+ self.model_loaded = False
331
+ return True # Continue running
332
+
333
+ async def download_file(self, url: str, suffix: str = "") -> str:
334
+ """Download file from URL and save to temporary location"""
335
+ try:
336
+ async with aiohttp.ClientSession() as session:
337
+ async with session.get(str(url)) as response:
338
+ if response.status != 200:
339
+ raise HTTPException(status_code=400, detail=f"Failed to download file from URL: {url}")
340
+
341
+ content = await response.read()
342
+
343
+ # Create temporary file
344
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=suffix)
345
+ temp_file.write(content)
346
+ temp_file.close()
347
+
348
+ return temp_file.name
349
+
350
+ except aiohttp.ClientError as e:
351
+ logger.error(f"Network error downloading {url}: {e}")
352
+ raise HTTPException(status_code=400, detail=f"Network error downloading file: {e}")
353
+ except Exception as e:
354
+ logger.error(f"Error downloading file from {url}: {e}")
355
+ raise HTTPException(status_code=500, detail=f"Error downloading file: {e}")
356
+
357
+ def validate_audio_url(self, url: str) -> bool:
358
+ """Validate if URL is likely an audio file"""
359
+ try:
360
+ parsed = urlparse(url)
361
+ # Check for common audio file extensions
362
+ audio_extensions = ['.mp3', '.wav', '.m4a', '.ogg', '.aac', '.flac']
363
+ is_audio_ext = any(parsed.path.lower().endswith(ext) for ext in audio_extensions)
364
+
365
+ return is_audio_ext or 'audio' in url.lower()
366
+ except:
367
+ return False
368
+
369
+ def validate_image_url(self, url: str) -> bool:
370
+ """Validate if URL is likely an image file"""
371
+ try:
372
+ parsed = urlparse(url)
373
+ image_extensions = ['.jpg', '.jpeg', '.png', '.webp', '.bmp', '.gif']
374
+ return any(parsed.path.lower().endswith(ext) for ext in image_extensions)
375
+ except:
376
+ return False
377
+
378
+ async def generate_avatar(self, request: GenerateRequest) -> tuple[str, float, bool, str]:
379
+ """Generate avatar VIDEO - PRIMARY FUNCTIONALITY"""
380
+ import time
381
+ start_time = time.time()
382
+ audio_generated = False
383
+ method_used = "Unknown"
384
+
385
+ logger.info("[VIDEO] STARTING AVATAR VIDEO GENERATION")
386
+ logger.info(f"[INFO] Prompt: {request.prompt}")
387
+
388
+ if VIDEO_ENGINE_AVAILABLE:
389
+ try:
390
+ # PRIORITIZE VIDEO GENERATION
391
+ logger.info("[TARGET] Using OmniAvatar Video Engine for FULL video generation")
392
+
393
+ # Handle audio source
394
+ audio_path = None
395
+ if request.text_to_speech:
396
+ logger.info("[MIC] Generating audio from text...")
397
+ audio_path, method_used = await self.tts_manager.text_to_speech(
398
+ request.text_to_speech,
399
+ request.voice_id or "21m00Tcm4TlvDq8ikWAM"
400
+ )
401
+ audio_generated = True
402
+ elif request.audio_url:
403
+ logger.info("πŸ“₯ Downloading audio from URL...")
404
+ audio_path = await self.download_file(str(request.audio_url), ".mp3")
405
+ method_used = "External Audio"
406
+ else:
407
+ raise HTTPException(status_code=400, detail="Either text_to_speech or audio_url required for video generation")
408
+
409
+ # Handle image if provided
410
+ image_path = None
411
+ if request.image_url:
412
+ logger.info("[IMAGE] Downloading reference image...")
413
+ parsed = urlparse(str(request.image_url))
414
+ ext = os.path.splitext(parsed.path)[1] or ".jpg"
415
+ image_path = await self.download_file(str(request.image_url), ext)
416
+
417
+ # GENERATE VIDEO using OmniAvatar engine
418
+ logger.info("[VIDEO] Generating avatar video with adaptive body animation...")
419
+ video_path, generation_time = video_engine.generate_avatar_video(
420
+ prompt=request.prompt,
421
+ audio_path=audio_path,
422
+ image_path=image_path,
423
+ guidance_scale=request.guidance_scale,
424
+ audio_scale=request.audio_scale,
425
+ num_steps=request.num_steps
426
+ )
427
+
428
+ processing_time = time.time() - start_time
429
+ logger.info(f"SUCCESS: VIDEO GENERATED successfully in {processing_time:.1f}s")
430
+
431
+ # Cleanup temporary files
432
+ if audio_path and os.path.exists(audio_path):
433
+ os.unlink(audio_path)
434
+ if image_path and os.path.exists(image_path):
435
+ os.unlink(image_path)
436
+
437
+ return video_path, processing_time, audio_generated, f"OmniAvatar Video Generation ({method_used})"
438
+
439
+ except Exception as e:
440
+ logger.error(f"ERROR: Video generation failed: {e}")
441
+ # For a VIDEO generation app, we should NOT fall back to audio-only
442
+ # Instead, provide clear guidance
443
+ if "models" in str(e).lower():
444
+ raise HTTPException(
445
+ status_code=503,
446
+ detail=f"Video generation requires OmniAvatar models (~30GB). Please run model download script. Error: {str(e)}"
447
+ )
448
+ else:
449
+ raise HTTPException(status_code=500, detail=f"Video generation failed: {str(e)}")
450
+
451
+ # If video engine not available, this is a critical error for a VIDEO app
452
+ raise HTTPException(
453
+ status_code=503,
454
+ detail="Video generation engine not available. This application requires OmniAvatar models for video generation."
455
+ )
456
+
457
+ async def generate_avatar_BACKUP(self, request: GenerateRequest) -> tuple[str, float, bool, str]:
458
+ """OLD TTS-ONLY METHOD - kept as backup reference.
459
+ Generate avatar video from prompt and audio/text - now handles missing models"""
460
+ import time
461
+ start_time = time.time()
462
+ audio_generated = False
463
+ tts_method = None
464
+
465
+ try:
466
+ # Check if video generation is available
467
+ if not self.model_loaded:
468
+ logger.info("πŸŽ™οΈ Running in TTS-only mode (OmniAvatar models not available)")
469
+
470
+ # Only generate audio, no video
471
+ if request.text_to_speech:
472
+ logger.info(f"Generating speech from text: {request.text_to_speech[:50]}...")
473
+ audio_path, tts_method = await self.tts_manager.text_to_speech(
474
+ request.text_to_speech,
475
+ request.voice_id or "21m00Tcm4TlvDq8ikWAM"
476
+ )
477
+
478
+ # Return the audio file as the "output"
479
+ processing_time = time.time() - start_time
480
+ logger.info(f"SUCCESS: TTS completed in {processing_time:.1f}s using {tts_method}")
481
+ return audio_path, processing_time, True, f"{tts_method} (TTS-only mode)"
482
+ else:
483
+ raise HTTPException(
484
+ status_code=503,
485
+ detail="Video generation unavailable. OmniAvatar models not found. Only TTS from text is supported."
486
+ )
487
+
488
+ # Original video generation logic (when models are available)
489
+ # Determine audio source
490
+ audio_path = None
491
+
492
+ if request.text_to_speech:
493
+ # Generate speech from text using TTS manager
494
+ logger.info(f"Generating speech from text: {request.text_to_speech[:50]}...")
495
+ audio_path, tts_method = await self.tts_manager.text_to_speech(
496
+ request.text_to_speech,
497
+ request.voice_id or "21m00Tcm4TlvDq8ikWAM"
498
+ )
499
+ audio_generated = True
500
+
501
+ elif request.audio_url:
502
+ # Download audio from provided URL
503
+ logger.info(f"Downloading audio from URL: {request.audio_url}")
504
+ if not self.validate_audio_url(str(request.audio_url)):
505
+ logger.warning(f"Audio URL may not be valid: {request.audio_url}")
506
+
507
+ audio_path = await self.download_file(str(request.audio_url), ".mp3")
508
+ tts_method = "External Audio URL"
509
+
510
+ else:
511
+ raise HTTPException(
512
+ status_code=400,
513
+ detail="Either text_to_speech or audio_url must be provided"
514
+ )
515
+
516
+ # Download image if provided
517
+ image_path = None
518
+ if request.image_url:
519
+ logger.info(f"Downloading image from URL: {request.image_url}")
520
+ if not self.validate_image_url(str(request.image_url)):
521
+ logger.warning(f"Image URL may not be valid: {request.image_url}")
522
+
523
+ # Determine image extension from URL or default to .jpg
524
+ parsed = urlparse(str(request.image_url))
525
+ ext = os.path.splitext(parsed.path)[1] or ".jpg"
526
+ image_path = await self.download_file(str(request.image_url), ext)
527
+
528
+ # Create temporary input file for inference
529
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f:
530
+ if image_path:
531
+ input_line = f"{request.prompt}@@{image_path}@@{audio_path}"
532
+ else:
533
+ input_line = f"{request.prompt}@@@@{audio_path}"
534
+ f.write(input_line)
535
+ temp_input_file = f.name
536
+
537
+ # Prepare inference command
538
+ cmd = [
539
+ "python", "-m", "torch.distributed.run",
540
+ "--standalone", f"--nproc_per_node={request.sp_size}",
541
+ "scripts/inference.py",
542
+ "--config", "configs/inference.yaml",
543
+ "--input_file", temp_input_file,
544
+ "--guidance_scale", str(request.guidance_scale),
545
+ "--audio_scale", str(request.audio_scale),
546
+ "--num_steps", str(request.num_steps)
547
+ ]
548
+
549
+ if request.tea_cache_l1_thresh:
550
+ cmd.extend(["--tea_cache_l1_thresh", str(request.tea_cache_l1_thresh)])
551
+
552
+ logger.info(f"Running inference with command: {' '.join(cmd)}")
553
+
554
+ # Run inference
555
+ result = subprocess.run(cmd, capture_output=True, text=True)
556
+
557
+ # Clean up temporary files
558
+ os.unlink(temp_input_file)
559
+ os.unlink(audio_path)
560
+ if image_path:
561
+ os.unlink(image_path)
562
+
563
+ if result.returncode != 0:
564
+ logger.error(f"Inference failed: {result.stderr}")
565
+ raise Exception(f"Inference failed: {result.stderr}")
566
+
567
+ # Find output video file
568
+ output_dir = "./outputs"
569
+ if os.path.exists(output_dir):
570
+ video_files = [f for f in os.listdir(output_dir) if f.endswith(('.mp4', '.avi'))]
571
+ if video_files:
572
+ # Return the most recent video file
573
+ video_files.sort(key=lambda x: os.path.getmtime(os.path.join(output_dir, x)), reverse=True)
574
+ output_path = os.path.join(output_dir, video_files[0])
575
+ processing_time = time.time() - start_time
576
+ return output_path, processing_time, audio_generated, tts_method
577
+
578
+ raise Exception("No output video generated")
579
+
580
+ except Exception as e:
581
+ # Clean up any temporary files in case of error
582
+ try:
583
+ if 'audio_path' in locals() and audio_path and os.path.exists(audio_path):
584
+ os.unlink(audio_path)
585
+ if 'image_path' in locals() and image_path and os.path.exists(image_path):
586
+ os.unlink(image_path)
587
+ if 'temp_input_file' in locals() and os.path.exists(temp_input_file):
588
+ os.unlink(temp_input_file)
589
+ except:
590
+ pass
591
+
592
+ logger.error(f"Generation error: {str(e)}")
593
+ raise HTTPException(status_code=500, detail=str(e))
594
+
595
+ # Initialize API
596
+ omni_api = OmniAvatarAPI()
597
+
598
+ # Use FastAPI lifespan instead of deprecated on_event
599
+ from contextlib import asynccontextmanager
600
+
601
+ @asynccontextmanager
602
+ async def lifespan(app: FastAPI):
603
+ # Startup
604
+ success = omni_api.load_model()
605
+ if not success:
606
+ logger.warning("WARNING: OmniAvatar model loading failed - running in limited mode")
607
+
608
+ # Load TTS models
609
+ try:
610
+ await omni_api.tts_manager.load_models()
611
+ logger.info("SUCCESS: TTS models initialization completed")
612
+ except Exception as e:
613
+ logger.error(f"ERROR: TTS initialization failed: {e}")
614
+
615
+ yield
616
+
617
+ # Shutdown (if needed)
618
+ logger.info("Application shutting down...")
619
+
620
+ # Create FastAPI app WITH lifespan parameter
621
+ app = FastAPI(
622
+ title="OmniAvatar-14B API with Advanced TTS",
623
+ version="1.0.0",
624
+ lifespan=lifespan
625
+ )
626
+
627
+ # Add CORS middleware
628
+ app.add_middleware(
629
+ CORSMiddleware,
630
+ allow_origins=["*"],
631
+ allow_credentials=True,
632
+ allow_methods=["*"],
633
+ allow_headers=["*"],
634
+ )
635
+
636
+ # Mount static files for serving generated videos
637
+ app.mount("/outputs", StaticFiles(directory="outputs"), name="outputs")
638
+
639
+ @app.get("/health")
640
+ async def health_check():
641
+ """Health check endpoint"""
642
+ tts_info = omni_api.tts_manager.get_tts_info()
643
+
644
+ return {
645
+ "status": "healthy",
646
+ "model_loaded": omni_api.model_loaded,
647
+ "video_generation_available": omni_api.model_loaded,
648
+ "tts_only_mode": not omni_api.model_loaded,
649
+ "device": omni_api.device,
650
+ "supports_text_to_speech": True,
651
+ "supports_image_urls": omni_api.model_loaded,
652
+ "supports_audio_urls": omni_api.model_loaded,
653
+ "tts_system": "Advanced TTS with Robust Fallback",
654
+ "advanced_tts_available": ADVANCED_TTS_AVAILABLE,
655
+ "robust_tts_available": ROBUST_TTS_AVAILABLE,
656
+ **tts_info
657
+ }
658
+
659
+ @app.get("/voices")
660
+ async def get_voices():
661
+ """Get available voice configurations"""
662
+ try:
663
+ voices = await omni_api.tts_manager.get_available_voices()
664
+ return {"voices": voices}
665
+ except Exception as e:
666
+ logger.error(f"Error getting voices: {e}")
667
+ return {"error": str(e)}
668
+
669
+ @app.post("/generate", response_model=GenerateResponse)
670
+ async def generate_avatar(request: GenerateRequest):
671
+ """Generate avatar video from prompt, text/audio, and optional image URL"""
672
+
673
+ logger.info(f"Generating avatar with prompt: {request.prompt}")
674
+ if request.text_to_speech:
675
+ logger.info(f"Text to speech: {request.text_to_speech[:100]}...")
676
+ logger.info(f"Voice ID: {request.voice_id}")
677
+ if request.audio_url:
678
+ logger.info(f"Audio URL: {request.audio_url}")
679
+ if request.image_url:
680
+ logger.info(f"Image URL: {request.image_url}")
681
+
682
+ try:
683
+ output_path, processing_time, audio_generated, tts_method = await omni_api.generate_avatar(request)
684
+
685
+ return GenerateResponse(
686
+ message="Generation completed successfully" + (" (TTS-only mode)" if not omni_api.model_loaded else ""),
687
+ output_path=get_video_url(output_path) if omni_api.model_loaded else output_path,
688
+ processing_time=processing_time,
689
+ audio_generated=audio_generated,
690
+ tts_method=tts_method
691
+ )
692
+
693
+ except HTTPException:
694
+ raise
695
+ except Exception as e:
696
+ logger.error(f"Unexpected error: {e}")
697
+ raise HTTPException(status_code=500, detail=f"Unexpected error: {e}")
698
+
699
+ @app.post("/download-models")
700
+ async def download_video_models():
701
+ """Manually trigger video model downloads"""
702
+ logger.info("?? Manual model download requested...")
703
+
704
+ try:
705
+ from huggingface_hub import snapshot_download
706
+ import shutil
707
+
708
+ # Check storage first
709
+ _, _, free_bytes = shutil.disk_usage(".")
710
+ free_gb = free_bytes / (1024**3)
711
+
712
+ logger.info(f"?? Available storage: {free_gb:.1f}GB")
713
+
714
+ if free_gb < 10: # Need at least 10GB free
715
+ return {
716
+ "success": False,
717
+ "message": f"Insufficient storage: {free_gb:.1f}GB available, 10GB+ required",
718
+ "storage_gb": free_gb
719
+ }
720
+
721
+ # Download small video generation model
722
+ logger.info("?? Downloading text-to-video model...")
723
+
724
+ model_path = snapshot_download(
725
+ repo_id="ali-vilab/text-to-video-ms-1.7b",
726
+ cache_dir="./downloaded_models/video",
727
+ local_files_only=False
728
+ )
729
+
730
+ logger.info(f"? Video model downloaded: {model_path}")
731
+
732
+ # Download audio model
733
+ audio_model_path = snapshot_download(
734
+ repo_id="facebook/wav2vec2-base-960h",
735
+ cache_dir="./downloaded_models/audio",
736
+ local_files_only=False
737
+ )
738
+
739
+ logger.info(f"? Audio model downloaded: {audio_model_path}")
740
+
741
+ # Check final storage usage
742
+ _, _, free_bytes_after = shutil.disk_usage(".")
743
+ free_gb_after = free_bytes_after / (1024**3)
744
+ used_gb = free_gb - free_gb_after
745
+
746
+ return {
747
+ "success": True,
748
+ "message": "? Video generation models downloaded successfully!",
749
+ "models_downloaded": [
750
+ "ali-vilab/text-to-video-ms-1.7b",
751
+ "facebook/wav2vec2-base-960h"
752
+ ],
753
+ "storage_used_gb": round(used_gb, 2),
754
+ "storage_remaining_gb": round(free_gb_after, 2),
755
+ "video_model_path": model_path,
756
+ "audio_model_path": audio_model_path,
757
+ "status": "READY FOR VIDEO GENERATION"
758
+ }
759
+
760
+ except Exception as e:
761
+ logger.error(f"? Model download failed: {e}")
762
+ return {
763
+ "success": False,
764
+ "message": f"Model download failed: {str(e)}",
765
+ "error": str(e)
766
+ }
767
+
768
+ @app.get("/model-status")
769
+ async def get_model_status():
770
+ """Check status of downloaded models"""
771
+ try:
772
+ models_dir = Path("./downloaded_models")
773
+
774
+ status = {
775
+ "models_downloaded": models_dir.exists(),
776
+ "available_models": [],
777
+ "storage_info": {}
778
+ }
779
+
780
+ if models_dir.exists():
781
+ for model_dir in models_dir.iterdir():
782
+ if model_dir.is_dir():
783
+ status["available_models"].append({
784
+ "name": model_dir.name,
785
+ "path": str(model_dir),
786
+ "files": len(list(model_dir.rglob("*")))
787
+ })
788
+
789
+ # Storage info
790
+ import shutil
791
+ _, _, free_bytes = shutil.disk_usage(".")
792
+ status["storage_info"] = {
793
+ "free_gb": round(free_bytes / (1024**3), 2),
794
+ "models_dir_exists": models_dir.exists()
795
+ }
796
+
797
+ return status
798
+
799
+ except Exception as e:
800
+ return {"error": str(e)}
801
+
802
+
803
+ # Enhanced Gradio interface
804
+ def gradio_generate(prompt, text_to_speech, audio_url, image_url, voice_id, guidance_scale, audio_scale, num_steps):
805
+ """Gradio interface wrapper with robust TTS support"""
806
+ try:
807
+ # Create request object
808
+ request_data = {
809
+ "prompt": prompt,
810
+ "guidance_scale": guidance_scale,
811
+ "audio_scale": audio_scale,
812
+ "num_steps": int(num_steps)
813
+ }
814
+
815
+ # Add audio source
816
+ if text_to_speech and text_to_speech.strip():
817
+ request_data["text_to_speech"] = text_to_speech
818
+ request_data["voice_id"] = voice_id or "21m00Tcm4TlvDq8ikWAM"
819
+ elif audio_url and audio_url.strip():
820
+ if omni_api.model_loaded:
821
+ request_data["audio_url"] = audio_url
822
+ else:
823
+ return "Error: Audio URL input requires full OmniAvatar models. Please use text-to-speech instead."
824
+ else:
825
+ return "Error: Please provide either text to speech or audio URL"
826
+
827
+ if image_url and image_url.strip():
828
+ if omni_api.model_loaded:
829
+ request_data["image_url"] = image_url
830
+ else:
831
+ return "Error: Image URL input requires full OmniAvatar models for video generation."
832
+
833
+ request = GenerateRequest(**request_data)
834
+
835
+ # Run async function in sync context
836
+ loop = asyncio.new_event_loop()
837
+ asyncio.set_event_loop(loop)
838
+ output_path, processing_time, audio_generated, tts_method = loop.run_until_complete(omni_api.generate_avatar(request))
839
+ loop.close()
840
+
841
+ success_message = f"SUCCESS: Generation completed in {processing_time:.1f}s using {tts_method}"
842
+ print(success_message)
843
+
844
+ if omni_api.model_loaded:
845
+ return output_path
846
+ else:
847
+ return f"πŸŽ™οΈ TTS Audio generated successfully using {tts_method}\nFile: {output_path}\n\nWARNING: Video generation unavailable (OmniAvatar models not found)"
848
+
849
+ except Exception as e:
850
+ logger.error(f"Gradio generation error: {e}")
851
+ return f"Error: {str(e)}"
852
+
853
+ # Create Gradio interface
854
+ mode_info = " (TTS-Only Mode)" if not omni_api.model_loaded else ""
855
+ description_extra = """
856
+ WARNING: Running in TTS-Only Mode - OmniAvatar models not found. Only text-to-speech generation is available.
857
+ To enable full video generation, the required model files need to be downloaded.
858
+ """ if not omni_api.model_loaded else ""
859
+
860
+ iface = gr.Interface(
861
+ fn=gradio_generate,
862
+ inputs=[
863
+ gr.Textbox(
864
+ label="Prompt",
865
+ placeholder="Describe the character behavior (e.g., 'A friendly person explaining a concept')",
866
+ lines=2
867
+ ),
868
+ gr.Textbox(
869
+ label="Text to Speech",
870
+ placeholder="Enter text to convert to speech",
871
+ lines=3,
872
+ info="Will use best available TTS system (Advanced or Fallback)"
873
+ ),
874
+ gr.Textbox(
875
+ label="OR Audio URL",
876
+ placeholder="https://example.com/audio.mp3",
877
+ info="Direct URL to audio file (requires full models)" if not omni_api.model_loaded else "Direct URL to audio file"
878
+ ),
879
+ gr.Textbox(
880
+ label="Image URL (Optional)",
881
+ placeholder="https://example.com/image.jpg",
882
+ info="Direct URL to reference image (requires full models)" if not omni_api.model_loaded else "Direct URL to reference image"
883
+ ),
884
+ gr.Dropdown(
885
+ choices=[
886
+ "21m00Tcm4TlvDq8ikWAM",
887
+ "pNInz6obpgDQGcFmaJgB",
888
+ "EXAVITQu4vr4xnSDxMaL",
889
+ "ErXwobaYiN019PkySvjV",
890
+ "TxGEqnHWrfGW9XjX",
891
+ "yoZ06aMxZJJ28mfd3POQ",
892
+ "AZnzlk1XvdvUeBnXmlld"
893
+ ],
894
+ value="21m00Tcm4TlvDq8ikWAM",
895
+ label="Voice Profile",
896
+ info="Choose voice characteristics for TTS generation"
897
+ ),
898
+ gr.Slider(minimum=1, maximum=10, value=5.0, label="Guidance Scale", info="4-6 recommended"),
899
+ gr.Slider(minimum=1, maximum=10, value=3.0, label="Audio Scale", info="Higher values = better lip-sync"),
900
+ gr.Slider(minimum=10, maximum=100, value=30, step=1, label="Number of Steps", info="20-50 recommended")
901
+ ],
902
+ outputs=gr.Video(label="Generated Avatar Video") if omni_api.model_loaded else gr.Textbox(label="TTS Output"),
903
+ title="[VIDEO] OmniAvatar-14B - Avatar Video Generation with Adaptive Body Animation",
904
+ description=f"""
905
+ Generate avatar videos with lip-sync from text prompts and speech using robust TTS system.
906
+
907
+ {description_extra}
908
+
909
+ **Robust TTS Architecture**
910
+ - **Primary**: Advanced TTS (Facebook VITS & SpeechT5) if available
911
+ - **Fallback**: Robust tone generation for 100% reliability
912
+ - **Automatic**: Seamless switching between methods
913
+
914
+ **Features:**
915
+ - **Guaranteed Generation**: Always produces audio output
916
+ - **No Dependencies**: Works even without advanced models
917
+ - **High Availability**: Multiple fallback layers
918
+ - **Voice Profiles**: Multiple voice characteristics
919
+ - **Audio URL Support**: Use external audio files {"(full models required)" if not omni_api.model_loaded else ""}
920
+ - **Image URL Support**: Reference images for characters {"(full models required)" if not omni_api.model_loaded else ""}
921
+
922
+ **Usage:**
923
+ 1. Enter a character description in the prompt
924
+ 2. **Enter text for speech generation** (recommended in current mode)
925
+ 3. {"Optionally add reference image/audio URLs (requires full models)" if not omni_api.model_loaded else "Optionally add reference image URL and choose audio source"}
926
+ 4. Choose voice profile and adjust parameters
927
+ 5. Generate your {"audio" if not omni_api.model_loaded else "avatar video"}!
928
+ """,
929
+ examples=[
930
+ [
931
+ "A professional teacher explaining a mathematical concept with clear gestures",
932
+ "Hello students! Today we're going to learn about calculus and derivatives.",
933
+ "",
934
+ "",
935
+ "21m00Tcm4TlvDq8ikWAM",
936
+ 5.0,
937
+ 3.5,
938
+ 30
939
+ ],
940
+ [
941
+ "A friendly presenter speaking confidently to an audience",
942
+ "Welcome everyone to our presentation on artificial intelligence!",
943
+ "",
944
+ "",
945
+ "pNInz6obpgDQGcFmaJgB",
946
+ 5.5,
947
+ 4.0,
948
+ 35
949
+ ]
950
+ ],
951
+ allow_flagging="never",
952
+ flagging_dir="/tmp/gradio_flagged"
953
+ )
954
+
955
+ # Mount Gradio app
956
+ app = gr.mount_gradio_app(app, iface, path="/gradio")
957
+
958
+ if __name__ == "__main__":
959
+ import uvicorn
960
+ uvicorn.run(app, host="0.0.0.0", port=7860)
961
+
962
+
963
+
964
+
965
+
966
+
967
+
968
+
969
+
970
+
971
+
972
+
973
+
974
+
975
+
976
+
977
+
978
+
979
+
980
+
981
+
982
+
983
+
requirements.txt CHANGED
Binary files a/requirements.txt and b/requirements.txt differ