bravedims commited on
Commit
bcba9ba
·
1 Parent(s): da41971

\🔥 CRITICAL: Fix unterminated triple-quoted string syntax error"

Browse files
Files changed (3) hide show
  1. app.py.backup +482 -137
  2. app_fixed.py +828 -0
  3. app_temp.py +827 -0
app.py.backup CHANGED
@@ -3,6 +3,7 @@ import torch
3
  import tempfile
4
  import gradio as gr
5
  from fastapi import FastAPI, HTTPException
 
6
  from fastapi.middleware.cors import CORSMiddleware
7
  from pydantic import BaseModel, HttpUrl
8
  import subprocess
@@ -25,23 +26,50 @@ load_dotenv()
25
  logging.basicConfig(level=logging.INFO)
26
  logger = logging.getLogger(__name__)
27
 
28
- app = FastAPI(title="OmniAvatar-14B API with ElevenLabs", version="1.0.0")
 
 
 
 
 
 
29
 
30
- # Add CORS middleware
31
- app.add_middleware(
32
- CORSMiddleware,
33
- allow_origins=["*"],
34
- allow_credentials=True,
35
- allow_methods=["*"],
36
- allow_headers=["*"],
37
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  # Pydantic models for request/response
40
  class GenerateRequest(BaseModel):
41
  prompt: str
42
  text_to_speech: Optional[str] = None # Text to convert to speech
43
- elevenlabs_audio_url: Optional[HttpUrl] = None # Direct audio URL
44
- voice_id: Optional[str] = "21m00Tcm4TlvDq8ikWAM" # Default ElevenLabs voice
45
  image_url: Optional[HttpUrl] = None
46
  guidance_scale: float = 5.0
47
  audio_scale: float = 3.0
@@ -54,88 +82,216 @@ class GenerateResponse(BaseModel):
54
  output_path: str
55
  processing_time: float
56
  audio_generated: bool = False
 
57
 
58
- class ElevenLabsClient:
59
- def __init__(self, api_key: str = None):
60
- self.api_key = api_key or os.getenv("ELEVENLABS_API_KEY", "sk_c7a0b115cd48fc026226158c5ac87755b063c802ad892de6")
61
- self.base_url = "https://api.elevenlabs.io/v1"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
- async def text_to_speech(self, text: str, voice_id: str = "21m00Tcm4TlvDq8ikWAM") -> str:
64
- """Convert text to speech using ElevenLabs and return temporary file path"""
65
- url = f"{self.base_url}/text-to-speech/{voice_id}"
 
 
 
66
 
67
- headers = {
68
- "Accept": "audio/mpeg",
69
- "Content-Type": "application/json",
70
- "xi-api-key": self.api_key
71
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
- data = {
74
- "text": text,
75
- "model_id": "eleven_monolingual_v1",
76
- "voice_settings": {
77
- "stability": 0.5,
78
- "similarity_boost": 0.5
79
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  }
81
 
82
  try:
83
- async with aiohttp.ClientSession() as session:
84
- async with session.post(url, headers=headers, json=data) as response:
85
- if response.status != 200:
86
- error_text = await response.text()
87
- raise HTTPException(
88
- status_code=400,
89
- detail=f"ElevenLabs API error: {response.status} - {error_text}"
90
- )
91
-
92
- audio_content = await response.read()
93
-
94
- # Save to temporary file
95
- temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp3')
96
- temp_file.write(audio_content)
97
- temp_file.close()
98
-
99
- logger.info(f"Generated speech audio: {temp_file.name}")
100
- return temp_file.name
101
-
102
- except aiohttp.ClientError as e:
103
- logger.error(f"Network error calling ElevenLabs: {e}")
104
- raise HTTPException(status_code=400, detail=f"Network error calling ElevenLabs: {e}")
105
  except Exception as e:
106
- logger.error(f"Error generating speech: {e}")
107
- raise HTTPException(status_code=500, detail=f"Error generating speech: {e}")
 
 
 
 
 
 
 
 
 
 
108
 
109
  class OmniAvatarAPI:
110
  def __init__(self):
111
  self.model_loaded = False
112
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
113
- self.elevenlabs_client = ElevenLabsClient()
114
  logger.info(f"Using device: {self.device}")
115
- logger.info(f"ElevenLabs API Key configured: {'Yes' if self.elevenlabs_client.api_key else 'No'}")
116
 
117
  def load_model(self):
118
- """Load the OmniAvatar model"""
119
  try:
120
- # Check if models are downloaded
121
  model_paths = [
122
  "./pretrained_models/Wan2.1-T2V-14B",
123
  "./pretrained_models/OmniAvatar-14B",
124
  "./pretrained_models/wav2vec2-base-960h"
125
  ]
126
 
 
127
  for path in model_paths:
128
  if not os.path.exists(path):
129
- logger.error(f"Model path not found: {path}")
130
- return False
131
-
132
- self.model_loaded = True
133
- logger.info("Models loaded successfully")
134
- return True
135
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  except Exception as e:
137
- logger.error(f"Error loading model: {str(e)}")
138
- return False
 
 
139
 
140
  async def download_file(self, url: str, suffix: str = "") -> str:
141
  """Download file from URL and save to temporary location"""
@@ -165,12 +321,11 @@ class OmniAvatarAPI:
165
  """Validate if URL is likely an audio file"""
166
  try:
167
  parsed = urlparse(url)
168
- # Check for common audio file extensions or ElevenLabs patterns
169
- audio_extensions = ['.mp3', '.wav', '.m4a', '.ogg', '.aac']
170
  is_audio_ext = any(parsed.path.lower().endswith(ext) for ext in audio_extensions)
171
- is_elevenlabs = 'elevenlabs' in parsed.netloc.lower()
172
 
173
- return is_audio_ext or is_elevenlabs or 'audio' in url.lower()
174
  except:
175
  return False
176
 
@@ -183,37 +338,142 @@ class OmniAvatarAPI:
183
  except:
184
  return False
185
 
186
- async def generate_avatar(self, request: GenerateRequest) -> tuple[str, float, bool]:
187
- """Generate avatar video from prompt and audio/text"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  import time
189
  start_time = time.time()
190
  audio_generated = False
 
191
 
192
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  # Determine audio source
194
  audio_path = None
195
 
196
  if request.text_to_speech:
197
- # Generate speech from text using ElevenLabs
198
  logger.info(f"Generating speech from text: {request.text_to_speech[:50]}...")
199
- audio_path = await self.elevenlabs_client.text_to_speech(
200
  request.text_to_speech,
201
  request.voice_id or "21m00Tcm4TlvDq8ikWAM"
202
  )
203
  audio_generated = True
204
 
205
- elif request.elevenlabs_audio_url:
206
  # Download audio from provided URL
207
- logger.info(f"Downloading audio from URL: {request.elevenlabs_audio_url}")
208
- if not self.validate_audio_url(str(request.elevenlabs_audio_url)):
209
- logger.warning(f"Audio URL may not be valid: {request.elevenlabs_audio_url}")
210
 
211
- audio_path = await self.download_file(str(request.elevenlabs_audio_url), ".mp3")
 
212
 
213
  else:
214
  raise HTTPException(
215
  status_code=400,
216
- detail="Either text_to_speech or elevenlabs_audio_url must be provided"
217
  )
218
 
219
  # Download image if provided
@@ -276,7 +536,7 @@ class OmniAvatarAPI:
276
  video_files.sort(key=lambda x: os.path.getmtime(os.path.join(output_dir, x)), reverse=True)
277
  output_path = os.path.join(output_dir, video_files[0])
278
  processing_time = time.time() - start_time
279
- return output_path, processing_time, audio_generated
280
 
281
  raise Exception("No output video generated")
282
 
@@ -298,50 +558,99 @@ class OmniAvatarAPI:
298
  # Initialize API
299
  omni_api = OmniAvatarAPI()
300
 
301
- @app.on_event("startup")
302
- async def startup_event():
303
- """Load model on startup"""
 
 
 
304
  success = omni_api.load_model()
305
  if not success:
306
- logger.warning("Model loading failed on startup")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
 
308
  @app.get("/health")
309
  async def health_check():
310
  """Health check endpoint"""
 
 
311
  return {
312
  "status": "healthy",
313
  "model_loaded": omni_api.model_loaded,
 
 
314
  "device": omni_api.device,
315
- "supports_elevenlabs": True,
316
- "supports_image_urls": True,
317
  "supports_text_to_speech": True,
318
- "elevenlabs_api_configured": bool(omni_api.elevenlabs_client.api_key)
 
 
 
 
 
319
  }
320
 
 
 
 
 
 
 
 
 
 
 
321
  @app.post("/generate", response_model=GenerateResponse)
322
  async def generate_avatar(request: GenerateRequest):
323
  """Generate avatar video from prompt, text/audio, and optional image URL"""
324
 
325
- if not omni_api.model_loaded:
326
- raise HTTPException(status_code=503, detail="Model not loaded")
327
-
328
  logger.info(f"Generating avatar with prompt: {request.prompt}")
329
  if request.text_to_speech:
330
  logger.info(f"Text to speech: {request.text_to_speech[:100]}...")
331
  logger.info(f"Voice ID: {request.voice_id}")
332
- if request.elevenlabs_audio_url:
333
- logger.info(f"Audio URL: {request.elevenlabs_audio_url}")
334
  if request.image_url:
335
  logger.info(f"Image URL: {request.image_url}")
336
 
337
  try:
338
- output_path, processing_time, audio_generated = await omni_api.generate_avatar(request)
339
 
340
  return GenerateResponse(
341
- message="Avatar generation completed successfully",
342
- output_path=output_path,
343
  processing_time=processing_time,
344
- audio_generated=audio_generated
 
345
  )
346
 
347
  except HTTPException:
@@ -350,12 +659,9 @@ async def generate_avatar(request: GenerateRequest):
350
  logger.error(f"Unexpected error: {e}")
351
  raise HTTPException(status_code=500, detail=f"Unexpected error: {e}")
352
 
353
- # Enhanced Gradio interface with text-to-speech option
354
  def gradio_generate(prompt, text_to_speech, audio_url, image_url, voice_id, guidance_scale, audio_scale, num_steps):
355
- """Gradio interface wrapper with text-to-speech support"""
356
- if not omni_api.model_loaded:
357
- return "Error: Model not loaded"
358
-
359
  try:
360
  # Create request object
361
  request_data = {
@@ -370,28 +676,46 @@ def gradio_generate(prompt, text_to_speech, audio_url, image_url, voice_id, guid
370
  request_data["text_to_speech"] = text_to_speech
371
  request_data["voice_id"] = voice_id or "21m00Tcm4TlvDq8ikWAM"
372
  elif audio_url and audio_url.strip():
373
- request_data["elevenlabs_audio_url"] = audio_url
 
 
 
374
  else:
375
  return "Error: Please provide either text to speech or audio URL"
376
 
377
  if image_url and image_url.strip():
378
- request_data["image_url"] = image_url
 
 
 
379
 
380
  request = GenerateRequest(**request_data)
381
 
382
  # Run async function in sync context
383
  loop = asyncio.new_event_loop()
384
  asyncio.set_event_loop(loop)
385
- output_path, processing_time, audio_generated = loop.run_until_complete(omni_api.generate_avatar(request))
386
  loop.close()
387
 
388
- return output_path
 
 
 
 
 
 
389
 
390
  except Exception as e:
391
  logger.error(f"Gradio generation error: {e}")
392
  return f"Error: {str(e)}"
393
 
394
- # Updated Gradio interface with text-to-speech support
 
 
 
 
 
 
395
  iface = gr.Interface(
396
  fn=gradio_generate,
397
  inputs=[
@@ -402,60 +726,71 @@ iface = gr.Interface(
402
  ),
403
  gr.Textbox(
404
  label="Text to Speech",
405
- placeholder="Enter text to convert to speech using ElevenLabs",
406
  lines=3,
407
- info="This will be converted to speech automatically"
408
  ),
409
  gr.Textbox(
410
  label="OR Audio URL",
411
- placeholder="https://api.elevenlabs.io/v1/text-to-speech/...",
412
- info="Direct URL to audio file (alternative to text-to-speech)"
413
  ),
414
  gr.Textbox(
415
  label="Image URL (Optional)",
416
  placeholder="https://example.com/image.jpg",
417
- info="Direct URL to reference image (JPG, PNG, etc.)"
418
  ),
419
  gr.Dropdown(
420
- choices=["21m00Tcm4TlvDq8ikWAM", "pNInz6obpgDQGcFmaJgB", "EXAVITQu4vr4xnSDxMaL"],
 
 
 
 
 
 
 
 
421
  value="21m00Tcm4TlvDq8ikWAM",
422
- label="ElevenLabs Voice ID",
423
- info="Choose voice for text-to-speech"
424
  ),
425
  gr.Slider(minimum=1, maximum=10, value=5.0, label="Guidance Scale", info="4-6 recommended"),
426
  gr.Slider(minimum=1, maximum=10, value=3.0, label="Audio Scale", info="Higher values = better lip-sync"),
427
  gr.Slider(minimum=10, maximum=100, value=30, step=1, label="Number of Steps", info="20-50 recommended")
428
  ],
429
- outputs=gr.Video(label="Generated Avatar Video"),
430
- title="🎭 OmniAvatar-14B with ElevenLabs TTS",
431
- description="""
432
- Generate avatar videos with lip-sync from text prompts and speech.
 
 
 
 
 
 
 
433
 
434
  **Features:**
435
- - **Text-to-Speech**: Enter text to generate speech automatically
436
- - **ElevenLabs Integration**: High-quality voice synthesis
437
- - **Audio URL Support**: Use pre-generated audio files
438
- - **Image URL Support**: Reference images for character appearance
439
- - **Customizable Parameters**: Fine-tune generation quality
 
440
 
441
  **Usage:**
442
  1. Enter a character description in the prompt
443
- 2. **Either** enter text for speech generation **OR** provide an audio URL
444
- 3. Optionally add a reference image URL
445
- 4. Choose voice and adjust parameters
446
- 5. Generate your avatar video!
447
-
448
- **Tips:**
449
- - Use guidance scale 4-6 for best prompt following
450
- - Increase audio scale for better lip-sync
451
- - Clear, descriptive prompts work best
452
  """,
453
  examples=[
454
  [
455
  "A professional teacher explaining a mathematical concept with clear gestures",
456
- "Hello students! Today we're going to learn about calculus and how derivatives work in real life.",
 
457
  "",
458
- "https://example.com/teacher.jpg",
459
  "21m00Tcm4TlvDq8ikWAM",
460
  5.0,
461
  3.5,
@@ -463,7 +798,7 @@ iface = gr.Interface(
463
  ],
464
  [
465
  "A friendly presenter speaking confidently to an audience",
466
- "Welcome everyone to our presentation on artificial intelligence and its applications!",
467
  "",
468
  "",
469
  "pNInz6obpgDQGcFmaJgB",
@@ -471,7 +806,9 @@ iface = gr.Interface(
471
  4.0,
472
  35
473
  ]
474
- ]
 
 
475
  )
476
 
477
  # Mount Gradio app
@@ -480,3 +817,11 @@ app = gr.mount_gradio_app(app, iface, path="/gradio")
480
  if __name__ == "__main__":
481
  import uvicorn
482
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
 
 
 
 
 
 
 
3
  import tempfile
4
  import gradio as gr
5
  from fastapi import FastAPI, HTTPException
6
+ from fastapi.staticfiles import StaticFiles
7
  from fastapi.middleware.cors import CORSMiddleware
8
  from pydantic import BaseModel, HttpUrl
9
  import subprocess
 
26
  logging.basicConfig(level=logging.INFO)
27
  logger = logging.getLogger(__name__)
28
 
29
+ # Set environment variables for matplotlib, gradio, and huggingface cache
30
+ os.environ['MPLCONFIGDIR'] = '/tmp/matplotlib'
31
+ os.environ['GRADIO_ALLOW_FLAGGING'] = 'never'
32
+ os.environ['HF_HOME'] = '/tmp/huggingface'
33
+ # Use HF_HOME instead of deprecated TRANSFORMERS_CACHE
34
+ os.environ['HF_DATASETS_CACHE'] = '/tmp/huggingface/datasets'
35
+ os.environ['HUGGINGFACE_HUB_CACHE'] = '/tmp/huggingface/hub'
36
 
37
+ # FastAPI app will be created after lifespan is defined
38
+
39
+
40
+
41
+ # Create directories with proper permissions
42
+ os.makedirs("outputs", exist_ok=True)
43
+ os.makedirs("/tmp/matplotlib", exist_ok=True)
44
+ os.makedirs("/tmp/huggingface", exist_ok=True)
45
+ os.makedirs("/tmp/huggingface/transformers", exist_ok=True)
46
+ os.makedirs("/tmp/huggingface/datasets", exist_ok=True)
47
+ os.makedirs("/tmp/huggingface/hub", exist_ok=True)
48
+
49
+ # Mount static files for serving generated videos
50
+
51
+
52
+ def get_video_url(output_path: str) -> str:
53
+ """Convert local file path to accessible URL"""
54
+ try:
55
+ from pathlib import Path
56
+ filename = Path(output_path).name
57
+
58
+ # For HuggingFace Spaces, construct the URL
59
+ base_url = "https://bravedims-ai-avatar-chat.hf.space"
60
+ video_url = f"{base_url}/outputs/{filename}"
61
+ logger.info(f"Generated video URL: {video_url}")
62
+ return video_url
63
+ except Exception as e:
64
+ logger.error(f"Error creating video URL: {e}")
65
+ return output_path # Fallback to original path
66
 
67
  # Pydantic models for request/response
68
  class GenerateRequest(BaseModel):
69
  prompt: str
70
  text_to_speech: Optional[str] = None # Text to convert to speech
71
+ audio_url: Optional[HttpUrl] = None # Direct audio URL
72
+ voice_id: Optional[str] = "21m00Tcm4TlvDq8ikWAM" # Voice profile ID
73
  image_url: Optional[HttpUrl] = None
74
  guidance_scale: float = 5.0
75
  audio_scale: float = 3.0
 
82
  output_path: str
83
  processing_time: float
84
  audio_generated: bool = False
85
+ tts_method: Optional[str] = None
86
 
87
+ # Try to import TTS clients, but make them optional
88
+ try:
89
+ from advanced_tts_client import AdvancedTTSClient
90
+ ADVANCED_TTS_AVAILABLE = True
91
+ logger.info("SUCCESS: Advanced TTS client available")
92
+ except ImportError as e:
93
+ ADVANCED_TTS_AVAILABLE = False
94
+ logger.warning(f"WARNING: Advanced TTS client not available: {e}")
95
+
96
+ # Always import the robust fallback
97
+ try:
98
+ from robust_tts_client import RobustTTSClient
99
+ ROBUST_TTS_AVAILABLE = True
100
+ logger.info("SUCCESS: Robust TTS client available")
101
+ except ImportError as e:
102
+ ROBUST_TTS_AVAILABLE = False
103
+ logger.error(f"ERROR: Robust TTS client not available: {e}")
104
+
105
+ class TTSManager:
106
+ """Manages multiple TTS clients with fallback chain"""
107
+
108
+ def __init__(self):
109
+ # Initialize TTS clients based on availability
110
+ self.advanced_tts = None
111
+ self.robust_tts = None
112
+ self.clients_loaded = False
113
 
114
+ if ADVANCED_TTS_AVAILABLE:
115
+ try:
116
+ self.advanced_tts = AdvancedTTSClient()
117
+ logger.info("SUCCESS: Advanced TTS client initialized")
118
+ except Exception as e:
119
+ logger.warning(f"WARNING: Advanced TTS client initialization failed: {e}")
120
 
121
+ if ROBUST_TTS_AVAILABLE:
122
+ try:
123
+ self.robust_tts = RobustTTSClient()
124
+ logger.info("SUCCESS: Robust TTS client initialized")
125
+ except Exception as e:
126
+ logger.error(f"ERROR: Robust TTS client initialization failed: {e}")
127
+
128
+ if not self.advanced_tts and not self.robust_tts:
129
+ logger.error("ERROR: No TTS clients available!")
130
+
131
+ async def load_models(self):
132
+ """Load TTS models"""
133
+ try:
134
+ logger.info("Loading TTS models...")
135
+
136
+ # Try to load advanced TTS first
137
+ if self.advanced_tts:
138
+ try:
139
+ logger.info("[PROCESS] Loading advanced TTS models (this may take a few minutes)...")
140
+ success = await self.advanced_tts.load_models()
141
+ if success:
142
+ logger.info("SUCCESS: Advanced TTS models loaded successfully")
143
+ else:
144
+ logger.warning("WARNING: Advanced TTS models failed to load")
145
+ except Exception as e:
146
+ logger.warning(f"WARNING: Advanced TTS loading error: {e}")
147
+
148
+ # Always ensure robust TTS is available
149
+ if self.robust_tts:
150
+ try:
151
+ await self.robust_tts.load_model()
152
+ logger.info("SUCCESS: Robust TTS fallback ready")
153
+ except Exception as e:
154
+ logger.error(f"ERROR: Robust TTS loading failed: {e}")
155
+
156
+ self.clients_loaded = True
157
+ return True
158
+
159
+ except Exception as e:
160
+ logger.error(f"ERROR: TTS manager initialization failed: {e}")
161
+ return False
162
+
163
+ async def text_to_speech(self, text: str, voice_id: Optional[str] = None) -> tuple[str, str]:
164
+ """
165
+ Convert text to speech with fallback chain
166
+ Returns: (audio_file_path, method_used)
167
+ """
168
+ if not self.clients_loaded:
169
+ logger.info("TTS models not loaded, loading now...")
170
+ await self.load_models()
171
+
172
+ logger.info(f"Generating speech: {text[:50]}...")
173
+ logger.info(f"Voice ID: {voice_id}")
174
+
175
+ # Try Advanced TTS first (Facebook VITS / SpeechT5)
176
+ if self.advanced_tts:
177
+ try:
178
+ audio_path = await self.advanced_tts.text_to_speech(text, voice_id)
179
+ return audio_path, "Facebook VITS/SpeechT5"
180
+ except Exception as advanced_error:
181
+ logger.warning(f"Advanced TTS failed: {advanced_error}")
182
 
183
+ # Fall back to robust TTS
184
+ if self.robust_tts:
185
+ try:
186
+ logger.info("Falling back to robust TTS...")
187
+ audio_path = await self.robust_tts.text_to_speech(text, voice_id)
188
+ return audio_path, "Robust TTS (Fallback)"
189
+ except Exception as robust_error:
190
+ logger.error(f"Robust TTS also failed: {robust_error}")
191
+
192
+ # If we get here, all methods failed
193
+ logger.error("All TTS methods failed!")
194
+ raise HTTPException(
195
+ status_code=500,
196
+ detail="All TTS methods failed. Please check system configuration."
197
+ )
198
+
199
+ async def get_available_voices(self):
200
+ """Get available voice configurations"""
201
+ try:
202
+ if self.advanced_tts and hasattr(self.advanced_tts, 'get_available_voices'):
203
+ return await self.advanced_tts.get_available_voices()
204
+ except:
205
+ pass
206
+
207
+ # Return default voices if advanced TTS not available
208
+ return {
209
+ "21m00Tcm4TlvDq8ikWAM": "Female (Neutral)",
210
+ "pNInz6obpgDQGcFmaJgB": "Male (Professional)",
211
+ "EXAVITQu4vr4xnSDxMaL": "Female (Sweet)",
212
+ "ErXwobaYiN019PkySvjV": "Male (Professional)",
213
+ "TxGEqnHWrfGW9XjX": "Male (Deep)",
214
+ "yoZ06aMxZJJ28mfd3POQ": "Unisex (Friendly)",
215
+ "AZnzlk1XvdvUeBnXmlld": "Female (Strong)"
216
+ }
217
+
218
+ def get_tts_info(self):
219
+ """Get TTS system information"""
220
+ info = {
221
+ "clients_loaded": self.clients_loaded,
222
+ "advanced_tts_available": self.advanced_tts is not None,
223
+ "robust_tts_available": self.robust_tts is not None,
224
+ "primary_method": "Robust TTS"
225
  }
226
 
227
  try:
228
+ if self.advanced_tts and hasattr(self.advanced_tts, 'get_model_info'):
229
+ advanced_info = self.advanced_tts.get_model_info()
230
+ info.update({
231
+ "advanced_tts_loaded": advanced_info.get("models_loaded", False),
232
+ "transformers_available": advanced_info.get("transformers_available", False),
233
+ "primary_method": "Facebook VITS/SpeechT5" if advanced_info.get("models_loaded") else "Robust TTS",
234
+ "device": advanced_info.get("device", "cpu"),
235
+ "vits_available": advanced_info.get("vits_available", False),
236
+ "speecht5_available": advanced_info.get("speecht5_available", False)
237
+ })
 
 
 
 
 
 
 
 
 
 
 
 
238
  except Exception as e:
239
+ logger.debug(f"Could not get advanced TTS info: {e}")
240
+
241
+ return info
242
+
243
+ # Import the VIDEO-FOCUSED engine
244
+ try:
245
+ from omniavatar_video_engine import video_engine
246
+ VIDEO_ENGINE_AVAILABLE = True
247
+ logger.info("SUCCESS: OmniAvatar Video Engine available")
248
+ except ImportError as e:
249
+ VIDEO_ENGINE_AVAILABLE = False
250
+ logger.error(f"ERROR: OmniAvatar Video Engine not available: {e}")
251
 
252
  class OmniAvatarAPI:
253
  def __init__(self):
254
  self.model_loaded = False
255
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
256
+ self.tts_manager = TTSManager()
257
  logger.info(f"Using device: {self.device}")
258
+ logger.info("Initialized with robust TTS system")
259
 
260
  def load_model(self):
261
+ """Load the OmniAvatar model - now more flexible"""
262
  try:
263
+ # Check if models are downloaded (but don't require them)
264
  model_paths = [
265
  "./pretrained_models/Wan2.1-T2V-14B",
266
  "./pretrained_models/OmniAvatar-14B",
267
  "./pretrained_models/wav2vec2-base-960h"
268
  ]
269
 
270
+ missing_models = []
271
  for path in model_paths:
272
  if not os.path.exists(path):
273
+ missing_models.append(path)
 
 
 
 
 
274
 
275
+ if missing_models:
276
+ logger.warning("WARNING: Some OmniAvatar models not found:")
277
+ for model in missing_models:
278
+ logger.warning(f" - {model}")
279
+ logger.info("TIP: App will run in TTS-only mode (no video generation)")
280
+ logger.info("TIP: To enable full avatar generation, download the required models")
281
+
282
+ # Set as loaded but in limited mode
283
+ self.model_loaded = False # Video generation disabled
284
+ return True # But app can still run
285
+ else:
286
+ self.model_loaded = True
287
+ logger.info("SUCCESS: All OmniAvatar models found - full functionality enabled")
288
+ return True
289
+
290
  except Exception as e:
291
+ logger.error(f"Error checking models: {str(e)}")
292
+ logger.info("TIP: Continuing in TTS-only mode")
293
+ self.model_loaded = False
294
+ return True # Continue running
295
 
296
  async def download_file(self, url: str, suffix: str = "") -> str:
297
  """Download file from URL and save to temporary location"""
 
321
  """Validate if URL is likely an audio file"""
322
  try:
323
  parsed = urlparse(url)
324
+ # Check for common audio file extensions
325
+ audio_extensions = ['.mp3', '.wav', '.m4a', '.ogg', '.aac', '.flac']
326
  is_audio_ext = any(parsed.path.lower().endswith(ext) for ext in audio_extensions)
 
327
 
328
+ return is_audio_ext or 'audio' in url.lower()
329
  except:
330
  return False
331
 
 
338
  except:
339
  return False
340
 
341
+ async def generate_avatar(self, request: GenerateRequest) -> tuple[str, float, bool, str]:
342
+ """Generate avatar VIDEO - PRIMARY FUNCTIONALITY"""
343
+ import time
344
+ start_time = time.time()
345
+ audio_generated = False
346
+ method_used = "Unknown"
347
+
348
+ logger.info("[VIDEO] STARTING AVATAR VIDEO GENERATION")
349
+ logger.info(f"[INFO] Prompt: {request.prompt}")
350
+
351
+ if VIDEO_ENGINE_AVAILABLE:
352
+ try:
353
+ # PRIORITIZE VIDEO GENERATION
354
+ logger.info("[TARGET] Using OmniAvatar Video Engine for FULL video generation")
355
+
356
+ # Handle audio source
357
+ audio_path = None
358
+ if request.text_to_speech:
359
+ logger.info("[MIC] Generating audio from text...")
360
+ audio_path, method_used = await self.tts_manager.text_to_speech(
361
+ request.text_to_speech,
362
+ request.voice_id or "21m00Tcm4TlvDq8ikWAM"
363
+ )
364
+ audio_generated = True
365
+ elif request.audio_url:
366
+ logger.info("📥 Downloading audio from URL...")
367
+ audio_path = await self.download_file(str(request.audio_url), ".mp3")
368
+ method_used = "External Audio"
369
+ else:
370
+ raise HTTPException(status_code=400, detail="Either text_to_speech or audio_url required for video generation")
371
+
372
+ # Handle image if provided
373
+ image_path = None
374
+ if request.image_url:
375
+ logger.info("[IMAGE] Downloading reference image...")
376
+ parsed = urlparse(str(request.image_url))
377
+ ext = os.path.splitext(parsed.path)[1] or ".jpg"
378
+ image_path = await self.download_file(str(request.image_url), ext)
379
+
380
+ # GENERATE VIDEO using OmniAvatar engine
381
+ logger.info("[VIDEO] Generating avatar video with adaptive body animation...")
382
+ video_path, generation_time = video_engine.generate_avatar_video(
383
+ prompt=request.prompt,
384
+ audio_path=audio_path,
385
+ image_path=image_path,
386
+ guidance_scale=request.guidance_scale,
387
+ audio_scale=request.audio_scale,
388
+ num_steps=request.num_steps
389
+ )
390
+
391
+ processing_time = time.time() - start_time
392
+ logger.info(f"SUCCESS: VIDEO GENERATED successfully in {processing_time:.1f}s")
393
+
394
+ # Cleanup temporary files
395
+ if audio_path and os.path.exists(audio_path):
396
+ os.unlink(audio_path)
397
+ if image_path and os.path.exists(image_path):
398
+ os.unlink(image_path)
399
+
400
+ return video_path, processing_time, audio_generated, f"OmniAvatar Video Generation ({method_used})"
401
+
402
+ except Exception as e:
403
+ logger.error(f"ERROR: Video generation failed: {e}")
404
+ # For a VIDEO generation app, we should NOT fall back to audio-only
405
+ # Instead, provide clear guidance
406
+ if "models" in str(e).lower():
407
+ raise HTTPException(
408
+ status_code=503,
409
+ detail=f"Video generation requires OmniAvatar models (~30GB). Please run model download script. Error: {str(e)}"
410
+ )
411
+ else:
412
+ raise HTTPException(status_code=500, detail=f"Video generation failed: {str(e)}")
413
+
414
+ # If video engine not available, this is a critical error for a VIDEO app
415
+ raise HTTPException(
416
+ status_code=503,
417
+ detail="Video generation engine not available. This application requires OmniAvatar models for video generation."
418
+ )
419
+
420
+ async def generate_avatar_BACKUP(self, request: GenerateRequest) -> tuple[str, float, bool, str]:
421
+ """OLD TTS-ONLY METHOD - kept as backup reference
422
+ """Generate avatar video from prompt and audio/text - now handles missing models"""
423
  import time
424
  start_time = time.time()
425
  audio_generated = False
426
+ tts_method = None
427
 
428
  try:
429
+ # Check if video generation is available
430
+ if not self.model_loaded:
431
+ logger.info("🎙️ Running in TTS-only mode (OmniAvatar models not available)")
432
+
433
+ # Only generate audio, no video
434
+ if request.text_to_speech:
435
+ logger.info(f"Generating speech from text: {request.text_to_speech[:50]}...")
436
+ audio_path, tts_method = await self.tts_manager.text_to_speech(
437
+ request.text_to_speech,
438
+ request.voice_id or "21m00Tcm4TlvDq8ikWAM"
439
+ )
440
+
441
+ # Return the audio file as the "output"
442
+ processing_time = time.time() - start_time
443
+ logger.info(f"SUCCESS: TTS completed in {processing_time:.1f}s using {tts_method}")
444
+ return audio_path, processing_time, True, f"{tts_method} (TTS-only mode)"
445
+ else:
446
+ raise HTTPException(
447
+ status_code=503,
448
+ detail="Video generation unavailable. OmniAvatar models not found. Only TTS from text is supported."
449
+ )
450
+
451
+ # Original video generation logic (when models are available)
452
  # Determine audio source
453
  audio_path = None
454
 
455
  if request.text_to_speech:
456
+ # Generate speech from text using TTS manager
457
  logger.info(f"Generating speech from text: {request.text_to_speech[:50]}...")
458
+ audio_path, tts_method = await self.tts_manager.text_to_speech(
459
  request.text_to_speech,
460
  request.voice_id or "21m00Tcm4TlvDq8ikWAM"
461
  )
462
  audio_generated = True
463
 
464
+ elif request.audio_url:
465
  # Download audio from provided URL
466
+ logger.info(f"Downloading audio from URL: {request.audio_url}")
467
+ if not self.validate_audio_url(str(request.audio_url)):
468
+ logger.warning(f"Audio URL may not be valid: {request.audio_url}")
469
 
470
+ audio_path = await self.download_file(str(request.audio_url), ".mp3")
471
+ tts_method = "External Audio URL"
472
 
473
  else:
474
  raise HTTPException(
475
  status_code=400,
476
+ detail="Either text_to_speech or audio_url must be provided"
477
  )
478
 
479
  # Download image if provided
 
536
  video_files.sort(key=lambda x: os.path.getmtime(os.path.join(output_dir, x)), reverse=True)
537
  output_path = os.path.join(output_dir, video_files[0])
538
  processing_time = time.time() - start_time
539
+ return output_path, processing_time, audio_generated, tts_method
540
 
541
  raise Exception("No output video generated")
542
 
 
558
  # Initialize API
559
  omni_api = OmniAvatarAPI()
560
 
561
+ # Use FastAPI lifespan instead of deprecated on_event
562
+ from contextlib import asynccontextmanager
563
+
564
+ @asynccontextmanager
565
+ async def lifespan(app: FastAPI):
566
+ # Startup
567
  success = omni_api.load_model()
568
  if not success:
569
+ logger.warning("WARNING: OmniAvatar model loading failed - running in limited mode")
570
+
571
+ # Load TTS models
572
+ try:
573
+ await omni_api.tts_manager.load_models()
574
+ logger.info("SUCCESS: TTS models initialization completed")
575
+ except Exception as e:
576
+ logger.error(f"ERROR: TTS initialization failed: {e}")
577
+
578
+ yield
579
+
580
+ # Shutdown (if needed)
581
+ logger.info("Application shutting down...")
582
+
583
+ # Create FastAPI app WITH lifespan parameter
584
+ app = FastAPI(
585
+ title="OmniAvatar-14B API with Advanced TTS",
586
+ version="1.0.0",
587
+ lifespan=lifespan
588
+ )
589
+
590
+ # Add CORS middleware
591
+ app.add_middleware(
592
+ CORSMiddleware,
593
+ allow_origins=["*"],
594
+ allow_credentials=True,
595
+ allow_methods=["*"],
596
+ allow_headers=["*"],
597
+ )
598
+
599
+ # Mount static files for serving generated videos
600
+ app.mount("/outputs", StaticFiles(directory="outputs"), name="outputs")
601
 
602
  @app.get("/health")
603
  async def health_check():
604
  """Health check endpoint"""
605
+ tts_info = omni_api.tts_manager.get_tts_info()
606
+
607
  return {
608
  "status": "healthy",
609
  "model_loaded": omni_api.model_loaded,
610
+ "video_generation_available": omni_api.model_loaded,
611
+ "tts_only_mode": not omni_api.model_loaded,
612
  "device": omni_api.device,
 
 
613
  "supports_text_to_speech": True,
614
+ "supports_image_urls": omni_api.model_loaded,
615
+ "supports_audio_urls": omni_api.model_loaded,
616
+ "tts_system": "Advanced TTS with Robust Fallback",
617
+ "advanced_tts_available": ADVANCED_TTS_AVAILABLE,
618
+ "robust_tts_available": ROBUST_TTS_AVAILABLE,
619
+ **tts_info
620
  }
621
 
622
+ @app.get("/voices")
623
+ async def get_voices():
624
+ """Get available voice configurations"""
625
+ try:
626
+ voices = await omni_api.tts_manager.get_available_voices()
627
+ return {"voices": voices}
628
+ except Exception as e:
629
+ logger.error(f"Error getting voices: {e}")
630
+ return {"error": str(e)}
631
+
632
  @app.post("/generate", response_model=GenerateResponse)
633
  async def generate_avatar(request: GenerateRequest):
634
  """Generate avatar video from prompt, text/audio, and optional image URL"""
635
 
 
 
 
636
  logger.info(f"Generating avatar with prompt: {request.prompt}")
637
  if request.text_to_speech:
638
  logger.info(f"Text to speech: {request.text_to_speech[:100]}...")
639
  logger.info(f"Voice ID: {request.voice_id}")
640
+ if request.audio_url:
641
+ logger.info(f"Audio URL: {request.audio_url}")
642
  if request.image_url:
643
  logger.info(f"Image URL: {request.image_url}")
644
 
645
  try:
646
+ output_path, processing_time, audio_generated, tts_method = await omni_api.generate_avatar(request)
647
 
648
  return GenerateResponse(
649
+ message="Generation completed successfully" + (" (TTS-only mode)" if not omni_api.model_loaded else ""),
650
+ output_path=get_video_url(output_path) if omni_api.model_loaded else output_path,
651
  processing_time=processing_time,
652
+ audio_generated=audio_generated,
653
+ tts_method=tts_method
654
  )
655
 
656
  except HTTPException:
 
659
  logger.error(f"Unexpected error: {e}")
660
  raise HTTPException(status_code=500, detail=f"Unexpected error: {e}")
661
 
662
+ # Enhanced Gradio interface
663
  def gradio_generate(prompt, text_to_speech, audio_url, image_url, voice_id, guidance_scale, audio_scale, num_steps):
664
+ """Gradio interface wrapper with robust TTS support"""
 
 
 
665
  try:
666
  # Create request object
667
  request_data = {
 
676
  request_data["text_to_speech"] = text_to_speech
677
  request_data["voice_id"] = voice_id or "21m00Tcm4TlvDq8ikWAM"
678
  elif audio_url and audio_url.strip():
679
+ if omni_api.model_loaded:
680
+ request_data["audio_url"] = audio_url
681
+ else:
682
+ return "Error: Audio URL input requires full OmniAvatar models. Please use text-to-speech instead."
683
  else:
684
  return "Error: Please provide either text to speech or audio URL"
685
 
686
  if image_url and image_url.strip():
687
+ if omni_api.model_loaded:
688
+ request_data["image_url"] = image_url
689
+ else:
690
+ return "Error: Image URL input requires full OmniAvatar models for video generation."
691
 
692
  request = GenerateRequest(**request_data)
693
 
694
  # Run async function in sync context
695
  loop = asyncio.new_event_loop()
696
  asyncio.set_event_loop(loop)
697
+ output_path, processing_time, audio_generated, tts_method = loop.run_until_complete(omni_api.generate_avatar(request))
698
  loop.close()
699
 
700
+ success_message = f"SUCCESS: Generation completed in {processing_time:.1f}s using {tts_method}"
701
+ print(success_message)
702
+
703
+ if omni_api.model_loaded:
704
+ return output_path
705
+ else:
706
+ return f"🎙️ TTS Audio generated successfully using {tts_method}\nFile: {output_path}\n\nWARNING: Video generation unavailable (OmniAvatar models not found)"
707
 
708
  except Exception as e:
709
  logger.error(f"Gradio generation error: {e}")
710
  return f"Error: {str(e)}"
711
 
712
+ # Create Gradio interface
713
+ mode_info = " (TTS-Only Mode)" if not omni_api.model_loaded else ""
714
+ description_extra = """
715
+ WARNING: Running in TTS-Only Mode - OmniAvatar models not found. Only text-to-speech generation is available.
716
+ To enable full video generation, the required model files need to be downloaded.
717
+ """ if not omni_api.model_loaded else ""
718
+
719
  iface = gr.Interface(
720
  fn=gradio_generate,
721
  inputs=[
 
726
  ),
727
  gr.Textbox(
728
  label="Text to Speech",
729
+ placeholder="Enter text to convert to speech",
730
  lines=3,
731
+ info="Will use best available TTS system (Advanced or Fallback)"
732
  ),
733
  gr.Textbox(
734
  label="OR Audio URL",
735
+ placeholder="https://example.com/audio.mp3",
736
+ info="Direct URL to audio file (requires full models)" if not omni_api.model_loaded else "Direct URL to audio file"
737
  ),
738
  gr.Textbox(
739
  label="Image URL (Optional)",
740
  placeholder="https://example.com/image.jpg",
741
+ info="Direct URL to reference image (requires full models)" if not omni_api.model_loaded else "Direct URL to reference image"
742
  ),
743
  gr.Dropdown(
744
+ choices=[
745
+ "21m00Tcm4TlvDq8ikWAM",
746
+ "pNInz6obpgDQGcFmaJgB",
747
+ "EXAVITQu4vr4xnSDxMaL",
748
+ "ErXwobaYiN019PkySvjV",
749
+ "TxGEqnHWrfGW9XjX",
750
+ "yoZ06aMxZJJ28mfd3POQ",
751
+ "AZnzlk1XvdvUeBnXmlld"
752
+ ],
753
  value="21m00Tcm4TlvDq8ikWAM",
754
+ label="Voice Profile",
755
+ info="Choose voice characteristics for TTS generation"
756
  ),
757
  gr.Slider(minimum=1, maximum=10, value=5.0, label="Guidance Scale", info="4-6 recommended"),
758
  gr.Slider(minimum=1, maximum=10, value=3.0, label="Audio Scale", info="Higher values = better lip-sync"),
759
  gr.Slider(minimum=10, maximum=100, value=30, step=1, label="Number of Steps", info="20-50 recommended")
760
  ],
761
+ outputs=gr.Video(label="Generated Avatar Video") if omni_api.model_loaded else gr.Textbox(label="TTS Output"),
762
+ title="[VIDEO] OmniAvatar-14B - Avatar Video Generation with Adaptive Body Animation",
763
+ description=f"""
764
+ Generate avatar videos with lip-sync from text prompts and speech using robust TTS system.
765
+
766
+ {description_extra}
767
+
768
+ **Robust TTS Architecture**
769
+ - **Primary**: Advanced TTS (Facebook VITS & SpeechT5) if available
770
+ - **Fallback**: Robust tone generation for 100% reliability
771
+ - **Automatic**: Seamless switching between methods
772
 
773
  **Features:**
774
+ - **Guaranteed Generation**: Always produces audio output
775
+ - **No Dependencies**: Works even without advanced models
776
+ - **High Availability**: Multiple fallback layers
777
+ - **Voice Profiles**: Multiple voice characteristics
778
+ - **Audio URL Support**: Use external audio files {"(full models required)" if not omni_api.model_loaded else ""}
779
+ - **Image URL Support**: Reference images for characters {"(full models required)" if not omni_api.model_loaded else ""}
780
 
781
  **Usage:**
782
  1. Enter a character description in the prompt
783
+ 2. **Enter text for speech generation** (recommended in current mode)
784
+ 3. {"Optionally add reference image/audio URLs (requires full models)" if not omni_api.model_loaded else "Optionally add reference image URL and choose audio source"}
785
+ 4. Choose voice profile and adjust parameters
786
+ 5. Generate your {"audio" if not omni_api.model_loaded else "avatar video"}!
 
 
 
 
 
787
  """,
788
  examples=[
789
  [
790
  "A professional teacher explaining a mathematical concept with clear gestures",
791
+ "Hello students! Today we're going to learn about calculus and derivatives.",
792
+ "",
793
  "",
 
794
  "21m00Tcm4TlvDq8ikWAM",
795
  5.0,
796
  3.5,
 
798
  ],
799
  [
800
  "A friendly presenter speaking confidently to an audience",
801
+ "Welcome everyone to our presentation on artificial intelligence!",
802
  "",
803
  "",
804
  "pNInz6obpgDQGcFmaJgB",
 
806
  4.0,
807
  35
808
  ]
809
+ ],
810
+ allow_flagging="never",
811
+ flagging_dir="/tmp/gradio_flagged"
812
  )
813
 
814
  # Mount Gradio app
 
817
  if __name__ == "__main__":
818
  import uvicorn
819
  uvicorn.run(app, host="0.0.0.0", port=7860)
820
+
821
+
822
+
823
+
824
+
825
+
826
+
827
+
app_fixed.py ADDED
@@ -0,0 +1,828 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import tempfile
4
+ import gradio as gr
5
+ from fastapi import FastAPI, HTTPException
6
+ from fastapi.staticfiles import StaticFiles
7
+ from fastapi.middleware.cors import CORSMiddleware
8
+ from pydantic import BaseModel, HttpUrl
9
+ import subprocess
10
+ import json
11
+ from pathlib import Path
12
+ import logging
13
+ import requests
14
+ from urllib.parse import urlparse
15
+ from PIL import Image
16
+ import io
17
+ from typing import Optional
18
+ import aiohttp
19
+ import asyncio
20
+ from dotenv import load_dotenv
21
+
22
+ # Load environment variables
23
+ load_dotenv()
24
+
25
+ # Set up logging
26
+ logging.basicConfig(level=logging.INFO)
27
+ logger = logging.getLogger(__name__)
28
+
29
+ # Set environment variables for matplotlib, gradio, and huggingface cache
30
+ os.environ['MPLCONFIGDIR'] = '/tmp/matplotlib'
31
+ os.environ['GRADIO_ALLOW_FLAGGING'] = 'never'
32
+ os.environ['HF_HOME'] = '/tmp/huggingface'
33
+ # Use HF_HOME instead of deprecated TRANSFORMERS_CACHE
34
+ os.environ['HF_DATASETS_CACHE'] = '/tmp/huggingface/datasets'
35
+ os.environ['HUGGINGFACE_HUB_CACHE'] = '/tmp/huggingface/hub'
36
+
37
+ # FastAPI app will be created after lifespan is defined
38
+
39
+
40
+
41
+ # Create directories with proper permissions
42
+ os.makedirs("outputs", exist_ok=True)
43
+ os.makedirs("/tmp/matplotlib", exist_ok=True)
44
+ os.makedirs("/tmp/huggingface", exist_ok=True)
45
+ os.makedirs("/tmp/huggingface/transformers", exist_ok=True)
46
+ os.makedirs("/tmp/huggingface/datasets", exist_ok=True)
47
+ os.makedirs("/tmp/huggingface/hub", exist_ok=True)
48
+
49
+ # Mount static files for serving generated videos
50
+
51
+
52
+ def get_video_url(output_path: str) -> str:
53
+ """Convert local file path to accessible URL"""
54
+ try:
55
+ from pathlib import Path
56
+ filename = Path(output_path).name
57
+
58
+ # For HuggingFace Spaces, construct the URL
59
+ base_url = "https://bravedims-ai-avatar-chat.hf.space"
60
+ video_url = f"{base_url}/outputs/{filename}"
61
+ logger.info(f"Generated video URL: {video_url}")
62
+ return video_url
63
+ except Exception as e:
64
+ logger.error(f"Error creating video URL: {e}")
65
+ return output_path # Fallback to original path
66
+
67
+ # Pydantic models for request/response
68
+ class GenerateRequest(BaseModel):
69
+ prompt: str
70
+ text_to_speech: Optional[str] = None # Text to convert to speech
71
+ audio_url: Optional[HttpUrl] = None # Direct audio URL
72
+ voice_id: Optional[str] = "21m00Tcm4TlvDq8ikWAM" # Voice profile ID
73
+ image_url: Optional[HttpUrl] = None
74
+ guidance_scale: float = 5.0
75
+ audio_scale: float = 3.0
76
+ num_steps: int = 30
77
+ sp_size: int = 1
78
+ tea_cache_l1_thresh: Optional[float] = None
79
+
80
+ class GenerateResponse(BaseModel):
81
+ message: str
82
+ output_path: str
83
+ processing_time: float
84
+ audio_generated: bool = False
85
+ tts_method: Optional[str] = None
86
+
87
+ # Try to import TTS clients, but make them optional
88
+ try:
89
+ from advanced_tts_client import AdvancedTTSClient
90
+ ADVANCED_TTS_AVAILABLE = True
91
+ logger.info("SUCCESS: Advanced TTS client available")
92
+ except ImportError as e:
93
+ ADVANCED_TTS_AVAILABLE = False
94
+ logger.warning(f"WARNING: Advanced TTS client not available: {e}")
95
+
96
+ # Always import the robust fallback
97
+ try:
98
+ from robust_tts_client import RobustTTSClient
99
+ ROBUST_TTS_AVAILABLE = True
100
+ logger.info("SUCCESS: Robust TTS client available")
101
+ except ImportError as e:
102
+ ROBUST_TTS_AVAILABLE = False
103
+ logger.error(f"ERROR: Robust TTS client not available: {e}")
104
+
105
+ class TTSManager:
106
+ """Manages multiple TTS clients with fallback chain"""
107
+
108
+ def __init__(self):
109
+ # Initialize TTS clients based on availability
110
+ self.advanced_tts = None
111
+ self.robust_tts = None
112
+ self.clients_loaded = False
113
+
114
+ if ADVANCED_TTS_AVAILABLE:
115
+ try:
116
+ self.advanced_tts = AdvancedTTSClient()
117
+ logger.info("SUCCESS: Advanced TTS client initialized")
118
+ except Exception as e:
119
+ logger.warning(f"WARNING: Advanced TTS client initialization failed: {e}")
120
+
121
+ if ROBUST_TTS_AVAILABLE:
122
+ try:
123
+ self.robust_tts = RobustTTSClient()
124
+ logger.info("SUCCESS: Robust TTS client initialized")
125
+ except Exception as e:
126
+ logger.error(f"ERROR: Robust TTS client initialization failed: {e}")
127
+
128
+ if not self.advanced_tts and not self.robust_tts:
129
+ logger.error("ERROR: No TTS clients available!")
130
+
131
+ async def load_models(self):
132
+ """Load TTS models"""
133
+ try:
134
+ logger.info("Loading TTS models...")
135
+
136
+ # Try to load advanced TTS first
137
+ if self.advanced_tts:
138
+ try:
139
+ logger.info("[PROCESS] Loading advanced TTS models (this may take a few minutes)...")
140
+ success = await self.advanced_tts.load_models()
141
+ if success:
142
+ logger.info("SUCCESS: Advanced TTS models loaded successfully")
143
+ else:
144
+ logger.warning("WARNING: Advanced TTS models failed to load")
145
+ except Exception as e:
146
+ logger.warning(f"WARNING: Advanced TTS loading error: {e}")
147
+
148
+ # Always ensure robust TTS is available
149
+ if self.robust_tts:
150
+ try:
151
+ await self.robust_tts.load_model()
152
+ logger.info("SUCCESS: Robust TTS fallback ready")
153
+ except Exception as e:
154
+ logger.error(f"ERROR: Robust TTS loading failed: {e}")
155
+
156
+ self.clients_loaded = True
157
+ return True
158
+
159
+ except Exception as e:
160
+ logger.error(f"ERROR: TTS manager initialization failed: {e}")
161
+ return False
162
+
163
+ async def text_to_speech(self, text: str, voice_id: Optional[str] = None) -> tuple[str, str]:
164
+ """
165
+ Convert text to speech with fallback chain
166
+ Returns: (audio_file_path, method_used)
167
+ """
168
+ if not self.clients_loaded:
169
+ logger.info("TTS models not loaded, loading now...")
170
+ await self.load_models()
171
+
172
+ logger.info(f"Generating speech: {text[:50]}...")
173
+ logger.info(f"Voice ID: {voice_id}")
174
+
175
+ # Try Advanced TTS first (Facebook VITS / SpeechT5)
176
+ if self.advanced_tts:
177
+ try:
178
+ audio_path = await self.advanced_tts.text_to_speech(text, voice_id)
179
+ return audio_path, "Facebook VITS/SpeechT5"
180
+ except Exception as advanced_error:
181
+ logger.warning(f"Advanced TTS failed: {advanced_error}")
182
+
183
+ # Fall back to robust TTS
184
+ if self.robust_tts:
185
+ try:
186
+ logger.info("Falling back to robust TTS...")
187
+ audio_path = await self.robust_tts.text_to_speech(text, voice_id)
188
+ return audio_path, "Robust TTS (Fallback)"
189
+ except Exception as robust_error:
190
+ logger.error(f"Robust TTS also failed: {robust_error}")
191
+
192
+ # If we get here, all methods failed
193
+ logger.error("All TTS methods failed!")
194
+ raise HTTPException(
195
+ status_code=500,
196
+ detail="All TTS methods failed. Please check system configuration."
197
+ )
198
+
199
+ async def get_available_voices(self):
200
+ """Get available voice configurations"""
201
+ try:
202
+ if self.advanced_tts and hasattr(self.advanced_tts, 'get_available_voices'):
203
+ return await self.advanced_tts.get_available_voices()
204
+ except:
205
+ pass
206
+
207
+ # Return default voices if advanced TTS not available
208
+ return {
209
+ "21m00Tcm4TlvDq8ikWAM": "Female (Neutral)",
210
+ "pNInz6obpgDQGcFmaJgB": "Male (Professional)",
211
+ "EXAVITQu4vr4xnSDxMaL": "Female (Sweet)",
212
+ "ErXwobaYiN019PkySvjV": "Male (Professional)",
213
+ "TxGEqnHWrfGW9XjX": "Male (Deep)",
214
+ "yoZ06aMxZJJ28mfd3POQ": "Unisex (Friendly)",
215
+ "AZnzlk1XvdvUeBnXmlld": "Female (Strong)"
216
+ }
217
+
218
+ def get_tts_info(self):
219
+ """Get TTS system information"""
220
+ info = {
221
+ "clients_loaded": self.clients_loaded,
222
+ "advanced_tts_available": self.advanced_tts is not None,
223
+ "robust_tts_available": self.robust_tts is not None,
224
+ "primary_method": "Robust TTS"
225
+ }
226
+
227
+ try:
228
+ if self.advanced_tts and hasattr(self.advanced_tts, 'get_model_info'):
229
+ advanced_info = self.advanced_tts.get_model_info()
230
+ info.update({
231
+ "advanced_tts_loaded": advanced_info.get("models_loaded", False),
232
+ "transformers_available": advanced_info.get("transformers_available", False),
233
+ "primary_method": "Facebook VITS/SpeechT5" if advanced_info.get("models_loaded") else "Robust TTS",
234
+ "device": advanced_info.get("device", "cpu"),
235
+ "vits_available": advanced_info.get("vits_available", False),
236
+ "speecht5_available": advanced_info.get("speecht5_available", False)
237
+ })
238
+ except Exception as e:
239
+ logger.debug(f"Could not get advanced TTS info: {e}")
240
+
241
+ return info
242
+
243
+ # Import the VIDEO-FOCUSED engine
244
+ try:
245
+ from omniavatar_video_engine import video_engine
246
+ VIDEO_ENGINE_AVAILABLE = True
247
+ logger.info("SUCCESS: OmniAvatar Video Engine available")
248
+ except ImportError as e:
249
+ VIDEO_ENGINE_AVAILABLE = False
250
+ logger.error(f"ERROR: OmniAvatar Video Engine not available: {e}")
251
+
252
+ class OmniAvatarAPI:
253
+ def __init__(self):
254
+ self.model_loaded = False
255
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
256
+ self.tts_manager = TTSManager()
257
+ logger.info(f"Using device: {self.device}")
258
+ logger.info("Initialized with robust TTS system")
259
+
260
+ def load_model(self):
261
+ """Load the OmniAvatar model - now more flexible"""
262
+ try:
263
+ # Check if models are downloaded (but don't require them)
264
+ model_paths = [
265
+ "./pretrained_models/Wan2.1-T2V-14B",
266
+ "./pretrained_models/OmniAvatar-14B",
267
+ "./pretrained_models/wav2vec2-base-960h"
268
+ ]
269
+
270
+ missing_models = []
271
+ for path in model_paths:
272
+ if not os.path.exists(path):
273
+ missing_models.append(path)
274
+
275
+ if missing_models:
276
+ logger.warning("WARNING: Some OmniAvatar models not found:")
277
+ for model in missing_models:
278
+ logger.warning(f" - {model}")
279
+ logger.info("TIP: App will run in TTS-only mode (no video generation)")
280
+ logger.info("TIP: To enable full avatar generation, download the required models")
281
+
282
+ # Set as loaded but in limited mode
283
+ self.model_loaded = False # Video generation disabled
284
+ return True # But app can still run
285
+ else:
286
+ self.model_loaded = True
287
+ logger.info("SUCCESS: All OmniAvatar models found - full functionality enabled")
288
+ return True
289
+
290
+ except Exception as e:
291
+ logger.error(f"Error checking models: {str(e)}")
292
+ logger.info("TIP: Continuing in TTS-only mode")
293
+ self.model_loaded = False
294
+ return True # Continue running
295
+
296
+ async def download_file(self, url: str, suffix: str = "") -> str:
297
+ """Download file from URL and save to temporary location"""
298
+ try:
299
+ async with aiohttp.ClientSession() as session:
300
+ async with session.get(str(url)) as response:
301
+ if response.status != 200:
302
+ raise HTTPException(status_code=400, detail=f"Failed to download file from URL: {url}")
303
+
304
+ content = await response.read()
305
+
306
+ # Create temporary file
307
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=suffix)
308
+ temp_file.write(content)
309
+ temp_file.close()
310
+
311
+ return temp_file.name
312
+
313
+ except aiohttp.ClientError as e:
314
+ logger.error(f"Network error downloading {url}: {e}")
315
+ raise HTTPException(status_code=400, detail=f"Network error downloading file: {e}")
316
+ except Exception as e:
317
+ logger.error(f"Error downloading file from {url}: {e}")
318
+ raise HTTPException(status_code=500, detail=f"Error downloading file: {e}")
319
+
320
+ def validate_audio_url(self, url: str) -> bool:
321
+ """Validate if URL is likely an audio file"""
322
+ try:
323
+ parsed = urlparse(url)
324
+ # Check for common audio file extensions
325
+ audio_extensions = ['.mp3', '.wav', '.m4a', '.ogg', '.aac', '.flac']
326
+ is_audio_ext = any(parsed.path.lower().endswith(ext) for ext in audio_extensions)
327
+
328
+ return is_audio_ext or 'audio' in url.lower()
329
+ except:
330
+ return False
331
+
332
+ def validate_image_url(self, url: str) -> bool:
333
+ """Validate if URL is likely an image file"""
334
+ try:
335
+ parsed = urlparse(url)
336
+ image_extensions = ['.jpg', '.jpeg', '.png', '.webp', '.bmp', '.gif']
337
+ return any(parsed.path.lower().endswith(ext) for ext in image_extensions)
338
+ except:
339
+ return False
340
+
341
+ async def generate_avatar(self, request: GenerateRequest) -> tuple[str, float, bool, str]:
342
+ """Generate avatar VIDEO - PRIMARY FUNCTIONALITY"""
343
+ import time
344
+ start_time = time.time()
345
+ audio_generated = False
346
+ method_used = "Unknown"
347
+
348
+ logger.info("[VIDEO] STARTING AVATAR VIDEO GENERATION")
349
+ logger.info(f"[INFO] Prompt: {request.prompt}")
350
+
351
+ if VIDEO_ENGINE_AVAILABLE:
352
+ try:
353
+ # PRIORITIZE VIDEO GENERATION
354
+ logger.info("[TARGET] Using OmniAvatar Video Engine for FULL video generation")
355
+
356
+ # Handle audio source
357
+ audio_path = None
358
+ if request.text_to_speech:
359
+ logger.info("[MIC] Generating audio from text...")
360
+ audio_path, method_used = await self.tts_manager.text_to_speech(
361
+ request.text_to_speech,
362
+ request.voice_id or "21m00Tcm4TlvDq8ikWAM"
363
+ )
364
+ audio_generated = True
365
+ elif request.audio_url:
366
+ logger.info("?? Downloading audio from URL...")
367
+ audio_path = await self.download_file(str(request.audio_url), ".mp3")
368
+ method_used = "External Audio"
369
+ else:
370
+ raise HTTPException(status_code=400, detail="Either text_to_speech or audio_url required for video generation")
371
+
372
+ # Handle image if provided
373
+ image_path = None
374
+ if request.image_url:
375
+ logger.info("[IMAGE] Downloading reference image...")
376
+ parsed = urlparse(str(request.image_url))
377
+ ext = os.path.splitext(parsed.path)[1] or ".jpg"
378
+ image_path = await self.download_file(str(request.image_url), ext)
379
+
380
+ # GENERATE VIDEO using OmniAvatar engine
381
+ logger.info("[VIDEO] Generating avatar video with adaptive body animation...")
382
+ video_path, generation_time = video_engine.generate_avatar_video(
383
+ prompt=request.prompt,
384
+ audio_path=audio_path,
385
+ image_path=image_path,
386
+ guidance_scale=request.guidance_scale,
387
+ audio_scale=request.audio_scale,
388
+ num_steps=request.num_steps
389
+ )
390
+
391
+ processing_time = time.time() - start_time
392
+ logger.info(f"SUCCESS: VIDEO GENERATED successfully in {processing_time:.1f}s")
393
+
394
+ # Cleanup temporary files
395
+ if audio_path and os.path.exists(audio_path):
396
+ os.unlink(audio_path)
397
+ if image_path and os.path.exists(image_path):
398
+ os.unlink(image_path)
399
+
400
+ return video_path, processing_time, audio_generated, f"OmniAvatar Video Generation ({method_used})"
401
+
402
+ except Exception as e:
403
+ logger.error(f"ERROR: Video generation failed: {e}")
404
+ # For a VIDEO generation app, we should NOT fall back to audio-only
405
+ # Instead, provide clear guidance
406
+ if "models" in str(e).lower():
407
+ raise HTTPException(
408
+ status_code=503,
409
+ detail=f"Video generation requires OmniAvatar models (~30GB). Please run model download script. Error: {str(e)}"
410
+ )
411
+ else:
412
+ raise HTTPException(status_code=500, detail=f"Video generation failed: {str(e)}")
413
+
414
+ # If video engine not available, this is a critical error for a VIDEO app
415
+ raise HTTPException(
416
+ status_code=503,
417
+ detail="Video generation engine not available. This application requires OmniAvatar models for video generation."
418
+ )
419
+
420
+ async def generate_avatar_BACKUP(self, request: GenerateRequest) -> tuple[str, float, bool, str]:
421
+ """OLD TTS-ONLY METHOD - kept as backup reference
422
+ """Generate avatar video from prompt and audio/text - now handles missing models"""
423
+ import time
424
+ start_time = time.time()
425
+ audio_generated = False
426
+ tts_method = None
427
+
428
+ try:
429
+ # Check if video generation is available
430
+ if not self.model_loaded:
431
+ logger.info("??? Running in TTS-only mode (OmniAvatar models not available)")
432
+
433
+ # Only generate audio, no video
434
+ if request.text_to_speech:
435
+ logger.info(f"Generating speech from text: {request.text_to_speech[:50]}...")
436
+ audio_path, tts_method = await self.tts_manager.text_to_speech(
437
+ request.text_to_speech,
438
+ request.voice_id or "21m00Tcm4TlvDq8ikWAM"
439
+ )
440
+
441
+ # Return the audio file as the "output"
442
+ processing_time = time.time() - start_time
443
+ logger.info(f"SUCCESS: TTS completed in {processing_time:.1f}s using {tts_method}")
444
+ return audio_path, processing_time, True, f"{tts_method} (TTS-only mode)"
445
+ else:
446
+ raise HTTPException(
447
+ status_code=503,
448
+ detail="Video generation unavailable. OmniAvatar models not found. Only TTS from text is supported."
449
+ )
450
+
451
+ # Original video generation logic (when models are available)
452
+ # Determine audio source
453
+ audio_path = None
454
+
455
+ if request.text_to_speech:
456
+ # Generate speech from text using TTS manager
457
+ logger.info(f"Generating speech from text: {request.text_to_speech[:50]}...")
458
+ audio_path, tts_method = await self.tts_manager.text_to_speech(
459
+ request.text_to_speech,
460
+ request.voice_id or "21m00Tcm4TlvDq8ikWAM"
461
+ )
462
+ audio_generated = True
463
+
464
+ elif request.audio_url:
465
+ # Download audio from provided URL
466
+ logger.info(f"Downloading audio from URL: {request.audio_url}")
467
+ if not self.validate_audio_url(str(request.audio_url)):
468
+ logger.warning(f"Audio URL may not be valid: {request.audio_url}")
469
+
470
+ audio_path = await self.download_file(str(request.audio_url), ".mp3")
471
+ tts_method = "External Audio URL"
472
+
473
+ else:
474
+ raise HTTPException(
475
+ status_code=400,
476
+ detail="Either text_to_speech or audio_url must be provided"
477
+ )
478
+
479
+ # Download image if provided
480
+ image_path = None
481
+ if request.image_url:
482
+ logger.info(f"Downloading image from URL: {request.image_url}")
483
+ if not self.validate_image_url(str(request.image_url)):
484
+ logger.warning(f"Image URL may not be valid: {request.image_url}")
485
+
486
+ # Determine image extension from URL or default to .jpg
487
+ parsed = urlparse(str(request.image_url))
488
+ ext = os.path.splitext(parsed.path)[1] or ".jpg"
489
+ image_path = await self.download_file(str(request.image_url), ext)
490
+
491
+ # Create temporary input file for inference
492
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f:
493
+ if image_path:
494
+ input_line = f"{request.prompt}@@{image_path}@@{audio_path}"
495
+ else:
496
+ input_line = f"{request.prompt}@@@@{audio_path}"
497
+ f.write(input_line)
498
+ temp_input_file = f.name
499
+
500
+ # Prepare inference command
501
+ cmd = [
502
+ "python", "-m", "torch.distributed.run",
503
+ "--standalone", f"--nproc_per_node={request.sp_size}",
504
+ "scripts/inference.py",
505
+ "--config", "configs/inference.yaml",
506
+ "--input_file", temp_input_file,
507
+ "--guidance_scale", str(request.guidance_scale),
508
+ "--audio_scale", str(request.audio_scale),
509
+ "--num_steps", str(request.num_steps)
510
+ ]
511
+
512
+ if request.tea_cache_l1_thresh:
513
+ cmd.extend(["--tea_cache_l1_thresh", str(request.tea_cache_l1_thresh)])
514
+
515
+ logger.info(f"Running inference with command: {' '.join(cmd)}")
516
+
517
+ # Run inference
518
+ result = subprocess.run(cmd, capture_output=True, text=True)
519
+
520
+ # Clean up temporary files
521
+ os.unlink(temp_input_file)
522
+ os.unlink(audio_path)
523
+ if image_path:
524
+ os.unlink(image_path)
525
+
526
+ if result.returncode != 0:
527
+ logger.error(f"Inference failed: {result.stderr}")
528
+ raise Exception(f"Inference failed: {result.stderr}")
529
+
530
+ # Find output video file
531
+ output_dir = "./outputs"
532
+ if os.path.exists(output_dir):
533
+ video_files = [f for f in os.listdir(output_dir) if f.endswith(('.mp4', '.avi'))]
534
+ if video_files:
535
+ # Return the most recent video file
536
+ video_files.sort(key=lambda x: os.path.getmtime(os.path.join(output_dir, x)), reverse=True)
537
+ output_path = os.path.join(output_dir, video_files[0])
538
+ processing_time = time.time() - start_time
539
+ return output_path, processing_time, audio_generated, tts_method
540
+
541
+ raise Exception("No output video generated")
542
+
543
+ except Exception as e:
544
+ # Clean up any temporary files in case of error
545
+ try:
546
+ if 'audio_path' in locals() and audio_path and os.path.exists(audio_path):
547
+ os.unlink(audio_path)
548
+ if 'image_path' in locals() and image_path and os.path.exists(image_path):
549
+ os.unlink(image_path)
550
+ if 'temp_input_file' in locals() and os.path.exists(temp_input_file):
551
+ os.unlink(temp_input_file)
552
+ except:
553
+ pass
554
+
555
+ logger.error(f"Generation error: {str(e)}")
556
+ raise HTTPException(status_code=500, detail=str(e))
557
+
558
+ # Initialize API
559
+ omni_api = OmniAvatarAPI()
560
+
561
+ # Use FastAPI lifespan instead of deprecated on_event
562
+ from contextlib import asynccontextmanager
563
+
564
+ @asynccontextmanager
565
+ async def lifespan(app: FastAPI):
566
+ # Startup
567
+ success = omni_api.load_model()
568
+ if not success:
569
+ logger.warning("WARNING: OmniAvatar model loading failed - running in limited mode")
570
+
571
+ # Load TTS models
572
+ try:
573
+ await omni_api.tts_manager.load_models()
574
+ logger.info("SUCCESS: TTS models initialization completed")
575
+ except Exception as e:
576
+ logger.error(f"ERROR: TTS initialization failed: {e}")
577
+
578
+ yield
579
+
580
+ # Shutdown (if needed)
581
+ logger.info("Application shutting down...")
582
+
583
+ # Create FastAPI app WITH lifespan parameter
584
+ app = FastAPI(
585
+ title="OmniAvatar-14B API with Advanced TTS",
586
+ version="1.0.0",
587
+ lifespan=lifespan
588
+ )
589
+
590
+ # Add CORS middleware
591
+ app.add_middleware(
592
+ CORSMiddleware,
593
+ allow_origins=["*"],
594
+ allow_credentials=True,
595
+ allow_methods=["*"],
596
+ allow_headers=["*"],
597
+ )
598
+
599
+ # Mount static files for serving generated videos
600
+ app.mount("/outputs", StaticFiles(directory="outputs"), name="outputs")
601
+
602
+ @app.get("/health")
603
+ async def health_check():
604
+ """Health check endpoint"""
605
+ tts_info = omni_api.tts_manager.get_tts_info()
606
+
607
+ return {
608
+ "status": "healthy",
609
+ "model_loaded": omni_api.model_loaded,
610
+ "video_generation_available": omni_api.model_loaded,
611
+ "tts_only_mode": not omni_api.model_loaded,
612
+ "device": omni_api.device,
613
+ "supports_text_to_speech": True,
614
+ "supports_image_urls": omni_api.model_loaded,
615
+ "supports_audio_urls": omni_api.model_loaded,
616
+ "tts_system": "Advanced TTS with Robust Fallback",
617
+ "advanced_tts_available": ADVANCED_TTS_AVAILABLE,
618
+ "robust_tts_available": ROBUST_TTS_AVAILABLE,
619
+ **tts_info
620
+ }
621
+
622
+ @app.get("/voices")
623
+ async def get_voices():
624
+ """Get available voice configurations"""
625
+ try:
626
+ voices = await omni_api.tts_manager.get_available_voices()
627
+ return {"voices": voices}
628
+ except Exception as e:
629
+ logger.error(f"Error getting voices: {e}")
630
+ return {"error": str(e)}
631
+
632
+ @app.post("/generate", response_model=GenerateResponse)
633
+ async def generate_avatar(request: GenerateRequest):
634
+ """Generate avatar video from prompt, text/audio, and optional image URL"""
635
+
636
+ logger.info(f"Generating avatar with prompt: {request.prompt}")
637
+ if request.text_to_speech:
638
+ logger.info(f"Text to speech: {request.text_to_speech[:100]}...")
639
+ logger.info(f"Voice ID: {request.voice_id}")
640
+ if request.audio_url:
641
+ logger.info(f"Audio URL: {request.audio_url}")
642
+ if request.image_url:
643
+ logger.info(f"Image URL: {request.image_url}")
644
+
645
+ try:
646
+ output_path, processing_time, audio_generated, tts_method = await omni_api.generate_avatar(request)
647
+
648
+ return GenerateResponse(
649
+ message="Generation completed successfully" + (" (TTS-only mode)" if not omni_api.model_loaded else ""),
650
+ output_path=get_video_url(output_path) if omni_api.model_loaded else output_path,
651
+ processing_time=processing_time,
652
+ audio_generated=audio_generated,
653
+ tts_method=tts_method
654
+ )
655
+
656
+ except HTTPException:
657
+ raise
658
+ except Exception as e:
659
+ logger.error(f"Unexpected error: {e}")
660
+ raise HTTPException(status_code=500, detail=f"Unexpected error: {e}")
661
+
662
+ # Enhanced Gradio interface
663
+ def gradio_generate(prompt, text_to_speech, audio_url, image_url, voice_id, guidance_scale, audio_scale, num_steps):
664
+ """Gradio interface wrapper with robust TTS support"""
665
+ try:
666
+ # Create request object
667
+ request_data = {
668
+ "prompt": prompt,
669
+ "guidance_scale": guidance_scale,
670
+ "audio_scale": audio_scale,
671
+ "num_steps": int(num_steps)
672
+ }
673
+
674
+ # Add audio source
675
+ if text_to_speech and text_to_speech.strip():
676
+ request_data["text_to_speech"] = text_to_speech
677
+ request_data["voice_id"] = voice_id or "21m00Tcm4TlvDq8ikWAM"
678
+ elif audio_url and audio_url.strip():
679
+ if omni_api.model_loaded:
680
+ request_data["audio_url"] = audio_url
681
+ else:
682
+ return "Error: Audio URL input requires full OmniAvatar models. Please use text-to-speech instead."
683
+ else:
684
+ return "Error: Please provide either text to speech or audio URL"
685
+
686
+ if image_url and image_url.strip():
687
+ if omni_api.model_loaded:
688
+ request_data["image_url"] = image_url
689
+ else:
690
+ return "Error: Image URL input requires full OmniAvatar models for video generation."
691
+
692
+ request = GenerateRequest(**request_data)
693
+
694
+ # Run async function in sync context
695
+ loop = asyncio.new_event_loop()
696
+ asyncio.set_event_loop(loop)
697
+ output_path, processing_time, audio_generated, tts_method = loop.run_until_complete(omni_api.generate_avatar(request))
698
+ loop.close()
699
+
700
+ success_message = f"SUCCESS: Generation completed in {processing_time:.1f}s using {tts_method}"
701
+ print(success_message)
702
+
703
+ if omni_api.model_loaded:
704
+ return output_path
705
+ else:
706
+ return f"??? TTS Audio generated successfully using {tts_method}\nFile: {output_path}\n\nWARNING: Video generation unavailable (OmniAvatar models not found)"
707
+
708
+ except Exception as e:
709
+ logger.error(f"Gradio generation error: {e}")
710
+ return f"Error: {str(e)}"
711
+
712
+ # Create Gradio interface
713
+ mode_info = " (TTS-Only Mode)" if not omni_api.model_loaded else ""
714
+ description_extra = """
715
+ WARNING: Running in TTS-Only Mode - OmniAvatar models not found. Only text-to-speech generation is available.
716
+ To enable full video generation, the required model files need to be downloaded.
717
+ """ if not omni_api.model_loaded else ""
718
+
719
+ iface = gr.Interface(
720
+ fn=gradio_generate,
721
+ inputs=[
722
+ gr.Textbox(
723
+ label="Prompt",
724
+ placeholder="Describe the character behavior (e.g., 'A friendly person explaining a concept')",
725
+ lines=2
726
+ ),
727
+ gr.Textbox(
728
+ label="Text to Speech",
729
+ placeholder="Enter text to convert to speech",
730
+ lines=3,
731
+ info="Will use best available TTS system (Advanced or Fallback)"
732
+ ),
733
+ gr.Textbox(
734
+ label="OR Audio URL",
735
+ placeholder="https://example.com/audio.mp3",
736
+ info="Direct URL to audio file (requires full models)" if not omni_api.model_loaded else "Direct URL to audio file"
737
+ ),
738
+ gr.Textbox(
739
+ label="Image URL (Optional)",
740
+ placeholder="https://example.com/image.jpg",
741
+ info="Direct URL to reference image (requires full models)" if not omni_api.model_loaded else "Direct URL to reference image"
742
+ ),
743
+ gr.Dropdown(
744
+ choices=[
745
+ "21m00Tcm4TlvDq8ikWAM",
746
+ "pNInz6obpgDQGcFmaJgB",
747
+ "EXAVITQu4vr4xnSDxMaL",
748
+ "ErXwobaYiN019PkySvjV",
749
+ "TxGEqnHWrfGW9XjX",
750
+ "yoZ06aMxZJJ28mfd3POQ",
751
+ "AZnzlk1XvdvUeBnXmlld"
752
+ ],
753
+ value="21m00Tcm4TlvDq8ikWAM",
754
+ label="Voice Profile",
755
+ info="Choose voice characteristics for TTS generation"
756
+ ),
757
+ gr.Slider(minimum=1, maximum=10, value=5.0, label="Guidance Scale", info="4-6 recommended"),
758
+ gr.Slider(minimum=1, maximum=10, value=3.0, label="Audio Scale", info="Higher values = better lip-sync"),
759
+ gr.Slider(minimum=10, maximum=100, value=30, step=1, label="Number of Steps", info="20-50 recommended")
760
+ ],
761
+ outputs=gr.Video(label="Generated Avatar Video") if omni_api.model_loaded else gr.Textbox(label="TTS Output"),
762
+ title="[VIDEO] OmniAvatar-14B - Avatar Video Generation with Adaptive Body Animation",
763
+ description=f"""
764
+ Generate avatar videos with lip-sync from text prompts and speech using robust TTS system.
765
+
766
+ {description_extra}
767
+
768
+ **Robust TTS Architecture**
769
+ - **Primary**: Advanced TTS (Facebook VITS & SpeechT5) if available
770
+ - **Fallback**: Robust tone generation for 100% reliability
771
+ - **Automatic**: Seamless switching between methods
772
+
773
+ **Features:**
774
+ - **Guaranteed Generation**: Always produces audio output
775
+ - **No Dependencies**: Works even without advanced models
776
+ - **High Availability**: Multiple fallback layers
777
+ - **Voice Profiles**: Multiple voice characteristics
778
+ - **Audio URL Support**: Use external audio files {"(full models required)" if not omni_api.model_loaded else ""}
779
+ - **Image URL Support**: Reference images for characters {"(full models required)" if not omni_api.model_loaded else ""}
780
+
781
+ **Usage:**
782
+ 1. Enter a character description in the prompt
783
+ 2. **Enter text for speech generation** (recommended in current mode)
784
+ 3. {"Optionally add reference image/audio URLs (requires full models)" if not omni_api.model_loaded else "Optionally add reference image URL and choose audio source"}
785
+ 4. Choose voice profile and adjust parameters
786
+ 5. Generate your {"audio" if not omni_api.model_loaded else "avatar video"}!
787
+ """,
788
+ examples=[
789
+ [
790
+ "A professional teacher explaining a mathematical concept with clear gestures",
791
+ "Hello students! Today we're going to learn about calculus and derivatives.",
792
+ "",
793
+ "",
794
+ "21m00Tcm4TlvDq8ikWAM",
795
+ 5.0,
796
+ 3.5,
797
+ 30
798
+ ],
799
+ [
800
+ "A friendly presenter speaking confidently to an audience",
801
+ "Welcome everyone to our presentation on artificial intelligence!",
802
+ "",
803
+ "",
804
+ "pNInz6obpgDQGcFmaJgB",
805
+ 5.5,
806
+ 4.0,
807
+ 35
808
+ ]
809
+ ],
810
+ allow_flagging="never",
811
+ flagging_dir="/tmp/gradio_flagged"
812
+ )
813
+
814
+ # Mount Gradio app
815
+ app = gr.mount_gradio_app(app, iface, path="/gradio")
816
+
817
+ if __name__ == "__main__":
818
+ import uvicorn
819
+ uvicorn.run(app, host="0.0.0.0", port=7860)
820
+
821
+
822
+
823
+
824
+
825
+
826
+
827
+
828
+
app_temp.py ADDED
@@ -0,0 +1,827 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import tempfile
4
+ import gradio as gr
5
+ from fastapi import FastAPI, HTTPException
6
+ from fastapi.staticfiles import StaticFiles
7
+ from fastapi.middleware.cors import CORSMiddleware
8
+ from pydantic import BaseModel, HttpUrl
9
+ import subprocess
10
+ import json
11
+ from pathlib import Path
12
+ import logging
13
+ import requests
14
+ from urllib.parse import urlparse
15
+ from PIL import Image
16
+ import io
17
+ from typing import Optional
18
+ import aiohttp
19
+ import asyncio
20
+ from dotenv import load_dotenv
21
+
22
+ # Load environment variables
23
+ load_dotenv()
24
+
25
+ # Set up logging
26
+ logging.basicConfig(level=logging.INFO)
27
+ logger = logging.getLogger(__name__)
28
+
29
+ # Set environment variables for matplotlib, gradio, and huggingface cache
30
+ os.environ['MPLCONFIGDIR'] = '/tmp/matplotlib'
31
+ os.environ['GRADIO_ALLOW_FLAGGING'] = 'never'
32
+ os.environ['HF_HOME'] = '/tmp/huggingface'
33
+ # Use HF_HOME instead of deprecated TRANSFORMERS_CACHE
34
+ os.environ['HF_DATASETS_CACHE'] = '/tmp/huggingface/datasets'
35
+ os.environ['HUGGINGFACE_HUB_CACHE'] = '/tmp/huggingface/hub'
36
+
37
+ # FastAPI app will be created after lifespan is defined
38
+
39
+
40
+
41
+ # Create directories with proper permissions
42
+ os.makedirs("outputs", exist_ok=True)
43
+ os.makedirs("/tmp/matplotlib", exist_ok=True)
44
+ os.makedirs("/tmp/huggingface", exist_ok=True)
45
+ os.makedirs("/tmp/huggingface/transformers", exist_ok=True)
46
+ os.makedirs("/tmp/huggingface/datasets", exist_ok=True)
47
+ os.makedirs("/tmp/huggingface/hub", exist_ok=True)
48
+
49
+ # Mount static files for serving generated videos
50
+
51
+
52
+ def get_video_url(output_path: str) -> str:
53
+ """Convert local file path to accessible URL"""
54
+ try:
55
+ from pathlib import Path
56
+ filename = Path(output_path).name
57
+
58
+ # For HuggingFace Spaces, construct the URL
59
+ base_url = "https://bravedims-ai-avatar-chat.hf.space"
60
+ video_url = f"{base_url}/outputs/{filename}"
61
+ logger.info(f"Generated video URL: {video_url}")
62
+ return video_url
63
+ except Exception as e:
64
+ logger.error(f"Error creating video URL: {e}")
65
+ return output_path # Fallback to original path
66
+
67
+ # Pydantic models for request/response
68
+ class GenerateRequest(BaseModel):
69
+ prompt: str
70
+ text_to_speech: Optional[str] = None # Text to convert to speech
71
+ audio_url: Optional[HttpUrl] = None # Direct audio URL
72
+ voice_id: Optional[str] = "21m00Tcm4TlvDq8ikWAM" # Voice profile ID
73
+ image_url: Optional[HttpUrl] = None
74
+ guidance_scale: float = 5.0
75
+ audio_scale: float = 3.0
76
+ num_steps: int = 30
77
+ sp_size: int = 1
78
+ tea_cache_l1_thresh: Optional[float] = None
79
+
80
+ class GenerateResponse(BaseModel):
81
+ message: str
82
+ output_path: str
83
+ processing_time: float
84
+ audio_generated: bool = False
85
+ tts_method: Optional[str] = None
86
+
87
+ # Try to import TTS clients, but make them optional
88
+ try:
89
+ from advanced_tts_client import AdvancedTTSClient
90
+ ADVANCED_TTS_AVAILABLE = True
91
+ logger.info("SUCCESS: Advanced TTS client available")
92
+ except ImportError as e:
93
+ ADVANCED_TTS_AVAILABLE = False
94
+ logger.warning(f"WARNING: Advanced TTS client not available: {e}")
95
+
96
+ # Always import the robust fallback
97
+ try:
98
+ from robust_tts_client import RobustTTSClient
99
+ ROBUST_TTS_AVAILABLE = True
100
+ logger.info("SUCCESS: Robust TTS client available")
101
+ except ImportError as e:
102
+ ROBUST_TTS_AVAILABLE = False
103
+ logger.error(f"ERROR: Robust TTS client not available: {e}")
104
+
105
+ class TTSManager:
106
+ """Manages multiple TTS clients with fallback chain"""
107
+
108
+ def __init__(self):
109
+ # Initialize TTS clients based on availability
110
+ self.advanced_tts = None
111
+ self.robust_tts = None
112
+ self.clients_loaded = False
113
+
114
+ if ADVANCED_TTS_AVAILABLE:
115
+ try:
116
+ self.advanced_tts = AdvancedTTSClient()
117
+ logger.info("SUCCESS: Advanced TTS client initialized")
118
+ except Exception as e:
119
+ logger.warning(f"WARNING: Advanced TTS client initialization failed: {e}")
120
+
121
+ if ROBUST_TTS_AVAILABLE:
122
+ try:
123
+ self.robust_tts = RobustTTSClient()
124
+ logger.info("SUCCESS: Robust TTS client initialized")
125
+ except Exception as e:
126
+ logger.error(f"ERROR: Robust TTS client initialization failed: {e}")
127
+
128
+ if not self.advanced_tts and not self.robust_tts:
129
+ logger.error("ERROR: No TTS clients available!")
130
+
131
+ async def load_models(self):
132
+ """Load TTS models"""
133
+ try:
134
+ logger.info("Loading TTS models...")
135
+
136
+ # Try to load advanced TTS first
137
+ if self.advanced_tts:
138
+ try:
139
+ logger.info("[PROCESS] Loading advanced TTS models (this may take a few minutes)...")
140
+ success = await self.advanced_tts.load_models()
141
+ if success:
142
+ logger.info("SUCCESS: Advanced TTS models loaded successfully")
143
+ else:
144
+ logger.warning("WARNING: Advanced TTS models failed to load")
145
+ except Exception as e:
146
+ logger.warning(f"WARNING: Advanced TTS loading error: {e}")
147
+
148
+ # Always ensure robust TTS is available
149
+ if self.robust_tts:
150
+ try:
151
+ await self.robust_tts.load_model()
152
+ logger.info("SUCCESS: Robust TTS fallback ready")
153
+ except Exception as e:
154
+ logger.error(f"ERROR: Robust TTS loading failed: {e}")
155
+
156
+ self.clients_loaded = True
157
+ return True
158
+
159
+ except Exception as e:
160
+ logger.error(f"ERROR: TTS manager initialization failed: {e}")
161
+ return False
162
+
163
+ async def text_to_speech(self, text: str, voice_id: Optional[str] = None) -> tuple[str, str]:
164
+ """
165
+ Convert text to speech with fallback chain
166
+ Returns: (audio_file_path, method_used)
167
+ """
168
+ if not self.clients_loaded:
169
+ logger.info("TTS models not loaded, loading now...")
170
+ await self.load_models()
171
+
172
+ logger.info(f"Generating speech: {text[:50]}...")
173
+ logger.info(f"Voice ID: {voice_id}")
174
+
175
+ # Try Advanced TTS first (Facebook VITS / SpeechT5)
176
+ if self.advanced_tts:
177
+ try:
178
+ audio_path = await self.advanced_tts.text_to_speech(text, voice_id)
179
+ return audio_path, "Facebook VITS/SpeechT5"
180
+ except Exception as advanced_error:
181
+ logger.warning(f"Advanced TTS failed: {advanced_error}")
182
+
183
+ # Fall back to robust TTS
184
+ if self.robust_tts:
185
+ try:
186
+ logger.info("Falling back to robust TTS...")
187
+ audio_path = await self.robust_tts.text_to_speech(text, voice_id)
188
+ return audio_path, "Robust TTS (Fallback)"
189
+ except Exception as robust_error:
190
+ logger.error(f"Robust TTS also failed: {robust_error}")
191
+
192
+ # If we get here, all methods failed
193
+ logger.error("All TTS methods failed!")
194
+ raise HTTPException(
195
+ status_code=500,
196
+ detail="All TTS methods failed. Please check system configuration."
197
+ )
198
+
199
+ async def get_available_voices(self):
200
+ """Get available voice configurations"""
201
+ try:
202
+ if self.advanced_tts and hasattr(self.advanced_tts, 'get_available_voices'):
203
+ return await self.advanced_tts.get_available_voices()
204
+ except:
205
+ pass
206
+
207
+ # Return default voices if advanced TTS not available
208
+ return {
209
+ "21m00Tcm4TlvDq8ikWAM": "Female (Neutral)",
210
+ "pNInz6obpgDQGcFmaJgB": "Male (Professional)",
211
+ "EXAVITQu4vr4xnSDxMaL": "Female (Sweet)",
212
+ "ErXwobaYiN019PkySvjV": "Male (Professional)",
213
+ "TxGEqnHWrfGW9XjX": "Male (Deep)",
214
+ "yoZ06aMxZJJ28mfd3POQ": "Unisex (Friendly)",
215
+ "AZnzlk1XvdvUeBnXmlld": "Female (Strong)"
216
+ }
217
+
218
+ def get_tts_info(self):
219
+ """Get TTS system information"""
220
+ info = {
221
+ "clients_loaded": self.clients_loaded,
222
+ "advanced_tts_available": self.advanced_tts is not None,
223
+ "robust_tts_available": self.robust_tts is not None,
224
+ "primary_method": "Robust TTS"
225
+ }
226
+
227
+ try:
228
+ if self.advanced_tts and hasattr(self.advanced_tts, 'get_model_info'):
229
+ advanced_info = self.advanced_tts.get_model_info()
230
+ info.update({
231
+ "advanced_tts_loaded": advanced_info.get("models_loaded", False),
232
+ "transformers_available": advanced_info.get("transformers_available", False),
233
+ "primary_method": "Facebook VITS/SpeechT5" if advanced_info.get("models_loaded") else "Robust TTS",
234
+ "device": advanced_info.get("device", "cpu"),
235
+ "vits_available": advanced_info.get("vits_available", False),
236
+ "speecht5_available": advanced_info.get("speecht5_available", False)
237
+ })
238
+ except Exception as e:
239
+ logger.debug(f"Could not get advanced TTS info: {e}")
240
+
241
+ return info
242
+
243
+ # Import the VIDEO-FOCUSED engine
244
+ try:
245
+ from omniavatar_video_engine import video_engine
246
+ VIDEO_ENGINE_AVAILABLE = True
247
+ logger.info("SUCCESS: OmniAvatar Video Engine available")
248
+ except ImportError as e:
249
+ VIDEO_ENGINE_AVAILABLE = False
250
+ logger.error(f"ERROR: OmniAvatar Video Engine not available: {e}")
251
+
252
+ class OmniAvatarAPI:
253
+ def __init__(self):
254
+ self.model_loaded = False
255
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
256
+ self.tts_manager = TTSManager()
257
+ logger.info(f"Using device: {self.device}")
258
+ logger.info("Initialized with robust TTS system")
259
+
260
+ def load_model(self):
261
+ """Load the OmniAvatar model - now more flexible"""
262
+ try:
263
+ # Check if models are downloaded (but don't require them)
264
+ model_paths = [
265
+ "./pretrained_models/Wan2.1-T2V-14B",
266
+ "./pretrained_models/OmniAvatar-14B",
267
+ "./pretrained_models/wav2vec2-base-960h"
268
+ ]
269
+
270
+ missing_models = []
271
+ for path in model_paths:
272
+ if not os.path.exists(path):
273
+ missing_models.append(path)
274
+
275
+ if missing_models:
276
+ logger.warning("WARNING: Some OmniAvatar models not found:")
277
+ for model in missing_models:
278
+ logger.warning(f" - {model}")
279
+ logger.info("TIP: App will run in TTS-only mode (no video generation)")
280
+ logger.info("TIP: To enable full avatar generation, download the required models")
281
+
282
+ # Set as loaded but in limited mode
283
+ self.model_loaded = False # Video generation disabled
284
+ return True # But app can still run
285
+ else:
286
+ self.model_loaded = True
287
+ logger.info("SUCCESS: All OmniAvatar models found - full functionality enabled")
288
+ return True
289
+
290
+ except Exception as e:
291
+ logger.error(f"Error checking models: {str(e)}")
292
+ logger.info("TIP: Continuing in TTS-only mode")
293
+ self.model_loaded = False
294
+ return True # Continue running
295
+
296
+ async def download_file(self, url: str, suffix: str = "") -> str:
297
+ """Download file from URL and save to temporary location"""
298
+ try:
299
+ async with aiohttp.ClientSession() as session:
300
+ async with session.get(str(url)) as response:
301
+ if response.status != 200:
302
+ raise HTTPException(status_code=400, detail=f"Failed to download file from URL: {url}")
303
+
304
+ content = await response.read()
305
+
306
+ # Create temporary file
307
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=suffix)
308
+ temp_file.write(content)
309
+ temp_file.close()
310
+
311
+ return temp_file.name
312
+
313
+ except aiohttp.ClientError as e:
314
+ logger.error(f"Network error downloading {url}: {e}")
315
+ raise HTTPException(status_code=400, detail=f"Network error downloading file: {e}")
316
+ except Exception as e:
317
+ logger.error(f"Error downloading file from {url}: {e}")
318
+ raise HTTPException(status_code=500, detail=f"Error downloading file: {e}")
319
+
320
+ def validate_audio_url(self, url: str) -> bool:
321
+ """Validate if URL is likely an audio file"""
322
+ try:
323
+ parsed = urlparse(url)
324
+ # Check for common audio file extensions
325
+ audio_extensions = ['.mp3', '.wav', '.m4a', '.ogg', '.aac', '.flac']
326
+ is_audio_ext = any(parsed.path.lower().endswith(ext) for ext in audio_extensions)
327
+
328
+ return is_audio_ext or 'audio' in url.lower()
329
+ except:
330
+ return False
331
+
332
+ def validate_image_url(self, url: str) -> bool:
333
+ """Validate if URL is likely an image file"""
334
+ try:
335
+ parsed = urlparse(url)
336
+ image_extensions = ['.jpg', '.jpeg', '.png', '.webp', '.bmp', '.gif']
337
+ return any(parsed.path.lower().endswith(ext) for ext in image_extensions)
338
+ except:
339
+ return False
340
+
341
+ async def generate_avatar(self, request: GenerateRequest) -> tuple[str, float, bool, str]:
342
+ """Generate avatar VIDEO - PRIMARY FUNCTIONALITY"""
343
+ import time
344
+ start_time = time.time()
345
+ audio_generated = False
346
+ method_used = "Unknown"
347
+
348
+ logger.info("[VIDEO] STARTING AVATAR VIDEO GENERATION")
349
+ logger.info(f"[INFO] Prompt: {request.prompt}")
350
+
351
+ if VIDEO_ENGINE_AVAILABLE:
352
+ try:
353
+ # PRIORITIZE VIDEO GENERATION
354
+ logger.info("[TARGET] Using OmniAvatar Video Engine for FULL video generation")
355
+
356
+ # Handle audio source
357
+ audio_path = None
358
+ if request.text_to_speech:
359
+ logger.info("[MIC] Generating audio from text...")
360
+ audio_path, method_used = await self.tts_manager.text_to_speech(
361
+ request.text_to_speech,
362
+ request.voice_id or "21m00Tcm4TlvDq8ikWAM"
363
+ )
364
+ audio_generated = True
365
+ elif request.audio_url:
366
+ logger.info("?? Downloading audio from URL...")
367
+ audio_path = await self.download_file(str(request.audio_url), ".mp3")
368
+ method_used = "External Audio"
369
+ else:
370
+ raise HTTPException(status_code=400, detail="Either text_to_speech or audio_url required for video generation")
371
+
372
+ # Handle image if provided
373
+ image_path = None
374
+ if request.image_url:
375
+ logger.info("[IMAGE] Downloading reference image...")
376
+ parsed = urlparse(str(request.image_url))
377
+ ext = os.path.splitext(parsed.path)[1] or ".jpg"
378
+ image_path = await self.download_file(str(request.image_url), ext)
379
+
380
+ # GENERATE VIDEO using OmniAvatar engine
381
+ logger.info("[VIDEO] Generating avatar video with adaptive body animation...")
382
+ video_path, generation_time = video_engine.generate_avatar_video(
383
+ prompt=request.prompt,
384
+ audio_path=audio_path,
385
+ image_path=image_path,
386
+ guidance_scale=request.guidance_scale,
387
+ audio_scale=request.audio_scale,
388
+ num_steps=request.num_steps
389
+ )
390
+
391
+ processing_time = time.time() - start_time
392
+ logger.info(f"SUCCESS: VIDEO GENERATED successfully in {processing_time:.1f}s")
393
+
394
+ # Cleanup temporary files
395
+ if audio_path and os.path.exists(audio_path):
396
+ os.unlink(audio_path)
397
+ if image_path and os.path.exists(image_path):
398
+ os.unlink(image_path)
399
+
400
+ return video_path, processing_time, audio_generated, f"OmniAvatar Video Generation ({method_used})"
401
+
402
+ except Exception as e:
403
+ logger.error(f"ERROR: Video generation failed: {e}")
404
+ # For a VIDEO generation app, we should NOT fall back to audio-only
405
+ # Instead, provide clear guidance
406
+ if "models" in str(e).lower():
407
+ raise HTTPException(
408
+ status_code=503,
409
+ detail=f"Video generation requires OmniAvatar models (~30GB). Please run model download script. Error: {str(e)}"
410
+ )
411
+ else:
412
+ raise HTTPException(status_code=500, detail=f"Video generation failed: {str(e)}")
413
+
414
+ # If video engine not available, this is a critical error for a VIDEO app
415
+ raise HTTPException(
416
+ status_code=503,
417
+ detail="Video generation engine not available. This application requires OmniAvatar models for video generation."
418
+ )
419
+
420
+ async def generate_avatar_BACKUP(self, request: GenerateRequest) -> tuple[str, float, bool, str]:
421
+ """OLD TTS-ONLY METHOD - kept as backup reference
422
+ """Generate avatar video from prompt and audio/text - now handles missing models"""
423
+ import time
424
+ start_time = time.time()
425
+ audio_generated = False
426
+ tts_method = None
427
+
428
+ try:
429
+ # Check if video generation is available
430
+ if not self.model_loaded:
431
+ logger.info("??? Running in TTS-only mode (OmniAvatar models not available)")
432
+
433
+ # Only generate audio, no video
434
+ if request.text_to_speech:
435
+ logger.info(f"Generating speech from text: {request.text_to_speech[:50]}...")
436
+ audio_path, tts_method = await self.tts_manager.text_to_speech(
437
+ request.text_to_speech,
438
+ request.voice_id or "21m00Tcm4TlvDq8ikWAM"
439
+ )
440
+
441
+ # Return the audio file as the "output"
442
+ processing_time = time.time() - start_time
443
+ logger.info(f"SUCCESS: TTS completed in {processing_time:.1f}s using {tts_method}")
444
+ return audio_path, processing_time, True, f"{tts_method} (TTS-only mode)"
445
+ else:
446
+ raise HTTPException(
447
+ status_code=503,
448
+ detail="Video generation unavailable. OmniAvatar models not found. Only TTS from text is supported."
449
+ )
450
+
451
+ # Original video generation logic (when models are available)
452
+ # Determine audio source
453
+ audio_path = None
454
+
455
+ if request.text_to_speech:
456
+ # Generate speech from text using TTS manager
457
+ logger.info(f"Generating speech from text: {request.text_to_speech[:50]}...")
458
+ audio_path, tts_method = await self.tts_manager.text_to_speech(
459
+ request.text_to_speech,
460
+ request.voice_id or "21m00Tcm4TlvDq8ikWAM"
461
+ )
462
+ audio_generated = True
463
+
464
+ elif request.audio_url:
465
+ # Download audio from provided URL
466
+ logger.info(f"Downloading audio from URL: {request.audio_url}")
467
+ if not self.validate_audio_url(str(request.audio_url)):
468
+ logger.warning(f"Audio URL may not be valid: {request.audio_url}")
469
+
470
+ audio_path = await self.download_file(str(request.audio_url), ".mp3")
471
+ tts_method = "External Audio URL"
472
+
473
+ else:
474
+ raise HTTPException(
475
+ status_code=400,
476
+ detail="Either text_to_speech or audio_url must be provided"
477
+ )
478
+
479
+ # Download image if provided
480
+ image_path = None
481
+ if request.image_url:
482
+ logger.info(f"Downloading image from URL: {request.image_url}")
483
+ if not self.validate_image_url(str(request.image_url)):
484
+ logger.warning(f"Image URL may not be valid: {request.image_url}")
485
+
486
+ # Determine image extension from URL or default to .jpg
487
+ parsed = urlparse(str(request.image_url))
488
+ ext = os.path.splitext(parsed.path)[1] or ".jpg"
489
+ image_path = await self.download_file(str(request.image_url), ext)
490
+
491
+ # Create temporary input file for inference
492
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f:
493
+ if image_path:
494
+ input_line = f"{request.prompt}@@{image_path}@@{audio_path}"
495
+ else:
496
+ input_line = f"{request.prompt}@@@@{audio_path}"
497
+ f.write(input_line)
498
+ temp_input_file = f.name
499
+
500
+ # Prepare inference command
501
+ cmd = [
502
+ "python", "-m", "torch.distributed.run",
503
+ "--standalone", f"--nproc_per_node={request.sp_size}",
504
+ "scripts/inference.py",
505
+ "--config", "configs/inference.yaml",
506
+ "--input_file", temp_input_file,
507
+ "--guidance_scale", str(request.guidance_scale),
508
+ "--audio_scale", str(request.audio_scale),
509
+ "--num_steps", str(request.num_steps)
510
+ ]
511
+
512
+ if request.tea_cache_l1_thresh:
513
+ cmd.extend(["--tea_cache_l1_thresh", str(request.tea_cache_l1_thresh)])
514
+
515
+ logger.info(f"Running inference with command: {' '.join(cmd)}")
516
+
517
+ # Run inference
518
+ result = subprocess.run(cmd, capture_output=True, text=True)
519
+
520
+ # Clean up temporary files
521
+ os.unlink(temp_input_file)
522
+ os.unlink(audio_path)
523
+ if image_path:
524
+ os.unlink(image_path)
525
+
526
+ if result.returncode != 0:
527
+ logger.error(f"Inference failed: {result.stderr}")
528
+ raise Exception(f"Inference failed: {result.stderr}")
529
+
530
+ # Find output video file
531
+ output_dir = "./outputs"
532
+ if os.path.exists(output_dir):
533
+ video_files = [f for f in os.listdir(output_dir) if f.endswith(('.mp4', '.avi'))]
534
+ if video_files:
535
+ # Return the most recent video file
536
+ video_files.sort(key=lambda x: os.path.getmtime(os.path.join(output_dir, x)), reverse=True)
537
+ output_path = os.path.join(output_dir, video_files[0])
538
+ processing_time = time.time() - start_time
539
+ return output_path, processing_time, audio_generated, tts_method
540
+
541
+ raise Exception("No output video generated")
542
+
543
+ except Exception as e:
544
+ # Clean up any temporary files in case of error
545
+ try:
546
+ if 'audio_path' in locals() and audio_path and os.path.exists(audio_path):
547
+ os.unlink(audio_path)
548
+ if 'image_path' in locals() and image_path and os.path.exists(image_path):
549
+ os.unlink(image_path)
550
+ if 'temp_input_file' in locals() and os.path.exists(temp_input_file):
551
+ os.unlink(temp_input_file)
552
+ except:
553
+ pass
554
+
555
+ logger.error(f"Generation error: {str(e)}")
556
+ raise HTTPException(status_code=500, detail=str(e))
557
+
558
+ # Initialize API
559
+ omni_api = OmniAvatarAPI()
560
+
561
+ # Use FastAPI lifespan instead of deprecated on_event
562
+ from contextlib import asynccontextmanager
563
+
564
+ @asynccontextmanager
565
+ async def lifespan(app: FastAPI):
566
+ # Startup
567
+ success = omni_api.load_model()
568
+ if not success:
569
+ logger.warning("WARNING: OmniAvatar model loading failed - running in limited mode")
570
+
571
+ # Load TTS models
572
+ try:
573
+ await omni_api.tts_manager.load_models()
574
+ logger.info("SUCCESS: TTS models initialization completed")
575
+ except Exception as e:
576
+ logger.error(f"ERROR: TTS initialization failed: {e}")
577
+
578
+ yield
579
+
580
+ # Shutdown (if needed)
581
+ logger.info("Application shutting down...")
582
+
583
+ # Create FastAPI app WITH lifespan parameter
584
+ app = FastAPI(
585
+ title="OmniAvatar-14B API with Advanced TTS",
586
+ version="1.0.0",
587
+ lifespan=lifespan
588
+ )
589
+
590
+ # Add CORS middleware
591
+ app.add_middleware(
592
+ CORSMiddleware,
593
+ allow_origins=["*"],
594
+ allow_credentials=True,
595
+ allow_methods=["*"],
596
+ allow_headers=["*"],
597
+ )
598
+
599
+ # Mount static files for serving generated videos
600
+ app.mount("/outputs", StaticFiles(directory="outputs"), name="outputs")
601
+
602
+ @app.get("/health")
603
+ async def health_check():
604
+ """Health check endpoint"""
605
+ tts_info = omni_api.tts_manager.get_tts_info()
606
+
607
+ return {
608
+ "status": "healthy",
609
+ "model_loaded": omni_api.model_loaded,
610
+ "video_generation_available": omni_api.model_loaded,
611
+ "tts_only_mode": not omni_api.model_loaded,
612
+ "device": omni_api.device,
613
+ "supports_text_to_speech": True,
614
+ "supports_image_urls": omni_api.model_loaded,
615
+ "supports_audio_urls": omni_api.model_loaded,
616
+ "tts_system": "Advanced TTS with Robust Fallback",
617
+ "advanced_tts_available": ADVANCED_TTS_AVAILABLE,
618
+ "robust_tts_available": ROBUST_TTS_AVAILABLE,
619
+ **tts_info
620
+ }
621
+
622
+ @app.get("/voices")
623
+ async def get_voices():
624
+ """Get available voice configurations"""
625
+ try:
626
+ voices = await omni_api.tts_manager.get_available_voices()
627
+ return {"voices": voices}
628
+ except Exception as e:
629
+ logger.error(f"Error getting voices: {e}")
630
+ return {"error": str(e)}
631
+
632
+ @app.post("/generate", response_model=GenerateResponse)
633
+ async def generate_avatar(request: GenerateRequest):
634
+ """Generate avatar video from prompt, text/audio, and optional image URL"""
635
+
636
+ logger.info(f"Generating avatar with prompt: {request.prompt}")
637
+ if request.text_to_speech:
638
+ logger.info(f"Text to speech: {request.text_to_speech[:100]}...")
639
+ logger.info(f"Voice ID: {request.voice_id}")
640
+ if request.audio_url:
641
+ logger.info(f"Audio URL: {request.audio_url}")
642
+ if request.image_url:
643
+ logger.info(f"Image URL: {request.image_url}")
644
+
645
+ try:
646
+ output_path, processing_time, audio_generated, tts_method = await omni_api.generate_avatar(request)
647
+
648
+ return GenerateResponse(
649
+ message="Generation completed successfully" + (" (TTS-only mode)" if not omni_api.model_loaded else ""),
650
+ output_path=get_video_url(output_path) if omni_api.model_loaded else output_path,
651
+ processing_time=processing_time,
652
+ audio_generated=audio_generated,
653
+ tts_method=tts_method
654
+ )
655
+
656
+ except HTTPException:
657
+ raise
658
+ except Exception as e:
659
+ logger.error(f"Unexpected error: {e}")
660
+ raise HTTPException(status_code=500, detail=f"Unexpected error: {e}")
661
+
662
+ # Enhanced Gradio interface
663
+ def gradio_generate(prompt, text_to_speech, audio_url, image_url, voice_id, guidance_scale, audio_scale, num_steps):
664
+ """Gradio interface wrapper with robust TTS support"""
665
+ try:
666
+ # Create request object
667
+ request_data = {
668
+ "prompt": prompt,
669
+ "guidance_scale": guidance_scale,
670
+ "audio_scale": audio_scale,
671
+ "num_steps": int(num_steps)
672
+ }
673
+
674
+ # Add audio source
675
+ if text_to_speech and text_to_speech.strip():
676
+ request_data["text_to_speech"] = text_to_speech
677
+ request_data["voice_id"] = voice_id or "21m00Tcm4TlvDq8ikWAM"
678
+ elif audio_url and audio_url.strip():
679
+ if omni_api.model_loaded:
680
+ request_data["audio_url"] = audio_url
681
+ else:
682
+ return "Error: Audio URL input requires full OmniAvatar models. Please use text-to-speech instead."
683
+ else:
684
+ return "Error: Please provide either text to speech or audio URL"
685
+
686
+ if image_url and image_url.strip():
687
+ if omni_api.model_loaded:
688
+ request_data["image_url"] = image_url
689
+ else:
690
+ return "Error: Image URL input requires full OmniAvatar models for video generation."
691
+
692
+ request = GenerateRequest(**request_data)
693
+
694
+ # Run async function in sync context
695
+ loop = asyncio.new_event_loop()
696
+ asyncio.set_event_loop(loop)
697
+ output_path, processing_time, audio_generated, tts_method = loop.run_until_complete(omni_api.generate_avatar(request))
698
+ loop.close()
699
+
700
+ success_message = f"SUCCESS: Generation completed in {processing_time:.1f}s using {tts_method}"
701
+ print(success_message)
702
+
703
+ if omni_api.model_loaded:
704
+ return output_path
705
+ else:
706
+ return f"??? TTS Audio generated successfully using {tts_method}\nFile: {output_path}\n\nWARNING: Video generation unavailable (OmniAvatar models not found)"
707
+
708
+ except Exception as e:
709
+ logger.error(f"Gradio generation error: {e}")
710
+ return f"Error: {str(e)}"
711
+
712
+ # Create Gradio interface
713
+ mode_info = " (TTS-Only Mode)" if not omni_api.model_loaded else ""
714
+ description_extra = """
715
+ WARNING: Running in TTS-Only Mode - OmniAvatar models not found. Only text-to-speech generation is available.
716
+ To enable full video generation, the required model files need to be downloaded.
717
+ """ if not omni_api.model_loaded else ""
718
+
719
+ iface = gr.Interface(
720
+ fn=gradio_generate,
721
+ inputs=[
722
+ gr.Textbox(
723
+ label="Prompt",
724
+ placeholder="Describe the character behavior (e.g., 'A friendly person explaining a concept')",
725
+ lines=2
726
+ ),
727
+ gr.Textbox(
728
+ label="Text to Speech",
729
+ placeholder="Enter text to convert to speech",
730
+ lines=3,
731
+ info="Will use best available TTS system (Advanced or Fallback)"
732
+ ),
733
+ gr.Textbox(
734
+ label="OR Audio URL",
735
+ placeholder="https://example.com/audio.mp3",
736
+ info="Direct URL to audio file (requires full models)" if not omni_api.model_loaded else "Direct URL to audio file"
737
+ ),
738
+ gr.Textbox(
739
+ label="Image URL (Optional)",
740
+ placeholder="https://example.com/image.jpg",
741
+ info="Direct URL to reference image (requires full models)" if not omni_api.model_loaded else "Direct URL to reference image"
742
+ ),
743
+ gr.Dropdown(
744
+ choices=[
745
+ "21m00Tcm4TlvDq8ikWAM",
746
+ "pNInz6obpgDQGcFmaJgB",
747
+ "EXAVITQu4vr4xnSDxMaL",
748
+ "ErXwobaYiN019PkySvjV",
749
+ "TxGEqnHWrfGW9XjX",
750
+ "yoZ06aMxZJJ28mfd3POQ",
751
+ "AZnzlk1XvdvUeBnXmlld"
752
+ ],
753
+ value="21m00Tcm4TlvDq8ikWAM",
754
+ label="Voice Profile",
755
+ info="Choose voice characteristics for TTS generation"
756
+ ),
757
+ gr.Slider(minimum=1, maximum=10, value=5.0, label="Guidance Scale", info="4-6 recommended"),
758
+ gr.Slider(minimum=1, maximum=10, value=3.0, label="Audio Scale", info="Higher values = better lip-sync"),
759
+ gr.Slider(minimum=10, maximum=100, value=30, step=1, label="Number of Steps", info="20-50 recommended")
760
+ ],
761
+ outputs=gr.Video(label="Generated Avatar Video") if omni_api.model_loaded else gr.Textbox(label="TTS Output"),
762
+ title="[VIDEO] OmniAvatar-14B - Avatar Video Generation with Adaptive Body Animation",
763
+ description=f"""
764
+ Generate avatar videos with lip-sync from text prompts and speech using robust TTS system.
765
+
766
+ {description_extra}
767
+
768
+ **Robust TTS Architecture**
769
+ - **Primary**: Advanced TTS (Facebook VITS & SpeechT5) if available
770
+ - **Fallback**: Robust tone generation for 100% reliability
771
+ - **Automatic**: Seamless switching between methods
772
+
773
+ **Features:**
774
+ - **Guaranteed Generation**: Always produces audio output
775
+ - **No Dependencies**: Works even without advanced models
776
+ - **High Availability**: Multiple fallback layers
777
+ - **Voice Profiles**: Multiple voice characteristics
778
+ - **Audio URL Support**: Use external audio files {"(full models required)" if not omni_api.model_loaded else ""}
779
+ - **Image URL Support**: Reference images for characters {"(full models required)" if not omni_api.model_loaded else ""}
780
+
781
+ **Usage:**
782
+ 1. Enter a character description in the prompt
783
+ 2. **Enter text for speech generation** (recommended in current mode)
784
+ 3. {"Optionally add reference image/audio URLs (requires full models)" if not omni_api.model_loaded else "Optionally add reference image URL and choose audio source"}
785
+ 4. Choose voice profile and adjust parameters
786
+ 5. Generate your {"audio" if not omni_api.model_loaded else "avatar video"}!
787
+ """,
788
+ examples=[
789
+ [
790
+ "A professional teacher explaining a mathematical concept with clear gestures",
791
+ "Hello students! Today we're going to learn about calculus and derivatives.",
792
+ "",
793
+ "",
794
+ "21m00Tcm4TlvDq8ikWAM",
795
+ 5.0,
796
+ 3.5,
797
+ 30
798
+ ],
799
+ [
800
+ "A friendly presenter speaking confidently to an audience",
801
+ "Welcome everyone to our presentation on artificial intelligence!",
802
+ "",
803
+ "",
804
+ "pNInz6obpgDQGcFmaJgB",
805
+ 5.5,
806
+ 4.0,
807
+ 35
808
+ ]
809
+ ],
810
+ allow_flagging="never",
811
+ flagging_dir="/tmp/gradio_flagged"
812
+ )
813
+
814
+ # Mount Gradio app
815
+ app = gr.mount_gradio_app(app, iface, path="/gradio")
816
+
817
+ if __name__ == "__main__":
818
+ import uvicorn
819
+ uvicorn.run(app, host="0.0.0.0", port=7860)
820
+
821
+
822
+
823
+
824
+
825
+
826
+
827
+