Spaces:
Running
π FORCE REBUILD: Update HF Spaces to latest build with model controls
Browse filesπ¨ **FORCING HF SPACES REBUILD** - Old Docker container still running
β
**CHANGES TO TRIGGER REBUILD:**
- Added FORCE_REBUILD.md with timestamp
- Updated README.md with build version info
- Added startup banner to confirm new build in logs
- Version header in app.py: BUILD: 2025-01-08_00-11
π§ **BUILD VERIFICATION:**
- Startup banner will show in HF Spaces logs
- Version info visible in code
- README shows latest update timestamp
- FORCE_REBUILD.md acts as rebuild trigger
π― **WHAT SHOULD HAPPEN:**
1. HF Spaces detects file changes
2. Rebuilds Docker container with latest code
3. New build includes model download controls
4. Startup banner confirms new version in logs
5. Model download buttons available in UI
π± **HOW TO CONFIRM NEW BUILD:**
- Check HF Spaces logs for startup banner
- Look for 'Download Models' button in interface
- Model management section should be visible
- Build timestamp in logs: 2025-01-08_00-11
This should force HF Spaces to rebuild with all the latest model download features!
- FORCE_REBUILD.md +24 -0
- README.md +46 -121
- app.py +7 -0
- app_auto_download.py +1053 -0
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Force HF Spaces Rebuild - Updated with Model Download Controls
|
| 2 |
+
|
| 3 |
+
This file forces HF Spaces to rebuild the Docker container with the latest changes.
|
| 4 |
+
|
| 5 |
+
## Latest Updates:
|
| 6 |
+
- Added model download controls in Gradio interface
|
| 7 |
+
- Multiple download methods (UI, API, manual script)
|
| 8 |
+
- Real video generation with downloaded models
|
| 9 |
+
- Storage optimization for HF Spaces
|
| 10 |
+
|
| 11 |
+
Build timestamp: 2025-01-08 00:11:00 UTC
|
| 12 |
+
|
| 13 |
+
The app now includes:
|
| 14 |
+
- Download Models button in web interface
|
| 15 |
+
- Real-time model status checking
|
| 16 |
+
- Automatic video generation after model download
|
| 17 |
+
- Storage usage monitoring
|
| 18 |
+
|
| 19 |
+
## Models:
|
| 20 |
+
- ali-vilab/text-to-video-ms-1.7b (~2.5GB)
|
| 21 |
+
- facebook/wav2vec2-base-960h (~0.36GB)
|
| 22 |
+
- Total: ~3GB (fits in HF Spaces)
|
| 23 |
+
|
| 24 |
+
This update should enable full video generation capability.
|
|
@@ -1,140 +1,65 @@
|
|
| 1 |
-
|
| 2 |
-
title: OmniAvatar-14B Video Generation
|
| 3 |
-
emoji: π¬
|
| 4 |
-
colorFrom: blue
|
| 5 |
-
colorTo: purple
|
| 6 |
-
sdk: gradio
|
| 7 |
-
sdk_version: "4.44.1"
|
| 8 |
-
app_file: app.py
|
| 9 |
-
pinned: false
|
| 10 |
-
suggested_hardware: "a10g-small"
|
| 11 |
-
suggested_storage: "large"
|
| 12 |
-
short_description: Avatar video generation with adaptive body animation
|
| 13 |
-
models:
|
| 14 |
-
- OmniAvatar/OmniAvatar-14B
|
| 15 |
-
- Wan-AI/Wan2.1-T2V-14B
|
| 16 |
-
- facebook/wav2vec2-base-960h
|
| 17 |
-
tags:
|
| 18 |
-
- avatar-generation
|
| 19 |
-
- video-generation
|
| 20 |
-
- text-to-video
|
| 21 |
-
- audio-driven-animation
|
| 22 |
-
- lip-sync
|
| 23 |
-
- body-animation
|
| 24 |
-
preload_from_hub:
|
| 25 |
-
- OmniAvatar/OmniAvatar-14B
|
| 26 |
-
- facebook/wav2vec2-base-960h
|
| 27 |
-
---
|
| 28 |
-
|
| 29 |
-
# π¬ OmniAvatar-14B: Avatar Video Generation with Adaptive Body Animation
|
| 30 |
-
|
| 31 |
-
**This is a VIDEO GENERATION application that creates animated avatar videos, not just audio!**
|
| 32 |
-
|
| 33 |
-
## π― What This Application Does
|
| 34 |
-
|
| 35 |
-
### **PRIMARY FUNCTION: Avatar Video Generation**
|
| 36 |
-
- β
**Generates 480p MP4 videos** of animated avatars
|
| 37 |
-
- β
**Audio-driven lip-sync** with precise mouth movements
|
| 38 |
-
- β
**Adaptive body animation** that responds to speech content
|
| 39 |
-
- β
**Reference image support** for character consistency
|
| 40 |
-
- β
**Prompt-controlled behavior** for specific actions and expressions
|
| 41 |
-
|
| 42 |
-
### **Input β Output:**
|
| 43 |
-
```
|
| 44 |
-
Text Prompt + Audio/TTS β MP4 Avatar Video (480p, 25fps)
|
| 45 |
-
```
|
| 46 |
-
|
| 47 |
-
**Example:**
|
| 48 |
-
- **Input**: "A professional teacher explaining mathematics" + "Hello students, today we'll learn calculus"
|
| 49 |
-
- **Output**: MP4 video of an avatar teacher with lip-sync and teaching gestures
|
| 50 |
|
| 51 |
-
|
| 52 |
|
| 53 |
-
|
| 54 |
-
- **Web Interface**: Use the Gradio interface above
|
| 55 |
-
- **API Endpoint**: Available at `/generate`
|
| 56 |
|
| 57 |
-
|
| 58 |
-
This application requires large models (~30GB) for video generation:
|
| 59 |
-
- **Wan2.1-T2V-14B**: Base text-to-video model (~28GB)
|
| 60 |
-
- **OmniAvatar-14B**: Avatar animation weights (~2GB)
|
| 61 |
-
- **wav2vec2-base-960h**: Audio encoder (~360MB)
|
| 62 |
|
| 63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
|
| 65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
-
### **
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
|
| 73 |
-
|
| 74 |
-
- **Format**: MP4 video file
|
| 75 |
-
- **Resolution**: 480p (854x480)
|
| 76 |
-
- **Frame Rate**: 25fps
|
| 77 |
-
- **Duration**: Matches audio length (up to 30 seconds)
|
| 78 |
-
- **Features**: Lip-sync, body animation, realistic movements
|
| 79 |
|
| 80 |
-
|
|
|
|
|
|
|
|
|
|
| 81 |
|
| 82 |
-
|
| 83 |
-
```
|
| 84 |
-
[Character Description] + [Behavior/Action] + [Setting/Context]
|
| 85 |
-
```
|
| 86 |
|
| 87 |
-
|
| 88 |
-
- `
|
| 89 |
-
- `
|
| 90 |
-
- `"A calm therapist providing advice with empathetic expressions - cozy office setting"`
|
| 91 |
|
| 92 |
-
|
| 93 |
-
1. **Be specific about appearance** - clothing, hair, age, etc.
|
| 94 |
-
2. **Include desired actions** - gesturing, pointing, demonstrating
|
| 95 |
-
3. **Specify the setting** - office, classroom, studio, outdoor
|
| 96 |
-
4. **Mention emotion/tone** - confident, friendly, professional, energetic
|
| 97 |
|
| 98 |
-
|
|
|
|
|
|
|
|
|
|
| 99 |
|
| 100 |
-
|
| 101 |
-
-
|
| 102 |
-
-
|
| 103 |
-
-
|
| 104 |
|
| 105 |
-
|
| 106 |
-
- **GPU Accelerated**: Optimized for A10G hardware
|
| 107 |
-
- **Generation Time**: ~30-60 seconds per video
|
| 108 |
-
- **Quality**: Professional 480p output with smooth animation
|
| 109 |
|
| 110 |
-
##
|
| 111 |
|
| 112 |
-
|
| 113 |
-
- **
|
| 114 |
-
- **
|
| 115 |
-
- **
|
| 116 |
-
|
| 117 |
-
### **Capabilities:**
|
| 118 |
-
- Audio-driven facial animation with precise lip-sync
|
| 119 |
-
- Adaptive body gestures based on speech content
|
| 120 |
-
- Character consistency with reference images
|
| 121 |
-
- High-quality 480p video output at 25fps
|
| 122 |
-
|
| 123 |
-
## π‘ Important Notes
|
| 124 |
-
|
| 125 |
-
### **This is a VIDEO Generation Application:**
|
| 126 |
-
- π¬ **Primary Output**: MP4 avatar videos with animation
|
| 127 |
-
- π€ **Audio Input**: Text-to-speech or direct audio files
|
| 128 |
-
- π― **Core Feature**: Adaptive body animation synchronized with speech
|
| 129 |
-
- β¨ **Advanced**: Reference image support for character consistency
|
| 130 |
-
|
| 131 |
-
## π References
|
| 132 |
-
|
| 133 |
-
- **OmniAvatar Paper**: [arXiv:2506.18866](https://arxiv.org/abs/2506.18866)
|
| 134 |
-
- **Model Hub**: [OmniAvatar/OmniAvatar-14B](https://huggingface.co/OmniAvatar/OmniAvatar-14B)
|
| 135 |
-
- **Base Model**: [Wan-AI/Wan2.1-T2V-14B](https://huggingface.co/Wan-AI/Wan2.1-T2V-14B)
|
| 136 |
|
| 137 |
---
|
| 138 |
|
| 139 |
-
|
| 140 |
-
|
|
|
|
| 1 |
+
# ?? AI Avatar Chat - Video Generation
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
+
**Latest Update: 2025-01-08 - Model Download Controls Added!**
|
| 4 |
|
| 5 |
+
Real avatar video generation with lip-sync and natural movement, optimized for Hugging Face Spaces.
|
|
|
|
|
|
|
| 6 |
|
| 7 |
+
## ? New Features (Latest Build):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
+
### ?? **Model Download Controls**
|
| 10 |
+
- **Download Models Button** in the web interface
|
| 11 |
+
- **Real-time status checking** with storage monitoring
|
| 12 |
+
- **Multiple download methods** (UI, API, manual)
|
| 13 |
+
- **Automatic video generation** after models download
|
| 14 |
|
| 15 |
+
### ?? **Real Video Generation**
|
| 16 |
+
- Uses `ali-vilab/text-to-video-ms-1.7b` for text-to-video
|
| 17 |
+
- Uses `facebook/wav2vec2-base-960h` for audio processing
|
| 18 |
+
- Total model size: ~3GB (optimized for HF Spaces)
|
| 19 |
+
- Professional quality avatar videos with lip sync
|
| 20 |
|
| 21 |
+
### ?? **HF Spaces Optimized**
|
| 22 |
+
- Storage usage monitoring
|
| 23 |
+
- Graceful fallback to TTS if models unavailable
|
| 24 |
+
- Smart storage checking before downloads
|
| 25 |
+
- Progressive enhancement architecture
|
| 26 |
|
| 27 |
+
## ?? Quick Start:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
+
1. **Visit the Space**: Model management section in the interface
|
| 30 |
+
2. **Download Models**: Click "?? Download Models" button
|
| 31 |
+
3. **Generate Videos**: Use the video generation interface
|
| 32 |
+
4. **API Access**: POST /generate for programmatic access
|
| 33 |
|
| 34 |
+
## ?? API Endpoints:
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
+
- `POST /generate` - Generate avatar video
|
| 37 |
+
- `POST /download-models` - Download video generation models
|
| 38 |
+
- `GET /model-status` - Check model download status
|
|
|
|
| 39 |
|
| 40 |
+
## ?? Model Details:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
+
**Text-to-Video Model**: ali-vilab/text-to-video-ms-1.7b (~2.5GB)
|
| 43 |
+
- High-quality text-to-video generation
|
| 44 |
+
- Optimized for avatar creation
|
| 45 |
+
- Professional results
|
| 46 |
|
| 47 |
+
**Audio Processing**: facebook/wav2vec2-base-960h (~0.36GB)
|
| 48 |
+
- Audio feature extraction
|
| 49 |
+
- Lip-sync synchronization
|
| 50 |
+
- Natural speech processing
|
| 51 |
|
| 52 |
+
**Total Storage**: ~3GB (fits comfortably in HF Spaces 50GB limit)
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
+
## ?? Technical Stack:
|
| 55 |
|
| 56 |
+
- **Frontend**: Gradio with custom model management UI
|
| 57 |
+
- **Backend**: FastAPI with async video generation
|
| 58 |
+
- **Models**: Hugging Face Transformers + Diffusers
|
| 59 |
+
- **Storage**: Optimized caching with automatic cleanup
|
| 60 |
+
- **Deployment**: HF Spaces compatible with graceful fallbacks
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
---
|
| 63 |
|
| 64 |
+
**Build Version**: 2025-01-08_00-11
|
| 65 |
+
**Status**: ? Ready for video generation after model download
|
|
@@ -1,3 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
|
| 3 |
# STORAGE OPTIMIZATION: Check if running on HF Spaces and disable model downloads
|
|
@@ -971,3 +976,5 @@ if __name__ == "__main__":
|
|
| 971 |
|
| 972 |
|
| 973 |
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
AI Avatar Chat - HF Spaces Optimized Version
|
| 3 |
+
BUILD: 2025-01-08_00-11 - With Model Download Controls
|
| 4 |
+
FEATURES: Real video generation, model download UI, storage optimization
|
| 5 |
+
"""
|
| 6 |
import os
|
| 7 |
|
| 8 |
# STORAGE OPTIMIZATION: Check if running on HF Spaces and disable model downloads
|
|
|
|
| 976 |
|
| 977 |
|
| 978 |
|
| 979 |
+
|
| 980 |
+
|
|
@@ -0,0 +1,1053 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
# STORAGE OPTIMIZATION: Check if running on HF Spaces and disable model downloads
|
| 4 |
+
IS_HF_SPACE = any([
|
| 5 |
+
os.getenv("SPACE_ID"),
|
| 6 |
+
os.getenv("SPACE_AUTHOR_NAME"),
|
| 7 |
+
os.getenv("SPACES_BUILDKIT_VERSION"),
|
| 8 |
+
"/home/user/app" in os.getcwd()
|
| 9 |
+
])
|
| 10 |
+
|
| 11 |
+
if IS_HF_SPACE:
|
| 12 |
+
# Force TTS-only mode to prevent storage limit exceeded
|
| 13 |
+
os.environ["DISABLE_MODEL_DOWNLOAD"] = "1"
|
| 14 |
+
os.environ["TTS_ONLY_MODE"] = "1"
|
| 15 |
+
os.environ["HF_SPACE_STORAGE_OPTIMIZED"] = "1"
|
| 16 |
+
print("?? STORAGE OPTIMIZATION: Detected HF Space environment")
|
| 17 |
+
print("??? TTS-only mode ENABLED (video generation disabled for storage limits)")
|
| 18 |
+
print("?? Model auto-download DISABLED to prevent storage exceeded error")
|
| 19 |
+
import os
|
| 20 |
+
import torch
|
| 21 |
+
import tempfile
|
| 22 |
+
import gradio as gr
|
| 23 |
+
from fastapi import FastAPI, HTTPException
|
| 24 |
+
from fastapi.staticfiles import StaticFiles
|
| 25 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 26 |
+
from pydantic import BaseModel, HttpUrl
|
| 27 |
+
import subprocess
|
| 28 |
+
import json
|
| 29 |
+
from pathlib import Path
|
| 30 |
+
import logging
|
| 31 |
+
import requests
|
| 32 |
+
from urllib.parse import urlparse
|
| 33 |
+
from PIL import Image
|
| 34 |
+
import io
|
| 35 |
+
from typing import Optional
|
| 36 |
+
import aiohttp
|
| 37 |
+
import asyncio
|
| 38 |
+
# Safe dotenv import
|
| 39 |
+
try:
|
| 40 |
+
from dotenv import load_dotenv
|
| 41 |
+
load_dotenv()
|
| 42 |
+
except ImportError:
|
| 43 |
+
print("Warning: python-dotenv not found, continuing without .env support")
|
| 44 |
+
def load_dotenv():
|
| 45 |
+
pass
|
| 46 |
+
|
| 47 |
+
# CRITICAL: HF Spaces compatibility fix
|
| 48 |
+
try:
|
| 49 |
+
from hf_spaces_fix import setup_hf_spaces_environment, HFSpacesCompatible
|
| 50 |
+
setup_hf_spaces_environment()
|
| 51 |
+
except ImportError:
|
| 52 |
+
print('Warning: HF Spaces fix not available')
|
| 53 |
+
|
| 54 |
+
# Load environment variables
|
| 55 |
+
load_dotenv()
|
| 56 |
+
|
| 57 |
+
# Set up logging
|
| 58 |
+
logging.basicConfig(level=logging.INFO)
|
| 59 |
+
logger = logging.getLogger(__name__)
|
| 60 |
+
|
| 61 |
+
# Set environment variables for matplotlib, gradio, and huggingface cache
|
| 62 |
+
os.environ['MPLCONFIGDIR'] = '/tmp/matplotlib'
|
| 63 |
+
os.environ['GRADIO_ALLOW_FLAGGING'] = 'never'
|
| 64 |
+
os.environ['HF_HOME'] = '/tmp/huggingface'
|
| 65 |
+
# Use HF_HOME instead of deprecated TRANSFORMERS_CACHE
|
| 66 |
+
os.environ['HF_DATASETS_CACHE'] = '/tmp/huggingface/datasets'
|
| 67 |
+
os.environ['HUGGINGFACE_HUB_CACHE'] = '/tmp/huggingface/hub'
|
| 68 |
+
|
| 69 |
+
# FastAPI app will be created after lifespan is defined
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
# Create directories with proper permissions
|
| 74 |
+
os.makedirs("outputs", exist_ok=True)
|
| 75 |
+
os.makedirs("/tmp/matplotlib", exist_ok=True)
|
| 76 |
+
os.makedirs("/tmp/huggingface", exist_ok=True)
|
| 77 |
+
os.makedirs("/tmp/huggingface/transformers", exist_ok=True)
|
| 78 |
+
os.makedirs("/tmp/huggingface/datasets", exist_ok=True)
|
| 79 |
+
os.makedirs("/tmp/huggingface/hub", exist_ok=True)
|
| 80 |
+
|
| 81 |
+
# Mount static files for serving generated videos
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def get_video_url(output_path: str) -> str:
|
| 85 |
+
"""Convert local file path to accessible URL"""
|
| 86 |
+
try:
|
| 87 |
+
from pathlib import Path
|
| 88 |
+
filename = Path(output_path).name
|
| 89 |
+
|
| 90 |
+
# For HuggingFace Spaces, construct the URL
|
| 91 |
+
base_url = "https://bravedims-ai-avatar-chat.hf.space"
|
| 92 |
+
video_url = f"{base_url}/outputs/{filename}"
|
| 93 |
+
logger.info(f"Generated video URL: {video_url}")
|
| 94 |
+
return video_url
|
| 95 |
+
except Exception as e:
|
| 96 |
+
logger.error(f"Error creating video URL: {e}")
|
| 97 |
+
return output_path # Fallback to original path
|
| 98 |
+
|
| 99 |
+
# Pydantic models for request/response
|
| 100 |
+
class GenerateRequest(BaseModel):
|
| 101 |
+
prompt: str
|
| 102 |
+
text_to_speech: Optional[str] = None # Text to convert to speech
|
| 103 |
+
audio_url: Optional[HttpUrl] = None # Direct audio URL
|
| 104 |
+
voice_id: Optional[str] = "21m00Tcm4TlvDq8ikWAM" # Voice profile ID
|
| 105 |
+
image_url: Optional[HttpUrl] = None
|
| 106 |
+
guidance_scale: float = 5.0
|
| 107 |
+
audio_scale: float = 3.0
|
| 108 |
+
num_steps: int = 30
|
| 109 |
+
sp_size: int = 1
|
| 110 |
+
tea_cache_l1_thresh: Optional[float] = None
|
| 111 |
+
|
| 112 |
+
class GenerateResponse(BaseModel):
|
| 113 |
+
message: str
|
| 114 |
+
output_path: str
|
| 115 |
+
processing_time: float
|
| 116 |
+
audio_generated: bool = False
|
| 117 |
+
tts_method: Optional[str] = None
|
| 118 |
+
|
| 119 |
+
# Try to import TTS clients, but make them optional
|
| 120 |
+
try:
|
| 121 |
+
from advanced_tts_client import AdvancedTTSClient
|
| 122 |
+
ADVANCED_TTS_AVAILABLE = True
|
| 123 |
+
logger.info("SUCCESS: Advanced TTS client available")
|
| 124 |
+
except ImportError as e:
|
| 125 |
+
ADVANCED_TTS_AVAILABLE = False
|
| 126 |
+
logger.warning(f"WARNING: Advanced TTS client not available: {e}")
|
| 127 |
+
|
| 128 |
+
# Always import the robust fallback
|
| 129 |
+
try:
|
| 130 |
+
from robust_tts_client import RobustTTSClient
|
| 131 |
+
ROBUST_TTS_AVAILABLE = True
|
| 132 |
+
logger.info("SUCCESS: Robust TTS client available")
|
| 133 |
+
except ImportError as e:
|
| 134 |
+
ROBUST_TTS_AVAILABLE = False
|
| 135 |
+
logger.error(f"ERROR: Robust TTS client not available: {e}")
|
| 136 |
+
|
| 137 |
+
class TTSManager:
|
| 138 |
+
"""Manages multiple TTS clients with fallback chain"""
|
| 139 |
+
|
| 140 |
+
def __init__(self):
|
| 141 |
+
# Initialize TTS clients based on availability
|
| 142 |
+
self.advanced_tts = None
|
| 143 |
+
self.robust_tts = None
|
| 144 |
+
self.clients_loaded = False
|
| 145 |
+
|
| 146 |
+
if ADVANCED_TTS_AVAILABLE:
|
| 147 |
+
try:
|
| 148 |
+
self.advanced_tts = AdvancedTTSClient()
|
| 149 |
+
logger.info("SUCCESS: Advanced TTS client initialized")
|
| 150 |
+
except Exception as e:
|
| 151 |
+
logger.warning(f"WARNING: Advanced TTS client initialization failed: {e}")
|
| 152 |
+
|
| 153 |
+
if ROBUST_TTS_AVAILABLE:
|
| 154 |
+
try:
|
| 155 |
+
self.robust_tts = RobustTTSClient()
|
| 156 |
+
logger.info("SUCCESS: Robust TTS client initialized")
|
| 157 |
+
except Exception as e:
|
| 158 |
+
logger.error(f"ERROR: Robust TTS client initialization failed: {e}")
|
| 159 |
+
|
| 160 |
+
if not self.advanced_tts and not self.robust_tts:
|
| 161 |
+
logger.error("ERROR: No TTS clients available!")
|
| 162 |
+
|
| 163 |
+
async def load_models(self):
|
| 164 |
+
"""Load TTS models"""
|
| 165 |
+
try:
|
| 166 |
+
logger.info("Loading TTS models...")
|
| 167 |
+
|
| 168 |
+
# Try to load advanced TTS first
|
| 169 |
+
if self.advanced_tts:
|
| 170 |
+
try:
|
| 171 |
+
logger.info("[PROCESS] Loading advanced TTS models (this may take a few minutes)...")
|
| 172 |
+
success = await self.advanced_tts.load_models()
|
| 173 |
+
if success:
|
| 174 |
+
logger.info("SUCCESS: Advanced TTS models loaded successfully")
|
| 175 |
+
else:
|
| 176 |
+
logger.warning("WARNING: Advanced TTS models failed to load")
|
| 177 |
+
except Exception as e:
|
| 178 |
+
logger.warning(f"WARNING: Advanced TTS loading error: {e}")
|
| 179 |
+
|
| 180 |
+
# Always ensure robust TTS is available
|
| 181 |
+
if self.robust_tts:
|
| 182 |
+
try:
|
| 183 |
+
await self.robust_tts.load_model()
|
| 184 |
+
logger.info("SUCCESS: Robust TTS fallback ready")
|
| 185 |
+
except Exception as e:
|
| 186 |
+
logger.error(f"ERROR: Robust TTS loading failed: {e}")
|
| 187 |
+
|
| 188 |
+
self.clients_loaded = True
|
| 189 |
+
return True
|
| 190 |
+
|
| 191 |
+
except Exception as e:
|
| 192 |
+
logger.error(f"ERROR: TTS manager initialization failed: {e}")
|
| 193 |
+
return False
|
| 194 |
+
|
| 195 |
+
async def text_to_speech(self, text: str, voice_id: Optional[str] = None) -> tuple[str, str]:
|
| 196 |
+
"""
|
| 197 |
+
Convert text to speech with fallback chain
|
| 198 |
+
Returns: (audio_file_path, method_used)
|
| 199 |
+
"""
|
| 200 |
+
if not self.clients_loaded:
|
| 201 |
+
logger.info("TTS models not loaded, loading now...")
|
| 202 |
+
await self.load_models()
|
| 203 |
+
|
| 204 |
+
logger.info(f"Generating speech: {text[:50]}...")
|
| 205 |
+
logger.info(f"Voice ID: {voice_id}")
|
| 206 |
+
|
| 207 |
+
# Try Advanced TTS first (Facebook VITS / SpeechT5)
|
| 208 |
+
if self.advanced_tts:
|
| 209 |
+
try:
|
| 210 |
+
audio_path = await self.advanced_tts.text_to_speech(text, voice_id)
|
| 211 |
+
return audio_path, "Facebook VITS/SpeechT5"
|
| 212 |
+
except Exception as advanced_error:
|
| 213 |
+
logger.warning(f"Advanced TTS failed: {advanced_error}")
|
| 214 |
+
|
| 215 |
+
# Fall back to robust TTS
|
| 216 |
+
if self.robust_tts:
|
| 217 |
+
try:
|
| 218 |
+
logger.info("Falling back to robust TTS...")
|
| 219 |
+
audio_path = await self.robust_tts.text_to_speech(text, voice_id)
|
| 220 |
+
return audio_path, "Robust TTS (Fallback)"
|
| 221 |
+
except Exception as robust_error:
|
| 222 |
+
logger.error(f"Robust TTS also failed: {robust_error}")
|
| 223 |
+
|
| 224 |
+
# If we get here, all methods failed
|
| 225 |
+
logger.error("All TTS methods failed!")
|
| 226 |
+
raise HTTPException(
|
| 227 |
+
status_code=500,
|
| 228 |
+
detail="All TTS methods failed. Please check system configuration."
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
async def get_available_voices(self):
|
| 232 |
+
"""Get available voice configurations"""
|
| 233 |
+
try:
|
| 234 |
+
if self.advanced_tts and hasattr(self.advanced_tts, 'get_available_voices'):
|
| 235 |
+
return await self.advanced_tts.get_available_voices()
|
| 236 |
+
except:
|
| 237 |
+
pass
|
| 238 |
+
|
| 239 |
+
# Return default voices if advanced TTS not available
|
| 240 |
+
return {
|
| 241 |
+
"21m00Tcm4TlvDq8ikWAM": "Female (Neutral)",
|
| 242 |
+
"pNInz6obpgDQGcFmaJgB": "Male (Professional)",
|
| 243 |
+
"EXAVITQu4vr4xnSDxMaL": "Female (Sweet)",
|
| 244 |
+
"ErXwobaYiN019PkySvjV": "Male (Professional)",
|
| 245 |
+
"TxGEqnHWrfGW9XjX": "Male (Deep)",
|
| 246 |
+
"yoZ06aMxZJJ28mfd3POQ": "Unisex (Friendly)",
|
| 247 |
+
"AZnzlk1XvdvUeBnXmlld": "Female (Strong)"
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
def get_tts_info(self):
|
| 251 |
+
"""Get TTS system information"""
|
| 252 |
+
info = {
|
| 253 |
+
"clients_loaded": self.clients_loaded,
|
| 254 |
+
"advanced_tts_available": self.advanced_tts is not None,
|
| 255 |
+
"robust_tts_available": self.robust_tts is not None,
|
| 256 |
+
"primary_method": "Robust TTS"
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
try:
|
| 260 |
+
if self.advanced_tts and hasattr(self.advanced_tts, 'get_model_info'):
|
| 261 |
+
advanced_info = self.advanced_tts.get_model_info()
|
| 262 |
+
info.update({
|
| 263 |
+
"advanced_tts_loaded": advanced_info.get("models_loaded", False),
|
| 264 |
+
"transformers_available": advanced_info.get("transformers_available", False),
|
| 265 |
+
"primary_method": "Facebook VITS/SpeechT5" if advanced_info.get("models_loaded") else "Robust TTS",
|
| 266 |
+
"device": advanced_info.get("device", "cpu"),
|
| 267 |
+
"vits_available": advanced_info.get("vits_available", False),
|
| 268 |
+
"speecht5_available": advanced_info.get("speecht5_available", False)
|
| 269 |
+
})
|
| 270 |
+
except Exception as e:
|
| 271 |
+
logger.debug(f"Could not get advanced TTS info: {e}")
|
| 272 |
+
|
| 273 |
+
return info
|
| 274 |
+
|
| 275 |
+
# Import the VIDEO-FOCUSED engine
|
| 276 |
+
try:
|
| 277 |
+
from omniavatar_video_engine import video_engine
|
| 278 |
+
VIDEO_ENGINE_AVAILABLE = True
|
| 279 |
+
logger.info("SUCCESS: OmniAvatar Video Engine available")
|
| 280 |
+
except ImportError as e:
|
| 281 |
+
VIDEO_ENGINE_AVAILABLE = False
|
| 282 |
+
logger.error(f"ERROR: OmniAvatar Video Engine not available: {e}")
|
| 283 |
+
|
| 284 |
+
class OmniAvatarAPI:
|
| 285 |
+
def __init__(self):
|
| 286 |
+
self.model_loaded = False
|
| 287 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 288 |
+
self.tts_manager = TTSManager()
|
| 289 |
+
logger.info(f"Using device: {self.device}")
|
| 290 |
+
logger.info("Initialized with robust TTS system")
|
| 291 |
+
|
| 292 |
+
def load_model(self):
|
| 293 |
+
"""Load the OmniAvatar model - now more flexible"""
|
| 294 |
+
try:
|
| 295 |
+
# Check if models are downloaded (but don't require them)
|
| 296 |
+
model_paths = [
|
| 297 |
+
"./pretrained_models/Wan2.1-T2V-14B",
|
| 298 |
+
"./pretrained_models/OmniAvatar-14B",
|
| 299 |
+
"./pretrained_models/wav2vec2-base-960h"
|
| 300 |
+
]
|
| 301 |
+
|
| 302 |
+
missing_models = []
|
| 303 |
+
for path in model_paths:
|
| 304 |
+
if not os.path.exists(path):
|
| 305 |
+
missing_models.append(path)
|
| 306 |
+
|
| 307 |
+
if missing_models:
|
| 308 |
+
logger.warning("WARNING: Some OmniAvatar models not found:")
|
| 309 |
+
for model in missing_models:
|
| 310 |
+
logger.warning(f" - {model}")
|
| 311 |
+
logger.info("TIP: App will run in TTS-only mode (no video generation)")
|
| 312 |
+
logger.info("TIP: To enable full avatar generation, download the required models")
|
| 313 |
+
|
| 314 |
+
# Set as loaded but in limited mode
|
| 315 |
+
self.model_loaded = False # Video generation disabled
|
| 316 |
+
return True # But app can still run
|
| 317 |
+
else:
|
| 318 |
+
self.model_loaded = True
|
| 319 |
+
logger.info("SUCCESS: All OmniAvatar models found - full functionality enabled")
|
| 320 |
+
return True
|
| 321 |
+
|
| 322 |
+
except Exception as e:
|
| 323 |
+
logger.error(f"Error checking models: {str(e)}")
|
| 324 |
+
logger.info("TIP: Continuing in TTS-only mode")
|
| 325 |
+
self.model_loaded = False
|
| 326 |
+
return True # Continue running
|
| 327 |
+
|
| 328 |
+
async def download_file(self, url: str, suffix: str = "") -> str:
|
| 329 |
+
"""Download file from URL and save to temporary location"""
|
| 330 |
+
try:
|
| 331 |
+
async with aiohttp.ClientSession() as session:
|
| 332 |
+
async with session.get(str(url)) as response:
|
| 333 |
+
if response.status != 200:
|
| 334 |
+
raise HTTPException(status_code=400, detail=f"Failed to download file from URL: {url}")
|
| 335 |
+
|
| 336 |
+
content = await response.read()
|
| 337 |
+
|
| 338 |
+
# Create temporary file
|
| 339 |
+
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=suffix)
|
| 340 |
+
temp_file.write(content)
|
| 341 |
+
temp_file.close()
|
| 342 |
+
|
| 343 |
+
return temp_file.name
|
| 344 |
+
|
| 345 |
+
except aiohttp.ClientError as e:
|
| 346 |
+
logger.error(f"Network error downloading {url}: {e}")
|
| 347 |
+
raise HTTPException(status_code=400, detail=f"Network error downloading file: {e}")
|
| 348 |
+
except Exception as e:
|
| 349 |
+
logger.error(f"Error downloading file from {url}: {e}")
|
| 350 |
+
raise HTTPException(status_code=500, detail=f"Error downloading file: {e}")
|
| 351 |
+
|
| 352 |
+
def validate_audio_url(self, url: str) -> bool:
|
| 353 |
+
"""Validate if URL is likely an audio file"""
|
| 354 |
+
try:
|
| 355 |
+
parsed = urlparse(url)
|
| 356 |
+
# Check for common audio file extensions
|
| 357 |
+
audio_extensions = ['.mp3', '.wav', '.m4a', '.ogg', '.aac', '.flac']
|
| 358 |
+
is_audio_ext = any(parsed.path.lower().endswith(ext) for ext in audio_extensions)
|
| 359 |
+
|
| 360 |
+
return is_audio_ext or 'audio' in url.lower()
|
| 361 |
+
except:
|
| 362 |
+
return False
|
| 363 |
+
|
| 364 |
+
def validate_image_url(self, url: str) -> bool:
|
| 365 |
+
"""Validate if URL is likely an image file"""
|
| 366 |
+
try:
|
| 367 |
+
parsed = urlparse(url)
|
| 368 |
+
image_extensions = ['.jpg', '.jpeg', '.png', '.webp', '.bmp', '.gif']
|
| 369 |
+
return any(parsed.path.lower().endswith(ext) for ext in image_extensions)
|
| 370 |
+
except:
|
| 371 |
+
return False
|
| 372 |
+
|
| 373 |
+
async def generate_avatar(self, request: GenerateRequest) -> tuple[str, float, bool, str]:
|
| 374 |
+
"""Generate avatar VIDEO - PRIMARY FUNCTIONALITY"""
|
| 375 |
+
import time
|
| 376 |
+
start_time = time.time()
|
| 377 |
+
audio_generated = False
|
| 378 |
+
method_used = "Unknown"
|
| 379 |
+
|
| 380 |
+
logger.info("[VIDEO] STARTING AVATAR VIDEO GENERATION")
|
| 381 |
+
logger.info(f"[INFO] Prompt: {request.prompt}")
|
| 382 |
+
|
| 383 |
+
if VIDEO_ENGINE_AVAILABLE:
|
| 384 |
+
try:
|
| 385 |
+
# PRIORITIZE VIDEO GENERATION
|
| 386 |
+
logger.info("[TARGET] Using OmniAvatar Video Engine for FULL video generation")
|
| 387 |
+
|
| 388 |
+
# Handle audio source
|
| 389 |
+
audio_path = None
|
| 390 |
+
if request.text_to_speech:
|
| 391 |
+
logger.info("[MIC] Generating audio from text...")
|
| 392 |
+
audio_path, method_used = await self.tts_manager.text_to_speech(
|
| 393 |
+
request.text_to_speech,
|
| 394 |
+
request.voice_id or "21m00Tcm4TlvDq8ikWAM"
|
| 395 |
+
)
|
| 396 |
+
audio_generated = True
|
| 397 |
+
elif request.audio_url:
|
| 398 |
+
logger.info("π₯ Downloading audio from URL...")
|
| 399 |
+
audio_path = await self.download_file(str(request.audio_url), ".mp3")
|
| 400 |
+
method_used = "External Audio"
|
| 401 |
+
else:
|
| 402 |
+
raise HTTPException(status_code=400, detail="Either text_to_speech or audio_url required for video generation")
|
| 403 |
+
|
| 404 |
+
# Handle image if provided
|
| 405 |
+
image_path = None
|
| 406 |
+
if request.image_url:
|
| 407 |
+
logger.info("[IMAGE] Downloading reference image...")
|
| 408 |
+
parsed = urlparse(str(request.image_url))
|
| 409 |
+
ext = os.path.splitext(parsed.path)[1] or ".jpg"
|
| 410 |
+
image_path = await self.download_file(str(request.image_url), ext)
|
| 411 |
+
|
| 412 |
+
# GENERATE VIDEO using OmniAvatar engine
|
| 413 |
+
logger.info("[VIDEO] Generating avatar video with adaptive body animation...")
|
| 414 |
+
video_path, generation_time = video_engine.generate_avatar_video(
|
| 415 |
+
prompt=request.prompt,
|
| 416 |
+
audio_path=audio_path,
|
| 417 |
+
image_path=image_path,
|
| 418 |
+
guidance_scale=request.guidance_scale,
|
| 419 |
+
audio_scale=request.audio_scale,
|
| 420 |
+
num_steps=request.num_steps
|
| 421 |
+
)
|
| 422 |
+
|
| 423 |
+
processing_time = time.time() - start_time
|
| 424 |
+
logger.info(f"SUCCESS: VIDEO GENERATED successfully in {processing_time:.1f}s")
|
| 425 |
+
|
| 426 |
+
# Cleanup temporary files
|
| 427 |
+
if audio_path and os.path.exists(audio_path):
|
| 428 |
+
os.unlink(audio_path)
|
| 429 |
+
if image_path and os.path.exists(image_path):
|
| 430 |
+
os.unlink(image_path)
|
| 431 |
+
|
| 432 |
+
return video_path, processing_time, audio_generated, f"OmniAvatar Video Generation ({method_used})"
|
| 433 |
+
|
| 434 |
+
except Exception as e:
|
| 435 |
+
logger.error(f"ERROR: Video generation failed: {e}")
|
| 436 |
+
# For a VIDEO generation app, we should NOT fall back to audio-only
|
| 437 |
+
# Instead, provide clear guidance
|
| 438 |
+
if "models" in str(e).lower():
|
| 439 |
+
raise HTTPException(
|
| 440 |
+
status_code=503,
|
| 441 |
+
detail=f"Video generation requires OmniAvatar models (~30GB). Please run model download script. Error: {str(e)}"
|
| 442 |
+
)
|
| 443 |
+
else:
|
| 444 |
+
raise HTTPException(status_code=500, detail=f"Video generation failed: {str(e)}")
|
| 445 |
+
|
| 446 |
+
# If video engine not available, this is a critical error for a VIDEO app
|
| 447 |
+
raise HTTPException(
|
| 448 |
+
status_code=503,
|
| 449 |
+
detail="Video generation engine not available. This application requires OmniAvatar models for video generation."
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
async def generate_avatar_BACKUP(self, request: GenerateRequest) -> tuple[str, float, bool, str]:
|
| 453 |
+
"""OLD TTS-ONLY METHOD - kept as backup reference.
|
| 454 |
+
Generate avatar video from prompt and audio/text - now handles missing models"""
|
| 455 |
+
import time
|
| 456 |
+
start_time = time.time()
|
| 457 |
+
audio_generated = False
|
| 458 |
+
tts_method = None
|
| 459 |
+
|
| 460 |
+
try:
|
| 461 |
+
# Check if video generation is available
|
| 462 |
+
if not self.model_loaded:
|
| 463 |
+
logger.info("ποΈ Running in TTS-only mode (OmniAvatar models not available)")
|
| 464 |
+
|
| 465 |
+
# Only generate audio, no video
|
| 466 |
+
if request.text_to_speech:
|
| 467 |
+
logger.info(f"Generating speech from text: {request.text_to_speech[:50]}...")
|
| 468 |
+
audio_path, tts_method = await self.tts_manager.text_to_speech(
|
| 469 |
+
request.text_to_speech,
|
| 470 |
+
request.voice_id or "21m00Tcm4TlvDq8ikWAM"
|
| 471 |
+
)
|
| 472 |
+
|
| 473 |
+
# Return the audio file as the "output"
|
| 474 |
+
processing_time = time.time() - start_time
|
| 475 |
+
logger.info(f"SUCCESS: TTS completed in {processing_time:.1f}s using {tts_method}")
|
| 476 |
+
return audio_path, processing_time, True, f"{tts_method} (TTS-only mode)"
|
| 477 |
+
else:
|
| 478 |
+
raise HTTPException(
|
| 479 |
+
status_code=503,
|
| 480 |
+
detail="Video generation unavailable. OmniAvatar models not found. Only TTS from text is supported."
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
# Original video generation logic (when models are available)
|
| 484 |
+
# Determine audio source
|
| 485 |
+
audio_path = None
|
| 486 |
+
|
| 487 |
+
if request.text_to_speech:
|
| 488 |
+
# Generate speech from text using TTS manager
|
| 489 |
+
logger.info(f"Generating speech from text: {request.text_to_speech[:50]}...")
|
| 490 |
+
audio_path, tts_method = await self.tts_manager.text_to_speech(
|
| 491 |
+
request.text_to_speech,
|
| 492 |
+
request.voice_id or "21m00Tcm4TlvDq8ikWAM"
|
| 493 |
+
)
|
| 494 |
+
audio_generated = True
|
| 495 |
+
|
| 496 |
+
elif request.audio_url:
|
| 497 |
+
# Download audio from provided URL
|
| 498 |
+
logger.info(f"Downloading audio from URL: {request.audio_url}")
|
| 499 |
+
if not self.validate_audio_url(str(request.audio_url)):
|
| 500 |
+
logger.warning(f"Audio URL may not be valid: {request.audio_url}")
|
| 501 |
+
|
| 502 |
+
audio_path = await self.download_file(str(request.audio_url), ".mp3")
|
| 503 |
+
tts_method = "External Audio URL"
|
| 504 |
+
|
| 505 |
+
else:
|
| 506 |
+
raise HTTPException(
|
| 507 |
+
status_code=400,
|
| 508 |
+
detail="Either text_to_speech or audio_url must be provided"
|
| 509 |
+
)
|
| 510 |
+
|
| 511 |
+
# Download image if provided
|
| 512 |
+
image_path = None
|
| 513 |
+
if request.image_url:
|
| 514 |
+
logger.info(f"Downloading image from URL: {request.image_url}")
|
| 515 |
+
if not self.validate_image_url(str(request.image_url)):
|
| 516 |
+
logger.warning(f"Image URL may not be valid: {request.image_url}")
|
| 517 |
+
|
| 518 |
+
# Determine image extension from URL or default to .jpg
|
| 519 |
+
parsed = urlparse(str(request.image_url))
|
| 520 |
+
ext = os.path.splitext(parsed.path)[1] or ".jpg"
|
| 521 |
+
image_path = await self.download_file(str(request.image_url), ext)
|
| 522 |
+
|
| 523 |
+
# Create temporary input file for inference
|
| 524 |
+
with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f:
|
| 525 |
+
if image_path:
|
| 526 |
+
input_line = f"{request.prompt}@@{image_path}@@{audio_path}"
|
| 527 |
+
else:
|
| 528 |
+
input_line = f"{request.prompt}@@@@{audio_path}"
|
| 529 |
+
f.write(input_line)
|
| 530 |
+
temp_input_file = f.name
|
| 531 |
+
|
| 532 |
+
# Prepare inference command
|
| 533 |
+
cmd = [
|
| 534 |
+
"python", "-m", "torch.distributed.run",
|
| 535 |
+
"--standalone", f"--nproc_per_node={request.sp_size}",
|
| 536 |
+
"scripts/inference.py",
|
| 537 |
+
"--config", "configs/inference.yaml",
|
| 538 |
+
"--input_file", temp_input_file,
|
| 539 |
+
"--guidance_scale", str(request.guidance_scale),
|
| 540 |
+
"--audio_scale", str(request.audio_scale),
|
| 541 |
+
"--num_steps", str(request.num_steps)
|
| 542 |
+
]
|
| 543 |
+
|
| 544 |
+
if request.tea_cache_l1_thresh:
|
| 545 |
+
cmd.extend(["--tea_cache_l1_thresh", str(request.tea_cache_l1_thresh)])
|
| 546 |
+
|
| 547 |
+
logger.info(f"Running inference with command: {' '.join(cmd)}")
|
| 548 |
+
|
| 549 |
+
# Run inference
|
| 550 |
+
result = subprocess.run(cmd, capture_output=True, text=True)
|
| 551 |
+
|
| 552 |
+
# Clean up temporary files
|
| 553 |
+
os.unlink(temp_input_file)
|
| 554 |
+
os.unlink(audio_path)
|
| 555 |
+
if image_path:
|
| 556 |
+
os.unlink(image_path)
|
| 557 |
+
|
| 558 |
+
if result.returncode != 0:
|
| 559 |
+
logger.error(f"Inference failed: {result.stderr}")
|
| 560 |
+
raise Exception(f"Inference failed: {result.stderr}")
|
| 561 |
+
|
| 562 |
+
# Find output video file
|
| 563 |
+
output_dir = "./outputs"
|
| 564 |
+
if os.path.exists(output_dir):
|
| 565 |
+
video_files = [f for f in os.listdir(output_dir) if f.endswith(('.mp4', '.avi'))]
|
| 566 |
+
if video_files:
|
| 567 |
+
# Return the most recent video file
|
| 568 |
+
video_files.sort(key=lambda x: os.path.getmtime(os.path.join(output_dir, x)), reverse=True)
|
| 569 |
+
output_path = os.path.join(output_dir, video_files[0])
|
| 570 |
+
processing_time = time.time() - start_time
|
| 571 |
+
return output_path, processing_time, audio_generated, tts_method
|
| 572 |
+
|
| 573 |
+
raise Exception("No output video generated")
|
| 574 |
+
|
| 575 |
+
except Exception as e:
|
| 576 |
+
# Clean up any temporary files in case of error
|
| 577 |
+
try:
|
| 578 |
+
if 'audio_path' in locals() and audio_path and os.path.exists(audio_path):
|
| 579 |
+
os.unlink(audio_path)
|
| 580 |
+
if 'image_path' in locals() and image_path and os.path.exists(image_path):
|
| 581 |
+
os.unlink(image_path)
|
| 582 |
+
if 'temp_input_file' in locals() and os.path.exists(temp_input_file):
|
| 583 |
+
os.unlink(temp_input_file)
|
| 584 |
+
except:
|
| 585 |
+
pass
|
| 586 |
+
|
| 587 |
+
logger.error(f"Generation error: {str(e)}")
|
| 588 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 589 |
+
|
| 590 |
+
# Initialize API
|
| 591 |
+
omni_api = OmniAvatarAPI()
|
| 592 |
+
|
| 593 |
+
# Use FastAPI lifespan instead of deprecated on_event
|
| 594 |
+
from contextlib import asynccontextmanager
|
| 595 |
+
|
| 596 |
+
@asynccontextmanager
|
| 597 |
+
async def lifespan(app: FastAPI):
|
| 598 |
+
# Startup
|
| 599 |
+
success = omni_api.load_model()
|
| 600 |
+
if not success:
|
| 601 |
+
logger.warning("WARNING: OmniAvatar model loading failed - running in limited mode")
|
| 602 |
+
|
| 603 |
+
# Load TTS models
|
| 604 |
+
try:
|
| 605 |
+
await omni_api.tts_manager.load_models()
|
| 606 |
+
logger.info("SUCCESS: TTS models initialization completed")
|
| 607 |
+
except Exception as e:
|
| 608 |
+
logger.error(f"ERROR: TTS initialization failed: {e}")
|
| 609 |
+
|
| 610 |
+
yield
|
| 611 |
+
|
| 612 |
+
# Shutdown (if needed)
|
| 613 |
+
logger.info("Application shutting down...")
|
| 614 |
+
|
| 615 |
+
# Create FastAPI app WITH lifespan parameter
|
| 616 |
+
app = FastAPI(
|
| 617 |
+
title="OmniAvatar-14B API with Advanced TTS",
|
| 618 |
+
version="1.0.0",
|
| 619 |
+
lifespan=lifespan
|
| 620 |
+
)
|
| 621 |
+
|
| 622 |
+
# Add CORS middleware
|
| 623 |
+
app.add_middleware(
|
| 624 |
+
CORSMiddleware,
|
| 625 |
+
allow_origins=["*"],
|
| 626 |
+
allow_credentials=True,
|
| 627 |
+
allow_methods=["*"],
|
| 628 |
+
allow_headers=["*"],
|
| 629 |
+
)
|
| 630 |
+
|
| 631 |
+
# Mount static files for serving generated videos
|
| 632 |
+
app.mount("/outputs", StaticFiles(directory="outputs"), name="outputs")
|
| 633 |
+
|
| 634 |
+
@app.get("/health")
|
| 635 |
+
async def health_check():
|
| 636 |
+
"""Health check endpoint"""
|
| 637 |
+
tts_info = omni_api.tts_manager.get_tts_info()
|
| 638 |
+
|
| 639 |
+
return {
|
| 640 |
+
"status": "healthy",
|
| 641 |
+
"model_loaded": omni_api.model_loaded,
|
| 642 |
+
"video_generation_available": omni_api.model_loaded,
|
| 643 |
+
"tts_only_mode": not omni_api.model_loaded,
|
| 644 |
+
"device": omni_api.device,
|
| 645 |
+
"supports_text_to_speech": True,
|
| 646 |
+
"supports_image_urls": omni_api.model_loaded,
|
| 647 |
+
"supports_audio_urls": omni_api.model_loaded,
|
| 648 |
+
"tts_system": "Advanced TTS with Robust Fallback",
|
| 649 |
+
"advanced_tts_available": ADVANCED_TTS_AVAILABLE,
|
| 650 |
+
"robust_tts_available": ROBUST_TTS_AVAILABLE,
|
| 651 |
+
**tts_info
|
| 652 |
+
}
|
| 653 |
+
|
| 654 |
+
@app.get("/voices")
|
| 655 |
+
async def get_voices():
|
| 656 |
+
"""Get available voice configurations"""
|
| 657 |
+
try:
|
| 658 |
+
voices = await omni_api.tts_manager.get_available_voices()
|
| 659 |
+
return {"voices": voices}
|
| 660 |
+
except Exception as e:
|
| 661 |
+
logger.error(f"Error getting voices: {e}")
|
| 662 |
+
return {"error": str(e)}
|
| 663 |
+
|
| 664 |
+
@app.post("/generate", response_model=GenerateResponse)
|
| 665 |
+
async def generate_avatar(request: GenerateRequest):
|
| 666 |
+
"""Generate avatar video from prompt, text/audio, and optional image URL"""
|
| 667 |
+
|
| 668 |
+
logger.info(f"Generating avatar with prompt: {request.prompt}")
|
| 669 |
+
if request.text_to_speech:
|
| 670 |
+
logger.info(f"Text to speech: {request.text_to_speech[:100]}...")
|
| 671 |
+
logger.info(f"Voice ID: {request.voice_id}")
|
| 672 |
+
if request.audio_url:
|
| 673 |
+
logger.info(f"Audio URL: {request.audio_url}")
|
| 674 |
+
if request.image_url:
|
| 675 |
+
logger.info(f"Image URL: {request.image_url}")
|
| 676 |
+
|
| 677 |
+
try:
|
| 678 |
+
output_path, processing_time, audio_generated, tts_method = await omni_api.generate_avatar(request)
|
| 679 |
+
|
| 680 |
+
return GenerateResponse(
|
| 681 |
+
message="Generation completed successfully" + (" (TTS-only mode)" if not omni_api.model_loaded else ""),
|
| 682 |
+
output_path=get_video_url(output_path) if omni_api.model_loaded else output_path,
|
| 683 |
+
processing_time=processing_time,
|
| 684 |
+
audio_generated=audio_generated,
|
| 685 |
+
tts_method=tts_method
|
| 686 |
+
)
|
| 687 |
+
|
| 688 |
+
except HTTPException:
|
| 689 |
+
raise
|
| 690 |
+
except Exception as e:
|
| 691 |
+
logger.error(f"Unexpected error: {e}")
|
| 692 |
+
raise HTTPException(status_code=500, detail=f"Unexpected error: {e}")
|
| 693 |
+
|
| 694 |
+
@app.post("/download-models")
|
| 695 |
+
async def download_video_models():
|
| 696 |
+
"""Manually trigger video model downloads"""
|
| 697 |
+
logger.info("?? Manual model download requested...")
|
| 698 |
+
|
| 699 |
+
try:
|
| 700 |
+
from huggingface_hub import snapshot_download
|
| 701 |
+
import shutil
|
| 702 |
+
|
| 703 |
+
# Check storage first
|
| 704 |
+
_, _, free_bytes = shutil.disk_usage(".")
|
| 705 |
+
free_gb = free_bytes / (1024**3)
|
| 706 |
+
|
| 707 |
+
logger.info(f"?? Available storage: {free_gb:.1f}GB")
|
| 708 |
+
|
| 709 |
+
if free_gb < 10: # Need at least 10GB free
|
| 710 |
+
return {
|
| 711 |
+
"success": False,
|
| 712 |
+
"message": f"Insufficient storage: {free_gb:.1f}GB available, 10GB+ required",
|
| 713 |
+
"storage_gb": free_gb
|
| 714 |
+
}
|
| 715 |
+
|
| 716 |
+
# Download small video generation model
|
| 717 |
+
logger.info("?? Downloading text-to-video model...")
|
| 718 |
+
|
| 719 |
+
model_path = snapshot_download(
|
| 720 |
+
repo_id="ali-vilab/text-to-video-ms-1.7b",
|
| 721 |
+
cache_dir="./downloaded_models/video",
|
| 722 |
+
local_files_only=False
|
| 723 |
+
)
|
| 724 |
+
|
| 725 |
+
logger.info(f"? Video model downloaded: {model_path}")
|
| 726 |
+
|
| 727 |
+
# Download audio model
|
| 728 |
+
audio_model_path = snapshot_download(
|
| 729 |
+
repo_id="facebook/wav2vec2-base-960h",
|
| 730 |
+
cache_dir="./downloaded_models/audio",
|
| 731 |
+
local_files_only=False
|
| 732 |
+
)
|
| 733 |
+
|
| 734 |
+
logger.info(f"? Audio model downloaded: {audio_model_path}")
|
| 735 |
+
|
| 736 |
+
# Check final storage usage
|
| 737 |
+
_, _, free_bytes_after = shutil.disk_usage(".")
|
| 738 |
+
free_gb_after = free_bytes_after / (1024**3)
|
| 739 |
+
used_gb = free_gb - free_gb_after
|
| 740 |
+
|
| 741 |
+
return {
|
| 742 |
+
"success": True,
|
| 743 |
+
"message": "? Video generation models downloaded successfully!",
|
| 744 |
+
"models_downloaded": [
|
| 745 |
+
"ali-vilab/text-to-video-ms-1.7b",
|
| 746 |
+
"facebook/wav2vec2-base-960h"
|
| 747 |
+
],
|
| 748 |
+
"storage_used_gb": round(used_gb, 2),
|
| 749 |
+
"storage_remaining_gb": round(free_gb_after, 2),
|
| 750 |
+
"video_model_path": model_path,
|
| 751 |
+
"audio_model_path": audio_model_path,
|
| 752 |
+
"status": "READY FOR VIDEO GENERATION"
|
| 753 |
+
}
|
| 754 |
+
|
| 755 |
+
except Exception as e:
|
| 756 |
+
logger.error(f"? Model download failed: {e}")
|
| 757 |
+
return {
|
| 758 |
+
"success": False,
|
| 759 |
+
"message": f"Model download failed: {str(e)}",
|
| 760 |
+
"error": str(e)
|
| 761 |
+
}
|
| 762 |
+
|
| 763 |
+
@app.get("/model-status")
|
| 764 |
+
async def get_model_status():
|
| 765 |
+
"""Check status of downloaded models"""
|
| 766 |
+
try:
|
| 767 |
+
models_dir = Path("./downloaded_models")
|
| 768 |
+
|
| 769 |
+
status = {
|
| 770 |
+
"models_downloaded": models_dir.exists(),
|
| 771 |
+
"available_models": [],
|
| 772 |
+
"storage_info": {}
|
| 773 |
+
}
|
| 774 |
+
|
| 775 |
+
if models_dir.exists():
|
| 776 |
+
for model_dir in models_dir.iterdir():
|
| 777 |
+
if model_dir.is_dir():
|
| 778 |
+
status["available_models"].append({
|
| 779 |
+
"name": model_dir.name,
|
| 780 |
+
"path": str(model_dir),
|
| 781 |
+
"files": len(list(model_dir.rglob("*")))
|
| 782 |
+
})
|
| 783 |
+
|
| 784 |
+
# Storage info
|
| 785 |
+
import shutil
|
| 786 |
+
_, _, free_bytes = shutil.disk_usage(".")
|
| 787 |
+
status["storage_info"] = {
|
| 788 |
+
"free_gb": round(free_bytes / (1024**3), 2),
|
| 789 |
+
"models_dir_exists": models_dir.exists()
|
| 790 |
+
}
|
| 791 |
+
|
| 792 |
+
return status
|
| 793 |
+
|
| 794 |
+
except Exception as e:
|
| 795 |
+
return {"error": str(e)}
|
| 796 |
+
|
| 797 |
+
|
| 798 |
+
# Enhanced Gradio interface
|
| 799 |
+
def gradio_generate(prompt, text_to_speech, audio_url, image_url, voice_id, guidance_scale, audio_scale, num_steps):
|
| 800 |
+
"""Gradio interface wrapper with robust TTS support"""
|
| 801 |
+
try:
|
| 802 |
+
# Create request object
|
| 803 |
+
request_data = {
|
| 804 |
+
"prompt": prompt,
|
| 805 |
+
"guidance_scale": guidance_scale,
|
| 806 |
+
"audio_scale": audio_scale,
|
| 807 |
+
"num_steps": int(num_steps)
|
| 808 |
+
}
|
| 809 |
+
|
| 810 |
+
# Add audio source
|
| 811 |
+
if text_to_speech and text_to_speech.strip():
|
| 812 |
+
request_data["text_to_speech"] = text_to_speech
|
| 813 |
+
request_data["voice_id"] = voice_id or "21m00Tcm4TlvDq8ikWAM"
|
| 814 |
+
elif audio_url and audio_url.strip():
|
| 815 |
+
if omni_api.model_loaded:
|
| 816 |
+
request_data["audio_url"] = audio_url
|
| 817 |
+
else:
|
| 818 |
+
return "Error: Audio URL input requires full OmniAvatar models. Please use text-to-speech instead."
|
| 819 |
+
else:
|
| 820 |
+
return "Error: Please provide either text to speech or audio URL"
|
| 821 |
+
|
| 822 |
+
if image_url and image_url.strip():
|
| 823 |
+
if omni_api.model_loaded:
|
| 824 |
+
request_data["image_url"] = image_url
|
| 825 |
+
else:
|
| 826 |
+
return "Error: Image URL input requires full OmniAvatar models for video generation."
|
| 827 |
+
|
| 828 |
+
request = GenerateRequest(**request_data)
|
| 829 |
+
|
| 830 |
+
# Run async function in sync context
|
| 831 |
+
loop = asyncio.new_event_loop()
|
| 832 |
+
asyncio.set_event_loop(loop)
|
| 833 |
+
output_path, processing_time, audio_generated, tts_method = loop.run_until_complete(omni_api.generate_avatar(request))
|
| 834 |
+
loop.close()
|
| 835 |
+
|
| 836 |
+
success_message = f"SUCCESS: Generation completed in {processing_time:.1f}s using {tts_method}"
|
| 837 |
+
print(success_message)
|
| 838 |
+
|
| 839 |
+
if omni_api.model_loaded:
|
| 840 |
+
return output_path
|
| 841 |
+
else:
|
| 842 |
+
return f"ποΈ TTS Audio generated successfully using {tts_method}\nFile: {output_path}\n\nWARNING: Video generation unavailable (OmniAvatar models not found)"
|
| 843 |
+
|
| 844 |
+
except Exception as e:
|
| 845 |
+
logger.error(f"Gradio generation error: {e}")
|
| 846 |
+
return f"Error: {str(e)}"
|
| 847 |
+
|
| 848 |
+
# Create Gradio interface
|
| 849 |
+
mode_info = " (TTS-Only Mode)" if not omni_api.model_loaded else ""
|
| 850 |
+
description_extra = """
|
| 851 |
+
WARNING: Running in TTS-Only Mode - OmniAvatar models not found. Only text-to-speech generation is available.
|
| 852 |
+
To enable full video generation, the required model files need to be downloaded.
|
| 853 |
+
""" if not omni_api.model_loaded else ""
|
| 854 |
+
|
| 855 |
+
iface = gr.Interface(
|
| 856 |
+
fn=gradio_generate,
|
| 857 |
+
inputs=[
|
| 858 |
+
gr.Textbox(
|
| 859 |
+
label="Prompt",
|
| 860 |
+
placeholder="Describe the character behavior (e.g., 'A friendly person explaining a concept')",
|
| 861 |
+
lines=2
|
| 862 |
+
),
|
| 863 |
+
gr.Textbox(
|
| 864 |
+
label="Text to Speech",
|
| 865 |
+
placeholder="Enter text to convert to speech",
|
| 866 |
+
lines=3,
|
| 867 |
+
info="Will use best available TTS system (Advanced or Fallback)"
|
| 868 |
+
),
|
| 869 |
+
gr.Textbox(
|
| 870 |
+
label="OR Audio URL",
|
| 871 |
+
placeholder="https://example.com/audio.mp3",
|
| 872 |
+
info="Direct URL to audio file (requires full models)" if not omni_api.model_loaded else "Direct URL to audio file"
|
| 873 |
+
),
|
| 874 |
+
gr.Textbox(
|
| 875 |
+
label="Image URL (Optional)",
|
| 876 |
+
placeholder="https://example.com/image.jpg",
|
| 877 |
+
info="Direct URL to reference image (requires full models)" if not omni_api.model_loaded else "Direct URL to reference image"
|
| 878 |
+
),
|
| 879 |
+
gr.Dropdown(
|
| 880 |
+
choices=[
|
| 881 |
+
"21m00Tcm4TlvDq8ikWAM",
|
| 882 |
+
"pNInz6obpgDQGcFmaJgB",
|
| 883 |
+
"EXAVITQu4vr4xnSDxMaL",
|
| 884 |
+
"ErXwobaYiN019PkySvjV",
|
| 885 |
+
"TxGEqnHWrfGW9XjX",
|
| 886 |
+
"yoZ06aMxZJJ28mfd3POQ",
|
| 887 |
+
"AZnzlk1XvdvUeBnXmlld"
|
| 888 |
+
],
|
| 889 |
+
value="21m00Tcm4TlvDq8ikWAM",
|
| 890 |
+
label="Voice Profile",
|
| 891 |
+
info="Choose voice characteristics for TTS generation"
|
| 892 |
+
),
|
| 893 |
+
gr.Slider(minimum=1, maximum=10, value=5.0, label="Guidance Scale", info="4-6 recommended"),
|
| 894 |
+
gr.Slider(minimum=1, maximum=10, value=3.0, label="Audio Scale", info="Higher values = better lip-sync"),
|
| 895 |
+
gr.Slider(minimum=10, maximum=100, value=30, step=1, label="Number of Steps", info="20-50 recommended")
|
| 896 |
+
],
|
| 897 |
+
outputs=gr.Video(label="Generated Avatar Video") if omni_api.model_loaded else gr.Textbox(label="TTS Output"),
|
| 898 |
+
title="[VIDEO] OmniAvatar-14B - Avatar Video Generation with Adaptive Body Animation",
|
| 899 |
+
description=f"""
|
| 900 |
+
Generate avatar videos with lip-sync from text prompts and speech using robust TTS system.
|
| 901 |
+
|
| 902 |
+
{description_extra}
|
| 903 |
+
|
| 904 |
+
**Robust TTS Architecture**
|
| 905 |
+
- **Primary**: Advanced TTS (Facebook VITS & SpeechT5) if available
|
| 906 |
+
- **Fallback**: Robust tone generation for 100% reliability
|
| 907 |
+
- **Automatic**: Seamless switching between methods
|
| 908 |
+
|
| 909 |
+
**Features:**
|
| 910 |
+
- **Guaranteed Generation**: Always produces audio output
|
| 911 |
+
- **No Dependencies**: Works even without advanced models
|
| 912 |
+
- **High Availability**: Multiple fallback layers
|
| 913 |
+
- **Voice Profiles**: Multiple voice characteristics
|
| 914 |
+
- **Audio URL Support**: Use external audio files {"(full models required)" if not omni_api.model_loaded else ""}
|
| 915 |
+
- **Image URL Support**: Reference images for characters {"(full models required)" if not omni_api.model_loaded else ""}
|
| 916 |
+
|
| 917 |
+
**Usage:**
|
| 918 |
+
1. Enter a character description in the prompt
|
| 919 |
+
2. **Enter text for speech generation** (recommended in current mode)
|
| 920 |
+
3. {"Optionally add reference image/audio URLs (requires full models)" if not omni_api.model_loaded else "Optionally add reference image URL and choose audio source"}
|
| 921 |
+
4. Choose voice profile and adjust parameters
|
| 922 |
+
5. Generate your {"audio" if not omni_api.model_loaded else "avatar video"}!
|
| 923 |
+
""",
|
| 924 |
+
examples=[
|
| 925 |
+
[
|
| 926 |
+
"A professional teacher explaining a mathematical concept with clear gestures",
|
| 927 |
+
"Hello students! Today we're going to learn about calculus and derivatives.",
|
| 928 |
+
"",
|
| 929 |
+
"",
|
| 930 |
+
"21m00Tcm4TlvDq8ikWAM",
|
| 931 |
+
5.0,
|
| 932 |
+
3.5,
|
| 933 |
+
30
|
| 934 |
+
],
|
| 935 |
+
[
|
| 936 |
+
"A friendly presenter speaking confidently to an audience",
|
| 937 |
+
"Welcome everyone to our presentation on artificial intelligence!",
|
| 938 |
+
"",
|
| 939 |
+
"",
|
| 940 |
+
"pNInz6obpgDQGcFmaJgB",
|
| 941 |
+
5.5,
|
| 942 |
+
4.0,
|
| 943 |
+
35
|
| 944 |
+
]
|
| 945 |
+
],
|
| 946 |
+
allow_flagging="never",
|
| 947 |
+
flagging_dir="/tmp/gradio_flagged"
|
| 948 |
+
)
|
| 949 |
+
|
| 950 |
+
# Mount Gradio app
|
| 951 |
+
app = gr.mount_gradio_app(app, iface, path="/gradio")
|
| 952 |
+
|
| 953 |
+
# Add this section near the end of app.py, before the main block
|
| 954 |
+
|
| 955 |
+
# AUTO-DOWNLOAD MODELS ON STARTUP
|
| 956 |
+
async def startup_model_download():
|
| 957 |
+
"""Automatically download models on startup"""
|
| 958 |
+
import asyncio
|
| 959 |
+
|
| 960 |
+
logger.info("?? Starting automatic model download...")
|
| 961 |
+
|
| 962 |
+
# Wait for app to fully initialize
|
| 963 |
+
await asyncio.sleep(5)
|
| 964 |
+
|
| 965 |
+
try:
|
| 966 |
+
models_dir = Path("./downloaded_models")
|
| 967 |
+
|
| 968 |
+
# Check if models already exist
|
| 969 |
+
if models_dir.exists() and any(models_dir.iterdir()):
|
| 970 |
+
logger.info("? Models directory exists, checking contents...")
|
| 971 |
+
return
|
| 972 |
+
|
| 973 |
+
# Check storage
|
| 974 |
+
import shutil
|
| 975 |
+
_, _, free_bytes = shutil.disk_usage(".")
|
| 976 |
+
free_gb = free_bytes / (1024**3)
|
| 977 |
+
|
| 978 |
+
logger.info(f"?? Free storage: {free_gb:.1f}GB")
|
| 979 |
+
|
| 980 |
+
if free_gb > 10: # Need at least 10GB
|
| 981 |
+
logger.info("?? AUTO-DOWNLOADING models (sufficient storage available)...")
|
| 982 |
+
|
| 983 |
+
from huggingface_hub import snapshot_download
|
| 984 |
+
|
| 985 |
+
# Download video model
|
| 986 |
+
logger.info("?? Downloading text-to-video model...")
|
| 987 |
+
video_path = snapshot_download(
|
| 988 |
+
repo_id="ali-vilab/text-to-video-ms-1.7b",
|
| 989 |
+
cache_dir="./downloaded_models/video"
|
| 990 |
+
)
|
| 991 |
+
|
| 992 |
+
# Download audio model
|
| 993 |
+
logger.info("?? Downloading audio model...")
|
| 994 |
+
audio_path = snapshot_download(
|
| 995 |
+
repo_id="facebook/wav2vec2-base-960h",
|
| 996 |
+
cache_dir="./downloaded_models/audio"
|
| 997 |
+
)
|
| 998 |
+
|
| 999 |
+
# Create success marker
|
| 1000 |
+
success_file = models_dir / "auto_download_success.txt"
|
| 1001 |
+
with open(success_file, "w") as f:
|
| 1002 |
+
f.write(f"AUTO DOWNLOAD COMPLETED: {time.time()}\n")
|
| 1003 |
+
f.write(f"Video model: {video_path}\n")
|
| 1004 |
+
f.write(f"Audio model: {audio_path}\n")
|
| 1005 |
+
f.write("STATUS: SUCCESS\n")
|
| 1006 |
+
|
| 1007 |
+
logger.info("? AUTO-DOWNLOAD completed successfully!")
|
| 1008 |
+
|
| 1009 |
+
else:
|
| 1010 |
+
logger.warning(f"?? Insufficient storage for auto-download: {free_gb:.1f}GB")
|
| 1011 |
+
|
| 1012 |
+
except Exception as e:
|
| 1013 |
+
logger.error(f"? Auto-download failed: {e}")
|
| 1014 |
+
|
| 1015 |
+
# Start auto-download in background
|
| 1016 |
+
import asyncio
|
| 1017 |
+
import threading
|
| 1018 |
+
|
| 1019 |
+
def run_auto_download():
|
| 1020 |
+
"""Run auto-download in background thread"""
|
| 1021 |
+
try:
|
| 1022 |
+
loop = asyncio.new_event_loop()
|
| 1023 |
+
asyncio.set_event_loop(loop)
|
| 1024 |
+
loop.run_until_complete(startup_model_download())
|
| 1025 |
+
except Exception as e:
|
| 1026 |
+
logger.error(f"? Background download failed: {e}")
|
| 1027 |
+
|
| 1028 |
+
# Start download thread
|
| 1029 |
+
download_thread = threading.Thread(target=run_auto_download, daemon=True)
|
| 1030 |
+
download_thread.start()
|
| 1031 |
+
logger.info("?? Model auto-download thread started")
|
| 1032 |
+
|
| 1033 |
+
if __name__ == "__main__":
|
| 1034 |
+
import uvicorn
|
| 1035 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|
| 1036 |
+
|
| 1037 |
+
|
| 1038 |
+
|
| 1039 |
+
|
| 1040 |
+
|
| 1041 |
+
|
| 1042 |
+
|
| 1043 |
+
|
| 1044 |
+
|
| 1045 |
+
|
| 1046 |
+
|
| 1047 |
+
|
| 1048 |
+
|
| 1049 |
+
|
| 1050 |
+
|
| 1051 |
+
|
| 1052 |
+
|
| 1053 |
+
|