Tacoswithhorchata commited on
Commit
14d3449
·
0 Parent(s):

Initial commit with essential files

Browse files
.streamlit/config.toml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ [theme]
2
+ base="dark"
3
+ primaryColor="#865bf1"
4
+ font="monospace"
README.md ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AI Lip Sync
2
+
3
+ ![Screenshot 2024-01-22 at 03-03-09 app · Streamlit](https://github.com/Aml-Hassan-Abd-El-hamid/AI-Lip-Sync/assets/66205928/d35f379f-f1ca-46e5-a113-bfa3f3c4c2f9)
4
+
5
+ An AI-powered application that synchronizes lip movements with audio input, built with Wav2Lip and Streamlit.
6
+
7
+ ## Features
8
+
9
+ - **Multiple Avatar Options**: Choose from built-in avatars or upload your own image/video
10
+ - **Audio Input Flexibility**: Record audio directly or upload WAV/MP3 files
11
+ - **Quality Assessment**: Automatic analysis of video and audio quality with recommendations
12
+ - **GPU Acceleration**: Optimized for Apple Silicon (M1/M2) GPUs
13
+ - **Two Animation Modes**: Fast (lips only) or Slow (full face animation)
14
+ - **Video Trimming**: Trim the output video to remove unwanted portions
15
+
16
+ ## Quick Setup Guide
17
+
18
+ ### Prerequisites
19
+
20
+ - Python 3.9+
21
+ - ffmpeg (for audio processing)
22
+ - Git LFS (optional, for handling large model files)
23
+
24
+ ### Installation
25
+
26
+ 1. Clone the repository:
27
+ ```bash
28
+ git clone https://github.com/yourusername/ai-lip-sync-app.git
29
+ cd ai-lip-sync-app
30
+ ```
31
+
32
+ 2. Create and activate a virtual environment:
33
+ ```bash
34
+ python -m venv .venv
35
+ # On macOS/Linux
36
+ source .venv/bin/activate
37
+ # On Windows
38
+ .venv\Scripts\activate
39
+ ```
40
+
41
+ 3. Install Python dependencies:
42
+ ```bash
43
+ pip install -r requirements.txt
44
+ ```
45
+
46
+ 4. Install system dependencies:
47
+ ```bash
48
+ # On Ubuntu/Debian
49
+ sudo apt-get update
50
+ sudo apt-get install $(cat packages.txt)
51
+
52
+ # On macOS with Homebrew
53
+ brew install ffmpeg
54
+ ```
55
+
56
+ 5. Run the application:
57
+ ```bash
58
+ python -m streamlit run app.py
59
+ ```
60
+
61
+ > **Note**: If you encounter a "streamlit: command not found" error, always use `python -m streamlit run app.py` instead of `streamlit run app.py`
62
+
63
+ The application will automatically download the required model files on first run.
64
+
65
+ ## Usage Guide
66
+
67
+ 1. **Choose Avatar Source**:
68
+ - Select from built-in avatars or upload your own image/video
69
+ - For best results, use clear frontal face images/videos
70
+
71
+ 2. **Provide Audio**:
72
+ - Record directly using your microphone
73
+ - Upload WAV or MP3 files
74
+
75
+ 3. **Quality Assessment**:
76
+ - The app will automatically analyze your uploaded video and audio
77
+ - Review the quality analysis and recommendations
78
+ - Make adjustments if needed for better results
79
+
80
+ 4. **Generate Animation**:
81
+ - Choose "Fast animate" for quicker processing (lips only)
82
+ - Choose "Slower animate" for more realistic results (full face)
83
+
84
+ 5. **View and Edit Results**:
85
+ - The generated video will appear in the app
86
+ - Use the trim feature to remove unwanted portions from the start or end
87
+ - Download the original or trimmed version to your computer
88
+
89
+ ## Video Trimming Feature
90
+
91
+ The app now includes a video trimming capability:
92
+
93
+ - After generating a lip-sync video, you'll see trimming options below the result
94
+ - Use the sliders to select the start and end times for your trimmed video
95
+ - Click "Trim Video" to create a shortened version
96
+ - Both original and trimmed videos can be downloaded directly from the app
97
+
98
+ ## Quality Assessment Feature
99
+
100
+ The app now includes automatic quality assessment for uploaded videos and audio:
101
+
102
+ ### Video Analysis:
103
+ - Resolution check (higher resolution = better results)
104
+ - Face detection (confirms a face is present and properly sized)
105
+ - Frame rate analysis
106
+ - Overall quality score with specific recommendations
107
+
108
+ ### Audio Analysis:
109
+ - Speech detection (confirms speech is present)
110
+ - Volume level assessment
111
+ - Silence detection
112
+ - Overall quality score with specific recommendations
113
+
114
+ ## Troubleshooting
115
+
116
+ - **"No face detected" error**: Ensure your video has a clear, well-lit frontal face
117
+ - **Poor lip sync results**: Try using higher quality audio with clear speech
118
+ - **Performance issues**: For large videos, try the "Fast animate" option or use a smaller video clip
119
+ - **Memory errors**: Close other applications to free up memory, or use a machine with more RAM
120
+
121
+ ## Technical Details
122
+
123
+ The project is built on the Wav2Lip model with several optimizations:
124
+ - Apple Silicon (M1/M2) GPU acceleration using MPS backend
125
+ - Automatic video resolution scaling for large videos
126
+ - Memory optimizations for processing longer videos
127
+ - Quality assessment using OpenCV and librosa
128
+
129
+ ## Original Project Background
130
+
131
+ The project started as a part of an interview process with some company, I received an email with the following task:
132
+
133
+ Assignment Object:<br>
134
+ &emsp;&emsp;&emsp;&emsp;Your task is to develop a lip-syncing model using machine learning
135
+ techniques. It takes an input image and audio and then generates a video
136
+ where the image appears to lip sync with the provided audio. You have to
137
+ develop this task using python3.
138
+
139
+ Requirements:<br>
140
+ &emsp;&emsp;&emsp;&emsp;● Avatar / Image : Get one AI-generated avatar, the avatar may be for a<br>
141
+ &emsp;&emsp;&emsp;&emsp;man, woman, old man, old lady or a child. Ensure that the avatar is<br>
142
+ &emsp;&emsp;&emsp;&emsp;created by artificial intelligence and does not represent real<br>
143
+ &emsp;&emsp;&emsp;&emsp;individuals.<br>
144
+ &emsp;&emsp;&emsp;&emsp;● Audio : Provide two distinct and clear audio recordings—one in Arabic<br>
145
+ &emsp;&emsp;&emsp;&emsp;and the other in English. The duration of each audio clip should be<br>
146
+ &emsp;&emsp;&emsp;&emsp;no less than 30 seconds and no more than 1 minute.<br>
147
+ &emsp;&emsp;&emsp;&emsp;● Lip-sync model: Develop a lip-syncing model to synchronise the lip<br>
148
+ &emsp;&emsp;&emsp;&emsp;movements of the chosen avatar with the provided audio. Ensure the<br>
149
+ &emsp;&emsp;&emsp;&emsp;model demonstrates proficiency in accurately aligning lip motions<br>
150
+ &emsp;&emsp;&emsp;&emsp;with the spoken words in both Arabic and English.<br>
151
+ &emsp;&emsp;&emsp;&emsp;Hint : You can refer to state of the art models in lip-syncing.<br>
152
+
153
+ I was given about 96 hours to accomplish this task, I spent the first 12 hours sick with a very bad flu and no proper internet connection so I had 84 hours!<br>
154
+ After submitting the task on time, I took more time to deploy the project on Streamlight, as I thought it was a fun project and would be a nice addition to my CV:)
155
+
156
+ Given the provided hint from the company, "You can refer to state-of-the-art models in lip-syncing.", I started looking into the available open-source pre-trained model that can accomplish this task and most available resources pointed towards **Wav2Lip**. I found a couple of interesting tutorials for that model that I will share below.
157
+
158
+ ### How to run the application locally:<br>
159
+
160
+ 1- clone the repo to your local machine.<br>
161
+ 2- open your terminal inside the project folder and run the following command: `pip install -r requirements.txt` and then run this command `sudo xargs -a packages.txt apt-get install` to install the needed modules and packages.<br>
162
+ 3- open your terminal inside the project folder and run the following command: `streamlit run app.py` to run the streamlit application.<br>
163
+
164
+ ### Things I changed in the wav2lip and why:<br>
165
+
166
+ In order to work with and deploy the wav2lip model I had to make the following changes:<br>
167
+ 1- Changed the `_build_mel_basis()` function in `audio.py`, I had to do that to be able to work with `librosa>=0.10.0` package, check this [issue](https://github.com/Rudrabha/Wav2Lip/issues/550) for more details.<br>
168
+ 2- Changed the `main()` function at the `inferance.py` to directly take an output from the `app.py` instead of using the command line arguments.<br>
169
+ 3- I took the `load_model(path)` function and added it to `app.py` and added `@st.cache_data` in order to only load the model once, instead of using it multiple times, I also modified it<br>
170
+ 4- Deleted the unnecessary files like the checkpoints to make the Streamlit website deployment easier.<br>
171
+ 5- Since I'm using Streamlit for deployment and Streamlit Cloud doesn't support GPU, I had to change the device to work with `cpu` instead of `cuda`.<br>
172
+ 6- I made other minor changes like changing the path to a file or modifying import statements.
173
+
174
+ ### Issues I had with Streamlit, during the deployment:
175
+
176
+ This part is a documentation for me, just in case, I need to face an issue in the future and also could be helpful for any poor soul who would have to work with Streamlit:
177
+
178
+ 1-
179
+ ```
180
+ Error downloading object: wav2lip/checkpoints/wav2lip_gan.pth (ca9ab7b): Smudge error: Error downloading wav2lip/checkpoints/wav2lip_gan.pth (ca9ab7b7b812c0e80a6e70a5977c545a1e8a365a6c49d5e533023c034d7ac3d8): batch request: [email protected]: Permission denied (publickey).: exit status 255
181
+
182
+ Errors logged to /mount/src/ai-lip-sync/.git/lfs/logs/20240121T212252.496674
183
+ ```
184
+ This essentially Streamlit telling you that it can't handle that big file, upload it to Google Drive, and then load it using Python code later, and no `git lfs` won't solve the problem :)<br>
185
+ A ground rule that I learned here is: that the lighter you make your app, the better and faster it is to deploy it.<br>
186
+ I opened a topic with that issue on the Streamlit forum, right [here](https://discuss.streamlit.io/t/file-upload-fails-with-error-downloading-object-wav2lip-checkpoints-wav2lip-gan-pth-ca9ab7b/60261)<br>
187
+
188
+ 2- Other issues that I faced a lot were dependency issues -lots of them- and that was mostly due to the fact that I depended on `pipreqs` to write down my `requirements.txt`, that `pipreqs` missed up my modules, it added unneeded ones and missed others, unfortunately, it took me some time to discover that and really slowed me down.
189
+
190
+ 3-
191
+ ```
192
+ ImportError: libGL.so.1: cannot open shared object file: No such file or directory
193
+ ```
194
+ I faced that problem during importing `cv2` -`openCv`- and the solution was to install `libgl1-mesa-dev` and some other packages using `apt`, you can't just add such packages to the `requirements.txt`, you need to create a file named `packages.txt` to do so.
195
+
196
+ 4- Streamlit can't handle heavy processing, I discovered that when I tried to deploy the `slow animation` button to process video input alongside recording to get more accurate lip-syncing, the application failed directly when I used that button -and I tried to use it twice :)-, and that kinda make sense as Streamlit doesn't have a GPU or even a high ram space -I don't have a good GPU but I have about 64GB ram which was enough to run that function locally- and to solve that issue, I initiated another branch to contain the deployment version that doesn't have the `slow animation` button and used that branch for deployment while kept the main branch containing that button.
197
+
198
+ **Pushing the checkpoints files:**<br>
199
+
200
+ Given the size of those kind of files, There are 2 ways to handle that.
201
+
202
+ At the start, I had to use git lfs, here's how to do it:<br>
203
+
204
+ 1- Follow the installation instructions that are suitable for your system from [here](https://docs.github.com/en/repositories/working-with-files/managing-large-files/installing-git-large-file-storage) <br>
205
+ 2- Use the command `git lfs track "*.pth"` to let git lfs know that those are your big files.<br>
206
+ 3- When pushing from the command line -I usually use VS code but it usually doesn't work with big files like `.pth` files- you need to generate a personal access token, to do so, follow the instructions from [here](https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/managing-your-personal-access-tokens#creating-a-fine-grained-personal-access-token), and then copy the token<br>
207
+ 4- When pushing the file from the terminal you will be asked to pass a password, don't pass your GitHub profile password, instead pass your personal access token that you got from step 3.
208
+
209
+ But then Streamlit wasn't capable of even pulling the repo! so I uploaded the model checkpoints and some other files to Google Drive, put them in a public folder, and then used a module called gdown to download those folders when needed! here's a [link](https://github.com/wkentaro/gdown) to that gdown, it's straightforward to use and install.
210
+
211
+
212
+ **Video preview of the application:**<br>
213
+
214
+ **fast animation version**<br>
215
+ Notice how only the lips are moving.
216
+
217
+ English version:
218
+
219
+ https://github.com/Aml-Hassan-Abd-El-hamid/AI-Lip-Sync/assets/66205928/36577ccb-5ec6-4bb4-b7ff-44bb52a4f984
220
+
221
+ Arabic version:
222
+
223
+
224
+
225
+ https://github.com/Aml-Hassan-Abd-El-hamid/ai-lip-sync-app/assets/66205928/4346aa6d-ea4e-400e-9124-1cce06b049df
226
+
227
+
228
+
229
+ **slower animation version**<br>
230
+ Notice how the eye and the whole face are moving instead of only the lips.<br>
231
+
232
+ Unfortunately, Streamlit can't handle the computational power that the slower animation version requires and that's why I made it only available on the offline version, which means that you need to run the application locally to try that version.
233
+
234
+ English version:
235
+
236
+ https://github.com/Aml-Hassan-Abd-El-hamid/AI-Lip-Sync/assets/66205928/26740856-52e5-4fe7-868d-3b9341e97064
237
+
238
+ Arabic version:
239
+
240
+
241
+
242
+ https://github.com/Aml-Hassan-Abd-El-hamid/ai-lip-sync-app/assets/66205928/ba97daca-b30d-4179-9387-a382abbca3ba
243
+
244
+
245
+
246
+ The only difference between the fast and slow versions of animation here is the fact that the fast version passes only a photo while the slow version passes a video instead.
app.py ADDED
@@ -0,0 +1,1000 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ from streamlit_image_select import image_select
4
+ import torch
5
+ from streamlit_mic_recorder import mic_recorder
6
+ from wav2lip import inference
7
+ from wav2lip.models import Wav2Lip
8
+ import gdown
9
+ import warnings
10
+ import cv2
11
+ import numpy as np
12
+ import librosa
13
+ from pathlib import Path
14
+ import subprocess
15
+ import time
16
+ from PIL import Image
17
+ import matplotlib.pyplot as plt
18
+ import sys
19
+ import threading
20
+ import concurrent.futures
21
+
22
+ # Suppress warnings
23
+ warnings.filterwarnings('ignore')
24
+
25
+ # More comprehensive fix for Streamlit file watcher issues with PyTorch
26
+ os.environ['STREAMLIT_WATCH_IGNORE'] = 'torch'
27
+ if 'torch' in sys.modules:
28
+ sys.modules['torch'].__path__ = type('', (), {'_path': []})()
29
+
30
+ # Check if MPS (Apple Silicon GPU) is available, otherwise use CPU
31
+ if torch.backends.mps.is_available():
32
+ device = 'mps'
33
+ # Enable memory optimization for Apple Silicon
34
+ torch.mps.empty_cache()
35
+ # Set the memory format to optimize for M2 Max
36
+ torch._C._set_cudnn_benchmark(True)
37
+ st.success("Using Apple M2 Max GPU for acceleration with optimized settings!")
38
+ else:
39
+ device = 'cpu'
40
+ st.warning("Using CPU for inference (slower). GPU acceleration not available.")
41
+
42
+ print(f"Using {device} for inference.")
43
+
44
+ # Add functions to analyze video and audio quality
45
+ def analyze_video_quality(file_path):
46
+ """Analyze video quality and detect faces for better user guidance"""
47
+ try:
48
+ # Open the video file
49
+ video = cv2.VideoCapture(file_path)
50
+
51
+ # Get video properties
52
+ width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
53
+ height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
54
+ fps = video.get(cv2.CAP_PROP_FPS)
55
+ frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
56
+ duration = frame_count / fps if fps > 0 else 0
57
+
58
+ # Read a frame for face detection
59
+ success, frame = video.read()
60
+ if not success:
61
+ return {
62
+ "resolution": f"{width}x{height}",
63
+ "fps": fps,
64
+ "duration": f"{duration:.1f} seconds",
65
+ "quality": "Unknown",
66
+ "face_detected": False,
67
+ "message": "Could not analyze video content."
68
+ }
69
+
70
+ # Detect faces using OpenCV's face detector
71
+ face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
72
+ gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
73
+ faces = face_cascade.detectMultiScale(gray, 1.1, 4)
74
+
75
+ # Determine quality score based on resolution and face detection
76
+ quality_score = 0
77
+
78
+ # Resolution assessment
79
+ if width >= 1920 or height >= 1080: # 1080p or higher
80
+ resolution_quality = "Excellent"
81
+ quality_score += 3
82
+ elif width >= 1280 or height >= 720: # 720p
83
+ resolution_quality = "Good"
84
+ quality_score += 2
85
+ elif width >= 640 or height >= 480: # 480p
86
+ resolution_quality = "Fair"
87
+ quality_score += 1
88
+ else:
89
+ resolution_quality = "Low"
90
+
91
+ # Overall quality assessment
92
+ face_detected = len(faces) > 0
93
+
94
+ if face_detected:
95
+ quality_score += 2
96
+ face_message = "Face detected! ✅"
97
+
98
+ # Check face size relative to frame
99
+ for (x, y, w, h) in faces:
100
+ face_area_ratio = (w * h) / (width * height)
101
+ if face_area_ratio > 0.1: # Face takes up at least 10% of frame
102
+ quality_score += 1
103
+ face_size = "Good face size"
104
+ else:
105
+ face_size = "Face may be too small"
106
+ else:
107
+ face_message = "No face detected! ⚠️ Lip sync results may be poor."
108
+ face_size = "N/A"
109
+
110
+ # Determine overall quality
111
+ if quality_score >= 5:
112
+ quality = "Excellent"
113
+ elif quality_score >= 3:
114
+ quality = "Good"
115
+ elif quality_score >= 1:
116
+ quality = "Fair"
117
+ else:
118
+ quality = "Poor"
119
+
120
+ # Release video resource
121
+ video.release()
122
+
123
+ return {
124
+ "resolution": f"{width}x{height}",
125
+ "fps": f"{fps:.1f}",
126
+ "duration": f"{duration:.1f} seconds",
127
+ "quality": quality,
128
+ "resolution_quality": resolution_quality,
129
+ "face_detected": face_detected,
130
+ "face_message": face_message,
131
+ "face_size": face_size,
132
+ "message": get_video_recommendation(quality, face_detected, width, height)
133
+ }
134
+
135
+ except Exception as e:
136
+ return {
137
+ "quality": "Error",
138
+ "message": f"Could not analyze video: {str(e)}"
139
+ }
140
+
141
+ def analyze_audio_quality(file_path):
142
+ """Analyze audio quality for better user guidance"""
143
+ try:
144
+ # Load audio file using librosa
145
+ y, sr = librosa.load(file_path, sr=None)
146
+
147
+ # Get duration
148
+ duration = librosa.get_duration(y=y, sr=sr)
149
+
150
+ # Calculate audio features
151
+ rms = librosa.feature.rms(y=y)[0]
152
+ mean_volume = np.mean(rms)
153
+
154
+ # Simple speech detection (using energy levels)
155
+ has_speech = np.max(rms) > 0.05
156
+
157
+ # Check for silence periods
158
+ silence_threshold = 0.01
159
+ silence_percentage = np.mean(rms < silence_threshold) * 100
160
+
161
+ # Calculate quality score
162
+ quality_score = 0
163
+
164
+ # Volume assessment
165
+ if 0.05 <= mean_volume <= 0.2:
166
+ volume_quality = "Good volume levels"
167
+ quality_score += 2
168
+ elif mean_volume > 0.2:
169
+ volume_quality = "Audio might be too loud"
170
+ quality_score += 1
171
+ else:
172
+ volume_quality = "Audio might be too quiet"
173
+
174
+ # Speech detection
175
+ if has_speech:
176
+ speech_quality = "Speech detected ✅"
177
+ quality_score += 2
178
+ else:
179
+ speech_quality = "Speech may not be clear ⚠️"
180
+
181
+ # Silence assessment (some silence is normal)
182
+ if silence_percentage < 40:
183
+ silence_quality = "Good speech-to-silence ratio"
184
+ quality_score += 1
185
+ else:
186
+ silence_quality = "Too much silence detected"
187
+
188
+ # Determine overall quality
189
+ if quality_score >= 4:
190
+ quality = "Excellent"
191
+ elif quality_score >= 2:
192
+ quality = "Good"
193
+ elif quality_score >= 1:
194
+ quality = "Fair"
195
+ else:
196
+ quality = "Poor"
197
+
198
+ return {
199
+ "duration": f"{duration:.1f} seconds",
200
+ "quality": quality,
201
+ "volume_quality": volume_quality,
202
+ "speech_quality": speech_quality,
203
+ "silence_quality": silence_quality,
204
+ "message": get_audio_recommendation(quality, has_speech, mean_volume, silence_percentage)
205
+ }
206
+
207
+ except Exception as e:
208
+ return {
209
+ "quality": "Error",
210
+ "message": f"Could not analyze audio: {str(e)}"
211
+ }
212
+
213
+ def get_video_recommendation(quality, face_detected, width, height):
214
+ """Get recommendations based on video quality"""
215
+ if not face_detected:
216
+ return "⚠️ No face detected. For best results, use a video with a clear, well-lit face looking toward the camera."
217
+
218
+ if quality == "Poor":
219
+ return "⚠️ Low quality video. Consider using a higher resolution video with better lighting and a clearly visible face."
220
+
221
+ if width < 640 or height < 480:
222
+ return "⚠️ Video resolution is low. For better results, use a video with at least 480p resolution."
223
+
224
+ if quality == "Excellent":
225
+ return "✅ Great video quality! This should work well for lip syncing."
226
+
227
+ return "✅ Video quality is acceptable for lip syncing."
228
+
229
+ def get_audio_recommendation(quality, has_speech, volume, silence_percentage):
230
+ """Get recommendations based on audio quality"""
231
+ if not has_speech:
232
+ return "⚠️ Speech may not be clearly detected. For best results, use audio with clear speech."
233
+
234
+ if quality == "Poor":
235
+ return "⚠️ Low quality audio. Consider using clearer audio with consistent volume levels."
236
+
237
+ if volume < 0.01:
238
+ return "⚠️ Audio volume is very low. This may result in poor lip sync."
239
+
240
+ if volume > 0.3:
241
+ return "⚠️ Audio volume is very high. This may cause distortion in lip sync."
242
+
243
+ if silence_percentage > 50:
244
+ return "⚠️ Audio contains a lot of silence. Lip sync will only work during speech sections."
245
+
246
+ if quality == "Excellent":
247
+ return "✅ Great audio quality! This should work well for lip syncing."
248
+
249
+ return "✅ Audio quality is acceptable for lip syncing."
250
+
251
+ #@st.cache_data is used to only load the model once
252
+ #@st.cache_data
253
+ @st.cache_resource
254
+ def load_model(path):
255
+ st.write("Please wait for the model to be loaded or it will cause an error")
256
+ wav2lip_checkpoints_url = "https://drive.google.com/drive/folders/1Sy5SHRmI3zgg2RJaOttNsN3iJS9VVkbg?usp=sharing"
257
+ if not os.path.exists(path):
258
+ gdown.download_folder(wav2lip_checkpoints_url, quiet=True, use_cookies=False)
259
+ st.write("Please wait")
260
+ model = Wav2Lip()
261
+ print("Load checkpoint from: {}".format(path))
262
+
263
+ # Optimize model loading for M2 Max
264
+ if device == 'mps':
265
+ # Clear cache before loading model
266
+ torch.mps.empty_cache()
267
+
268
+ # Load model with device mapping
269
+ checkpoint = torch.load(path, map_location=torch.device(device))
270
+ s = checkpoint["state_dict"]
271
+ new_s = {}
272
+ for k, v in s.items():
273
+ new_s[k.replace('module.', '')] = v
274
+ model.load_state_dict(new_s)
275
+ model = model.to(device)
276
+
277
+ # Set model to evaluation mode and optimize for inference
278
+ model.eval()
279
+ if device == 'mps':
280
+ # Attempt to optimize the model for inference
281
+ try:
282
+ # Use torch's inference mode for optimized inference
283
+ torch._C._jit_set_profiling_executor(False)
284
+ torch._C._jit_set_profiling_mode(False)
285
+ print("Applied M2 Max optimizations")
286
+ except:
287
+ print("Could not apply all M2 Max optimizations")
288
+
289
+ st.write(f"Model loaded successfully on {device} with optimized settings for M2 Max!")
290
+ return model
291
+ @st.cache_resource
292
+ def load_avatar_videos_for_slow_animation(path):
293
+ if not os.path.exists(path):
294
+ try:
295
+ os.makedirs(path, exist_ok=True)
296
+ print(f"Created directory: {path}")
297
+
298
+ avatar_videos_url = "https://drive.google.com/drive/folders/1h9pkU5wenrS2vmKqXBfFmrg-1hYw5s4q?usp=sharing"
299
+ print(f"Downloading avatar videos from: {avatar_videos_url}")
300
+ gdown.download_folder(avatar_videos_url, quiet=False, use_cookies=False)
301
+ print(f"Avatar videos downloaded successfully to: {path}")
302
+ except Exception as e:
303
+ print(f"Error downloading avatar videos: {str(e)}")
304
+ # Create default empty videos if download fails
305
+ for avatar_file in ["avatar1.mp4", "avatar2.mp4", "avatar3.mp4"]:
306
+ video_path = os.path.join(path, avatar_file)
307
+ if not os.path.exists(video_path):
308
+ print(f"Creating empty video file: {video_path}")
309
+ # Get the matching image
310
+ img_key = f"avatars_images/{os.path.splitext(avatar_file)[0]}" + (".jpg" if avatar_file != "avatar3.mp4" else ".png")
311
+ try:
312
+ # Create a video from the image
313
+ img = cv2.imread(img_key)
314
+ if img is not None:
315
+ # Create a short 5-second video from the image
316
+ print(f"Creating video from image: {img_key}")
317
+ height, width = img.shape[:2]
318
+ output_video = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), 30, (width, height))
319
+ for _ in range(150): # 5 seconds at 30 fps
320
+ output_video.write(img)
321
+ output_video.release()
322
+ else:
323
+ print(f"Could not read image: {img_key}")
324
+ except Exception as e:
325
+ print(f"Error creating video from image: {str(e)}")
326
+ else:
327
+ print(f"Avatar videos directory already exists: {path}")
328
+ # Check if files exist in the directory
329
+ files = os.listdir(path)
330
+ if not files:
331
+ print(f"No files found in {path}, directory exists but is empty")
332
+ else:
333
+ print(f"Found {len(files)} files in {path}: {', '.join(files)}")
334
+
335
+
336
+
337
+ image_video_map = {
338
+ "avatars_images/avatar1.jpg":"avatars_videos/avatar1.mp4",
339
+ "avatars_images/avatar2.jpg":"avatars_videos/avatar2.mp4",
340
+ "avatars_images/avatar3.png":"avatars_videos/avatar3.mp4"
341
+ }
342
+ def streamlit_look():
343
+ """
344
+ Modest front-end code:)
345
+ """
346
+ data={}
347
+ st.title("Welcome to AI Lip Sync :)")
348
+
349
+ # Add a brief app description
350
+ st.markdown("""
351
+ This app uses AI to synchronize a person's lip movements with any audio file.
352
+ You can choose from built-in avatars or upload your own image/video, then provide audio
353
+ to create realistic lip-synced videos. Powered by Wav2Lip and optimized for Apple Silicon.
354
+ """)
355
+
356
+ # Add a guidelines section with an expander for best practices
357
+ with st.expander("📋 Guidelines & Best Practices (Click to expand)", expanded=False):
358
+ st.markdown("""
359
+ ### Guidelines for Best Results
360
+
361
+ #### Audio and Video Length
362
+ - Audio and video don't need to be exactly the same length
363
+ - If audio is shorter than video: Only the matching portion will be lip-synced
364
+ - If audio is longer than video: Audio will be trimmed to match video length
365
+
366
+ #### Face Quality
367
+ - Clear, well-lit frontal views of faces work best
368
+ - Faces should take up a reasonable portion of the frame
369
+ - Avoid extreme angles, heavy shadows, or partial face views
370
+
371
+ #### Audio Quality
372
+ - Clear speech with minimal background noise works best
373
+ - Consistent audio volume improves synchronization
374
+ - Supported formats: WAV, MP3
375
+
376
+ #### Video Quality
377
+ - Stable videos with minimal camera movement
378
+ - The person's mouth should be clearly visible
379
+ - Videos at 480p or higher resolution work best
380
+ - Very high-resolution videos will be automatically downscaled
381
+
382
+ #### Processing Tips
383
+ - Shorter videos process faster and often give better results
384
+ - "Fast animation" only moves the lips (quicker processing)
385
+ - "Slow animation" animates the full face (better quality, slower)
386
+ - Your M2 Max GPU will significantly speed up processing
387
+ """)
388
+
389
+ # Option to choose between built-in avatars or upload a custom one
390
+ avatar_source = st.radio("Choose avatar source:", ["Upload my own image/video", "Use built-in avatars"])
391
+
392
+ if avatar_source == "Use built-in avatars":
393
+ st.write("Please choose your avatar from the following options:")
394
+ avatar_img = image_select("",
395
+ ["avatars_images/avatar1.jpg",
396
+ "avatars_images/avatar2.jpg",
397
+ "avatars_images/avatar3.png",
398
+ ])
399
+ data["imge_path"] = avatar_img
400
+ else:
401
+ st.write("Upload an image or video file for your avatar:")
402
+ uploaded_file = st.file_uploader("Choose an image or video file", type=["jpg", "jpeg", "png", "mp4"], key="avatar_uploader")
403
+
404
+ if uploaded_file is not None:
405
+ # Save the uploaded file
406
+ file_path = os.path.join("uploads", uploaded_file.name)
407
+ os.makedirs("uploads", exist_ok=True)
408
+
409
+ with open(file_path, "wb") as f:
410
+ f.write(uploaded_file.getvalue())
411
+
412
+ # Set the file path as image path
413
+ data["imge_path"] = file_path
414
+ st.success(f"File uploaded successfully: {uploaded_file.name}")
415
+
416
+ # Preview the uploaded image/video
417
+ if uploaded_file.name.endswith(('.jpg', '.jpeg', '.png')):
418
+ st.image(file_path, caption="Uploaded Image")
419
+ elif uploaded_file.name.endswith('.mp4'):
420
+ st.video(file_path)
421
+
422
+ # Analyze video quality for MP4 files
423
+ with st.spinner("Analyzing video quality..."):
424
+ video_analysis = analyze_video_quality(file_path)
425
+
426
+ # Display video quality analysis in a nice box
427
+ with st.expander("📊 Video Quality Analysis", expanded=True):
428
+ col1, col2 = st.columns(2)
429
+
430
+ with col1:
431
+ st.markdown(f"**Resolution:** {video_analysis['resolution']}")
432
+ st.markdown(f"**FPS:** {video_analysis['fps']}")
433
+ st.markdown(f"**Duration:** {video_analysis['duration']}")
434
+
435
+ with col2:
436
+ quality_color = {
437
+ "Excellent": "green",
438
+ "Good": "lightgreen",
439
+ "Fair": "orange",
440
+ "Poor": "red",
441
+ "Error": "red"
442
+ }.get(video_analysis['quality'], "gray")
443
+
444
+ st.markdown(f"**Quality:** <span style='color:{quality_color};font-weight:bold'>{video_analysis['quality']}</span>", unsafe_allow_html=True)
445
+ st.markdown(f"**Face Detection:** {'✅ Detected' if video_analysis.get('face_detected', False) else '❌ Not detected'}")
446
+
447
+ # Display the recommendation
448
+ st.info(video_analysis['message'])
449
+
450
+ # Option to choose between mic recording or upload audio file
451
+ audio_source = st.radio("Choose audio source:", ["Upload audio file", "Record with microphone"])
452
+
453
+ if audio_source == "Record with microphone":
454
+ audio = mic_recorder(
455
+ start_prompt="Start recording",
456
+ stop_prompt="Stop recording",
457
+ just_once=False,
458
+ use_container_width=False,
459
+ callback=None,
460
+ args=(),
461
+ kwargs={},
462
+ key=None)
463
+
464
+ if audio:
465
+ st.audio(audio["bytes"])
466
+ data["audio"] = audio["bytes"]
467
+ else:
468
+ st.write("Upload an audio file:")
469
+ uploaded_audio = st.file_uploader("Choose an audio file", type=["wav", "mp3"], key="audio_uploader")
470
+
471
+ if uploaded_audio is not None:
472
+ # Save the uploaded audio file
473
+ audio_path = os.path.join("uploads", uploaded_audio.name)
474
+ os.makedirs("uploads", exist_ok=True)
475
+
476
+ with open(audio_path, "wb") as f:
477
+ f.write(uploaded_audio.getvalue())
478
+
479
+ # Preview the uploaded audio
480
+ st.audio(audio_path)
481
+
482
+ # Read the file into bytes for consistency with microphone recording
483
+ with open(audio_path, "rb") as f:
484
+ audio_bytes = f.read()
485
+
486
+ data["audio"] = audio_bytes
487
+ st.success(f"Audio file uploaded successfully: {uploaded_audio.name}")
488
+
489
+ # Analyze audio quality
490
+ with st.spinner("Analyzing audio quality..."):
491
+ audio_analysis = analyze_audio_quality(audio_path)
492
+
493
+ # Display audio quality analysis in a nice box
494
+ with st.expander("🎵 Audio Quality Analysis", expanded=True):
495
+ col1, col2 = st.columns(2)
496
+
497
+ with col1:
498
+ st.markdown(f"**Duration:** {audio_analysis['duration']}")
499
+ st.markdown(f"**Volume:** {audio_analysis['volume_quality']}")
500
+
501
+ with col2:
502
+ quality_color = {
503
+ "Excellent": "green",
504
+ "Good": "lightgreen",
505
+ "Fair": "orange",
506
+ "Poor": "red",
507
+ "Error": "red"
508
+ }.get(audio_analysis['quality'], "gray")
509
+
510
+ st.markdown(f"**Quality:** <span style='color:{quality_color};font-weight:bold'>{audio_analysis['quality']}</span>", unsafe_allow_html=True)
511
+ st.markdown(f"**Speech:** {audio_analysis['speech_quality']}")
512
+
513
+ # Display the recommendation
514
+ st.info(audio_analysis['message'])
515
+
516
+ return data
517
+
518
+ def main():
519
+ # Initialize session state to track processing status
520
+ if 'processed' not in st.session_state:
521
+ st.session_state.processed = False
522
+
523
+ data = streamlit_look()
524
+
525
+ # Add debug information
526
+ st.write("Debug info:")
527
+ if "imge_path" in data:
528
+ st.write(f"Image/Video path: {data['imge_path']}")
529
+ else:
530
+ st.write("No image/video selected yet")
531
+
532
+ if "audio" in data:
533
+ st.write("Audio file selected ✓")
534
+ else:
535
+ st.write("No audio selected yet")
536
+
537
+ # Only proceed if we have both image/video and audio data
538
+ if "imge_path" in data and "audio" in data:
539
+ st.write("This app will automatically save your audio when you click animate.")
540
+ save_record = st.button("save record manually")
541
+ st.write("With fast animation only the lips of the avatar will move, and it will take probably less than a minute for a record of about 30 seconds, but with slow animation choice, the full face of the avatar will move and it will take about 30 minutes for a record of about 30 seconds to get ready.")
542
+ model = load_model("wav2lip_checkpoints/wav2lip_gan.pth")
543
+
544
+ # Check for duration mismatches between video and audio
545
+ if data["imge_path"].endswith('.mp4'):
546
+ # Save audio to temp file for analysis
547
+ if not os.path.exists('record.wav'):
548
+ with open('record.wav', mode='wb') as f:
549
+ f.write(data["audio"])
550
+
551
+ # Get durations
552
+ video_duration = get_video_duration(data["imge_path"])
553
+ audio_duration = get_audio_duration('record.wav')
554
+
555
+ # Check for significant duration mismatch (more than 2 seconds difference)
556
+ if abs(video_duration - audio_duration) > 2:
557
+ st.warning(f"⚠️ Duration mismatch detected: Video is {video_duration:.1f}s and Audio is {audio_duration:.1f}s")
558
+
559
+ # Create a tab for handling duration mismatches
560
+ with st.expander("Duration Mismatch Options (Click to expand)", expanded=True):
561
+ st.info("The video and audio have different durations. Choose an option below:")
562
+
563
+ if video_duration > audio_duration:
564
+ if st.button("Trim Video to Match Audio Duration"):
565
+ # Update duration values to match
566
+ output_path = 'uploads/trimmed_input_video.mp4'
567
+ with st.spinner(f"Trimming video from {video_duration:.1f}s to {audio_duration:.1f}s..."):
568
+ success = trim_video(data["imge_path"], output_path, 0, audio_duration)
569
+
570
+ if success:
571
+ st.success("Video trimmed to match audio duration!")
572
+ # Update the image path to use the trimmed video
573
+ data["imge_path"] = output_path
574
+ st.video(output_path)
575
+ else: # audio_duration > video_duration
576
+ if st.button("Trim Audio to Match Video Duration"):
577
+ # Update duration values to match
578
+ output_path = 'uploads/trimmed_input_audio.wav'
579
+ with st.spinner(f"Trimming audio from {audio_duration:.1f}s to {video_duration:.1f}s..."):
580
+ success = trim_audio('record.wav', output_path, 0, video_duration)
581
+
582
+ if success:
583
+ st.success("Audio trimmed to match video duration!")
584
+ # Update the audio data with the trimmed audio
585
+ with open(output_path, "rb") as f:
586
+ data["audio"] = f.read()
587
+ # Save the trimmed audio as record.wav
588
+ with open('record.wav', mode='wb') as f:
589
+ f.write(data["audio"])
590
+ st.audio(output_path)
591
+
592
+ # Animation buttons
593
+ fast_animate = st.button("fast animate")
594
+ slower_animate = st.button("slower animate")
595
+
596
+ # Function to save the audio record
597
+ def save_audio_record():
598
+ if os.path.exists('record.wav'):
599
+ os.remove('record.wav')
600
+ with open('record.wav', mode='wb') as f:
601
+ f.write(data["audio"])
602
+ st.write("Audio record saved!")
603
+
604
+ if save_record:
605
+ save_audio_record()
606
+
607
+ # Show previously generated results if they exist and we're not generating new ones
608
+ if os.path.exists('wav2lip/results/result_voice.mp4') and st.session_state.processed and not (fast_animate or slower_animate):
609
+ st.video('wav2lip/results/result_voice.mp4')
610
+ display_trim_options('wav2lip/results/result_voice.mp4')
611
+
612
+ if fast_animate:
613
+ # Automatically save the record before animation
614
+ save_audio_record()
615
+
616
+ progress_placeholder = st.empty()
617
+ status_placeholder = st.empty()
618
+
619
+ progress_bar = progress_placeholder.progress(0, text="Processing: 0% complete")
620
+ status_placeholder.info("Preparing to process...")
621
+
622
+ # Call the inference function inside a try block with progress updates at key points
623
+ try:
624
+ # Initialize a progress tracker
625
+ progress_steps = [
626
+ (0, "Starting processing..."),
627
+ (15, "Step 1/4: Loading and analyzing video frames"),
628
+ (30, "Step 2/4: Performing face detection (this may take a while for long videos)"),
629
+ (60, "Step 3/4: Generating lip-synced frames"),
630
+ (80, "Step 4/4: Creating final video with audio"),
631
+ (100, "Processing complete!")
632
+ ]
633
+ current_step = 0
634
+
635
+ # Redirect stdout to capture progress information
636
+ import io
637
+ sys.stdout = io.StringIO()
638
+
639
+ # Update progress for the initial step
640
+ progress, message = progress_steps[current_step]
641
+ progress_bar.progress(progress, text=f"Processing: {progress}% complete")
642
+ status_placeholder.info(message)
643
+ current_step += 1
644
+
645
+ # Run the inference in a background thread
646
+ with concurrent.futures.ThreadPoolExecutor() as executor:
647
+ # Start the inference process
648
+ future = executor.submit(inference.main, data["imge_path"], "record.wav", model)
649
+
650
+ # Monitor the output for progress indicators
651
+ while not future.done():
652
+ captured_output = sys.stdout.getvalue()
653
+
654
+ # Check for progress indicators and update UI
655
+ if current_step < len(progress_steps):
656
+ # Check for stage 1 completion: frames read
657
+ if current_step == 1 and "Number of frames available for inference" in captured_output:
658
+ progress, message = progress_steps[current_step]
659
+ progress_bar.progress(progress, text=f"Processing: {progress}% complete")
660
+ status_placeholder.info(message)
661
+ current_step += 1
662
+ # Check for stage 2 completion: face detection
663
+ elif current_step == 2 and "Face detection completed successfully" in captured_output:
664
+ progress, message = progress_steps[current_step]
665
+ progress_bar.progress(progress, text=f"Processing: {progress}% complete")
666
+ status_placeholder.info(message)
667
+ current_step += 1
668
+ # Check for stage 3 completion: ffmpeg started
669
+ elif current_step == 3 and "ffmpeg" in captured_output:
670
+ progress, message = progress_steps[current_step]
671
+ progress_bar.progress(progress, text=f"Processing: {progress}% complete")
672
+ status_placeholder.info(message)
673
+ current_step += 1
674
+
675
+ # Sleep to avoid excessive CPU usage
676
+ time.sleep(0.5)
677
+
678
+ try:
679
+ # Get the result or propagate exceptions
680
+ future.result()
681
+
682
+ # Show completion
683
+ progress, message = progress_steps[-1]
684
+ progress_bar.progress(progress, text=f"Processing: {progress}% complete")
685
+ status_placeholder.success("Lip sync complete! Your video is ready.")
686
+ except Exception as e:
687
+ raise e
688
+
689
+ # Restore stdout
690
+ sys.stdout = sys.__stdout__
691
+
692
+ if os.path.exists('wav2lip/results/result_voice.mp4'):
693
+ st.video('wav2lip/results/result_voice.mp4')
694
+ display_trim_options('wav2lip/results/result_voice.mp4')
695
+ # Set processed flag to True after successful processing
696
+ st.session_state.processed = True
697
+
698
+ except Exception as e:
699
+ # Restore stdout in case of error
700
+ sys.stdout = sys.__stdout__
701
+
702
+ progress_placeholder.empty()
703
+ status_placeholder.error(f"Error during processing: {str(e)}")
704
+ st.error("Failed to generate video. Please try again or use a different image/audio.")
705
+
706
+ if slower_animate:
707
+ # Automatically save the record before animation
708
+ save_audio_record()
709
+
710
+ progress_placeholder = st.empty()
711
+ status_placeholder = st.empty()
712
+
713
+ progress_bar = progress_placeholder.progress(0, text="Processing: 0% complete")
714
+ status_placeholder.info("Preparing to process...")
715
+
716
+ # Derive the video path from the selected avatar
717
+ if data["imge_path"].endswith('.mp4'):
718
+ video_path = data["imge_path"]
719
+ else:
720
+ # Get the avatar video path for the selected avatar
721
+ avatar_list = load_avatar_videos_for_slow_animation("./data/avatars/samples")
722
+ video_path = avatar_list[available_avatars_for_slow.index(avatar_choice)]
723
+
724
+ try:
725
+ # Initialize a progress tracker
726
+ progress_steps = [
727
+ (0, "Starting processing..."),
728
+ (15, "Step 1/4: Loading and analyzing video frames"),
729
+ (30, "Step 2/4: Performing face detection (this may take a while for long videos)"),
730
+ (60, "Step 3/4: Generating lip-synced frames with full-face animation"),
731
+ (80, "Step 4/4: Creating final video with audio"),
732
+ (100, "Processing complete!")
733
+ ]
734
+ current_step = 0
735
+
736
+ # Redirect stdout to capture progress information
737
+ import io
738
+ sys.stdout = io.StringIO()
739
+
740
+ # Update progress for the initial step
741
+ progress, message = progress_steps[current_step]
742
+ progress_bar.progress(progress, text=f"Processing: {progress}% complete")
743
+ status_placeholder.info(message)
744
+ current_step += 1
745
+
746
+ # Run the inference in a background thread
747
+ with concurrent.futures.ThreadPoolExecutor() as executor:
748
+ # Start the inference process
749
+ future = executor.submit(inference.main, video_path, "record.wav", model, slow_mode=True)
750
+
751
+ # Monitor the output for progress indicators
752
+ while not future.done():
753
+ captured_output = sys.stdout.getvalue()
754
+
755
+ # Check for progress indicators and update UI
756
+ if current_step < len(progress_steps):
757
+ # Check for stage 1 completion: frames read
758
+ if current_step == 1 and "Number of frames available for inference" in captured_output:
759
+ progress, message = progress_steps[current_step]
760
+ progress_bar.progress(progress, text=f"Processing: {progress}% complete")
761
+ status_placeholder.info(message)
762
+ current_step += 1
763
+ # Check for stage 2 completion: face detection
764
+ elif current_step == 2 and "Face detection completed successfully" in captured_output:
765
+ progress, message = progress_steps[current_step]
766
+ progress_bar.progress(progress, text=f"Processing: {progress}% complete")
767
+ status_placeholder.info(message)
768
+ current_step += 1
769
+ # Check for stage 3 completion: ffmpeg started
770
+ elif current_step == 3 and "ffmpeg" in captured_output:
771
+ progress, message = progress_steps[current_step]
772
+ progress_bar.progress(progress, text=f"Processing: {progress}% complete")
773
+ status_placeholder.info(message)
774
+ current_step += 1
775
+
776
+ # Sleep to avoid excessive CPU usage
777
+ time.sleep(0.5)
778
+
779
+ try:
780
+ # Get the result or propagate exceptions
781
+ future.result()
782
+
783
+ # Show completion
784
+ progress, message = progress_steps[-1]
785
+ progress_bar.progress(progress, text=f"Processing: {progress}% complete")
786
+ status_placeholder.success("Lip sync complete! Your video is ready.")
787
+ except Exception as e:
788
+ raise e
789
+
790
+ # Restore stdout
791
+ sys.stdout = sys.__stdout__
792
+
793
+ if os.path.exists('wav2lip/results/result_voice.mp4'):
794
+ st.video('wav2lip/results/result_voice.mp4')
795
+ display_trim_options('wav2lip/results/result_voice.mp4')
796
+ # Set processed flag to True after successful processing
797
+ st.session_state.processed = True
798
+ except Exception as e:
799
+ # Restore stdout in case of error
800
+ sys.stdout = sys.__stdout__
801
+
802
+ progress_placeholder.empty()
803
+ status_placeholder.error(f"Error during processing: {str(e)}")
804
+ st.error("Failed to generate video. Please try again or use a different video/audio.")
805
+ else:
806
+ if "imge_path" not in data and "audio" not in data:
807
+ st.warning("Please upload both an image/video AND provide audio to continue.")
808
+ elif "imge_path" not in data:
809
+ st.warning("Please select or upload an image/video to continue.")
810
+ else:
811
+ st.warning("Please provide audio to continue.")
812
+
813
+ # Function to display trim options and handle video trimming
814
+ def display_trim_options(video_path):
815
+ """Display options to trim the video and handle the trimming process"""
816
+ st.subheader("Video Processing Options")
817
+
818
+ # Check if the video exists first
819
+ if not os.path.exists(video_path):
820
+ st.error(f"Video file not found at {video_path}. Try running the animation again.")
821
+ return
822
+
823
+ # Add tabs for different operations
824
+ download_tab, trim_tab = st.tabs(["Download Original", "Trim Video"])
825
+
826
+ with download_tab:
827
+ st.write("Download the original generated video:")
828
+ try:
829
+ st.video(video_path)
830
+ st.download_button(
831
+ label="Download Original Video",
832
+ data=open(video_path, 'rb').read(),
833
+ file_name="original_lip_sync_video.mp4",
834
+ mime="video/mp4"
835
+ )
836
+ except Exception as e:
837
+ st.error(f"Error loading video: {str(e)}")
838
+
839
+ with trim_tab:
840
+ st.write("You can trim the generated video to remove unwanted parts from the beginning or end.")
841
+
842
+ duration = get_video_duration(video_path)
843
+ if duration <= 0:
844
+ st.error("Could not determine video duration")
845
+ return
846
+
847
+ # Display video duration
848
+ st.write(f"Video duration: {duration:.2f} seconds")
849
+
850
+ # Create a slider for selecting start and end times
851
+ col1, col2 = st.columns(2)
852
+
853
+ with col1:
854
+ start_time = st.slider("Start time (seconds)",
855
+ min_value=0.0,
856
+ max_value=float(duration),
857
+ value=0.0,
858
+ step=0.1)
859
+ st.write(f"Start at: {start_time:.1f}s")
860
+
861
+ with col2:
862
+ end_time = st.slider("End time (seconds)",
863
+ min_value=0.0,
864
+ max_value=float(duration),
865
+ value=float(duration),
866
+ step=0.1)
867
+ st.write(f"End at: {end_time:.1f}s")
868
+
869
+ # Display trim duration
870
+ trim_duration = end_time - start_time
871
+ st.info(f"Trimmed video duration will be: {trim_duration:.1f} seconds")
872
+
873
+ # Validate the selected range
874
+ if start_time >= end_time:
875
+ st.error("Start time must be less than end time")
876
+ return
877
+
878
+ # Button to perform trimming
879
+ if st.button("Trim Video"):
880
+ # Generate output path
881
+ output_path = 'wav2lip/results/trimmed_video.mp4'
882
+
883
+ # Show progress
884
+ with st.spinner("Trimming video..."):
885
+ success = trim_video(video_path, output_path, start_time, end_time)
886
+
887
+ if success:
888
+ st.success("Video trimmed successfully!")
889
+ try:
890
+ st.video(output_path)
891
+
892
+ # Add download button for trimmed video
893
+ st.download_button(
894
+ label="Download Trimmed Video",
895
+ data=open(output_path, 'rb').read(),
896
+ file_name="trimmed_lip_sync_video.mp4",
897
+ mime="video/mp4"
898
+ )
899
+ except Exception as e:
900
+ st.error(f"Error displaying trimmed video: {str(e)}")
901
+ else:
902
+ st.error("Failed to trim video. Try again with different timing parameters.")
903
+
904
+ # Function to trim video using ffmpeg
905
+ def trim_video(input_path, output_path, start_time, end_time):
906
+ """
907
+ Trim a video using ffmpeg from start_time to end_time.
908
+
909
+ Args:
910
+ input_path: Path to the input video
911
+ output_path: Path to save the trimmed video
912
+ start_time: Start time in seconds
913
+ end_time: End time in seconds
914
+
915
+ Returns:
916
+ bool: True if successful, False otherwise
917
+ """
918
+ try:
919
+ # Check if input file exists
920
+ if not os.path.exists(input_path):
921
+ st.error(f"Input video not found: {input_path}")
922
+ return False
923
+
924
+ # Format the command - use -ss before -i for faster seeking
925
+ # Add quotes around file paths to handle spaces and special characters
926
+ command = f'ffmpeg -y -ss {start_time} -i "{input_path}" -to {end_time} -c:v copy -c:a copy "{output_path}"'
927
+
928
+ # Use subprocess.run for better error handling
929
+ result = subprocess.run(
930
+ command,
931
+ shell=True,
932
+ stdout=subprocess.PIPE,
933
+ stderr=subprocess.PIPE,
934
+ text=True
935
+ )
936
+
937
+ if result.returncode != 0:
938
+ st.error(f"FFMPEG error: {result.stderr}")
939
+ return False
940
+
941
+ # Verify the output file exists and has a size greater than 0
942
+ if os.path.exists(output_path) and os.path.getsize(output_path) > 0:
943
+ return True
944
+ else:
945
+ st.error("Output file was not created correctly")
946
+ return False
947
+
948
+ except Exception as e:
949
+ st.error(f"Error trimming video: {str(e)}")
950
+ return False
951
+
952
+ # Function to get video duration
953
+ def get_video_duration(video_path):
954
+ """Get the duration of a video file in seconds"""
955
+ try:
956
+ video = cv2.VideoCapture(video_path)
957
+ fps = video.get(cv2.CAP_PROP_FPS)
958
+ frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
959
+ video.release()
960
+
961
+ duration = frame_count / fps if fps > 0 else 0
962
+ return duration
963
+ except Exception as e:
964
+ st.error(f"Error getting video duration: {str(e)}")
965
+ return 0
966
+
967
+ # Function to get audio duration
968
+ def get_audio_duration(audio_path):
969
+ """Get the duration of an audio file in seconds"""
970
+ try:
971
+ y, sr = librosa.load(audio_path, sr=None)
972
+ duration = librosa.get_duration(y=y, sr=sr)
973
+ return duration
974
+ except Exception as e:
975
+ st.error(f"Error getting audio duration: {str(e)}")
976
+ return 0
977
+
978
+ # Function to trim audio file
979
+ def trim_audio(input_path, output_path, start_time, end_time):
980
+ """Trim an audio file to the specified start and end times"""
981
+ try:
982
+ # Command to trim audio using ffmpeg
983
+ command = f'ffmpeg -y -i "{input_path}" -ss {start_time} -to {end_time} -c copy "{output_path}"'
984
+
985
+ # Execute the command
986
+ subprocess.call(command, shell=True)
987
+
988
+ # Check if output file exists
989
+ if os.path.exists(output_path):
990
+ return True
991
+ else:
992
+ st.error("Output audio file was not created correctly")
993
+ return False
994
+
995
+ except Exception as e:
996
+ st.error(f"Error trimming audio: {str(e)}")
997
+ return False
998
+
999
+ if __name__ == "__main__":
1000
+ main()
avatars_images/avatar1.jpg ADDED
avatars_images/avatar2.jpg ADDED
avatars_images/avatar3.png ADDED
packages.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ python3-opencv
2
+ libgl1-mesa-dev
3
+ ffmpeg
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy==1.26.3
2
+ scipy==1.12.0
3
+ iou==0.1.0
4
+ librosa==0.10.1
5
+ opencv_contrib_python==4.9.0.80
6
+ streamlit==1.31.0
7
+ streamlit_image_select==0.6.0
8
+ streamlit_mic_recorder==0.0.4
9
+ torch==2.2.1
10
+ tqdm==4.64.1
11
+ gdown
12
+ matplotlib==3.10.1
wav2lip/audio.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+ import librosa.filters
3
+ import numpy as np
4
+ # import tensorflow as tf
5
+ from scipy import signal
6
+ from scipy.io import wavfile
7
+ from .hparams import hparams as hp
8
+
9
+ def load_wav(path, sr):
10
+ return librosa.core.load(path, sr=sr)[0]
11
+
12
+ def save_wav(wav, path, sr):
13
+ wav *= 32767 / max(0.01, np.max(np.abs(wav)))
14
+ #proposed by @dsmiller
15
+ wavfile.write(path, sr, wav.astype(np.int16))
16
+
17
+ def save_wavenet_wav(wav, path, sr):
18
+ librosa.output.write_wav(path, wav, sr=sr)
19
+
20
+ def preemphasis(wav, k, preemphasize=True):
21
+ if preemphasize:
22
+ return signal.lfilter([1, -k], [1], wav)
23
+ return wav
24
+
25
+ def inv_preemphasis(wav, k, inv_preemphasize=True):
26
+ if inv_preemphasize:
27
+ return signal.lfilter([1], [1, -k], wav)
28
+ return wav
29
+
30
+ def get_hop_size():
31
+ hop_size = hp.hop_size
32
+ if hop_size is None:
33
+ assert hp.frame_shift_ms is not None
34
+ hop_size = int(hp.frame_shift_ms / 1000 * hp.sample_rate)
35
+ return hop_size
36
+
37
+ def linearspectrogram(wav):
38
+ D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
39
+ S = _amp_to_db(np.abs(D)) - hp.ref_level_db
40
+
41
+ if hp.signal_normalization:
42
+ return _normalize(S)
43
+ return S
44
+
45
+ def melspectrogram(wav):
46
+ D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
47
+ S = _amp_to_db(_linear_to_mel(np.abs(D))) - hp.ref_level_db
48
+
49
+ if hp.signal_normalization:
50
+ return _normalize(S)
51
+ return S
52
+
53
+ def _lws_processor():
54
+ import lws
55
+ return lws.lws(hp.n_fft, get_hop_size(), fftsize=hp.win_size, mode="speech")
56
+
57
+ def _stft(y):
58
+ if hp.use_lws:
59
+ return _lws_processor(hp).stft(y).T
60
+ else:
61
+ return librosa.stft(y=y, n_fft=hp.n_fft, hop_length=get_hop_size(), win_length=hp.win_size)
62
+
63
+ ##########################################################
64
+ #Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!)
65
+ def num_frames(length, fsize, fshift):
66
+ """Compute number of time frames of spectrogram
67
+ """
68
+ pad = (fsize - fshift)
69
+ if length % fshift == 0:
70
+ M = (length + pad * 2 - fsize) // fshift + 1
71
+ else:
72
+ M = (length + pad * 2 - fsize) // fshift + 2
73
+ return M
74
+
75
+
76
+ def pad_lr(x, fsize, fshift):
77
+ """Compute left and right padding
78
+ """
79
+ M = num_frames(len(x), fsize, fshift)
80
+ pad = (fsize - fshift)
81
+ T = len(x) + 2 * pad
82
+ r = (M - 1) * fshift + fsize - T
83
+ return pad, pad + r
84
+ ##########################################################
85
+ #Librosa correct padding
86
+ def librosa_pad_lr(x, fsize, fshift):
87
+ return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0]
88
+
89
+ # Conversions
90
+ _mel_basis = None
91
+
92
+ def _linear_to_mel(spectogram):
93
+ global _mel_basis
94
+ if _mel_basis is None:
95
+ _mel_basis = _build_mel_basis()
96
+ return np.dot(_mel_basis, spectogram)
97
+
98
+
99
+ def _build_mel_basis():
100
+ assert hp.fmax <= hp.sample_rate // 2
101
+ return librosa.filters.mel(sr=hp.sample_rate, n_fft=hp.n_fft, n_mels=hp.num_mels,
102
+ fmin=hp.fmin, fmax=hp.fmax)
103
+ def _amp_to_db(x):
104
+ min_level = np.exp(hp.min_level_db / 20 * np.log(10))
105
+ return 20 * np.log10(np.maximum(min_level, x))
106
+
107
+ def _db_to_amp(x):
108
+ return np.power(10.0, (x) * 0.05)
109
+
110
+ def _normalize(S):
111
+ if hp.allow_clipping_in_normalization:
112
+ if hp.symmetric_mels:
113
+ return np.clip((2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value,
114
+ -hp.max_abs_value, hp.max_abs_value)
115
+ else:
116
+ return np.clip(hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)), 0, hp.max_abs_value)
117
+
118
+ assert S.max() <= 0 and S.min() - hp.min_level_db >= 0
119
+ if hp.symmetric_mels:
120
+ return (2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value
121
+ else:
122
+ return hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db))
123
+
124
+ def _denormalize(D):
125
+ if hp.allow_clipping_in_normalization:
126
+ if hp.symmetric_mels:
127
+ return (((np.clip(D, -hp.max_abs_value,
128
+ hp.max_abs_value) + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value))
129
+ + hp.min_level_db)
130
+ else:
131
+ return ((np.clip(D, 0, hp.max_abs_value) * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db)
132
+
133
+ if hp.symmetric_mels:
134
+ return (((D + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) + hp.min_level_db)
135
+ else:
136
+ return ((D * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db)
wav2lip/face_detection/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ The code for Face Detection in this folder has been taken from the wonderful [face_alignment](https://github.com/1adrianb/face-alignment) repository. This has been modified to take batches of faces at a time.
wav2lip/face_detection/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ __author__ = """Adrian Bulat"""
4
+ __email__ = '[email protected]'
5
+ __version__ = '1.0.1'
6
+
7
+ from .api import FaceAlignment, LandmarksType, NetworkSize
wav2lip/face_detection/api.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ import os
3
+ import torch
4
+ from torch.utils.model_zoo import load_url
5
+ from enum import Enum
6
+ import numpy as np
7
+ import cv2
8
+ from .detection import sfd
9
+ try:
10
+ import urllib.request as request_file
11
+ except BaseException:
12
+ import urllib as request_file
13
+
14
+ from .models import FAN, ResNetDepth
15
+ from .utils import *
16
+
17
+
18
+ class LandmarksType(Enum):
19
+ """Enum class defining the type of landmarks to detect.
20
+
21
+ ``_2D`` - the detected points ``(x,y)`` are detected in a 2D space and follow the visible contour of the face
22
+ ``_2halfD`` - this points represent the projection of the 3D points into 3D
23
+ ``_3D`` - detect the points ``(x,y,z)``` in a 3D space
24
+
25
+ """
26
+ _2D = 1
27
+ _2halfD = 2
28
+ _3D = 3
29
+
30
+
31
+ class NetworkSize(Enum):
32
+ # TINY = 1
33
+ # SMALL = 2
34
+ # MEDIUM = 3
35
+ LARGE = 4
36
+
37
+ def __new__(cls, value):
38
+ member = object.__new__(cls)
39
+ member._value_ = value
40
+ return member
41
+
42
+ def __int__(self):
43
+ return self.value
44
+
45
+ ROOT = os.path.dirname(os.path.abspath(__file__))
46
+
47
+ class FaceAlignment:
48
+ def __init__(self, landmarks_type, network_size=NetworkSize.LARGE,
49
+ device='cuda', flip_input=False, face_detector='sfd', verbose=False):
50
+ self.device = device
51
+ self.flip_input = flip_input
52
+ self.landmarks_type = landmarks_type
53
+ self.verbose = verbose
54
+
55
+ network_size = int(network_size)
56
+
57
+ if 'cuda' in device or 'mps' in device:
58
+ torch.backends.cudnn.benchmark = True
59
+ if 'mps' in device and verbose:
60
+ print("Using Apple Silicon GPU (MPS) for face detection.")
61
+
62
+ # Get the face detector
63
+ #face_detector_module = __import__('from .detection. import' + face_detector,
64
+ # globals(), locals(), [face_detector], 0)
65
+ #self.face_detector = face_detector_module.FaceDetector(device=device, verbose=verbose)
66
+ try:
67
+ self.face_detector = sfd.FaceDetector(device=device, verbose=verbose)
68
+ except Exception as e:
69
+ print(f"Error initializing face detector: {e}")
70
+ print("Falling back to CPU for face detection.")
71
+ # If detection fails on GPU (MPS/CUDA), fall back to CPU
72
+ self.device = 'cpu'
73
+ self.face_detector = sfd.FaceDetector(device='cpu', verbose=verbose)
74
+
75
+ def get_detections_for_batch(self, images):
76
+ """
77
+ Returns a list of bounding boxes for each image in the batch.
78
+ If no face is detected, returns None for that image.
79
+ """
80
+ try:
81
+ # Convert to RGB for face detection
82
+ images = images.copy()
83
+ if images.shape[-1] == 3:
84
+ images = images[..., ::-1] # BGR to RGB
85
+
86
+ # Get face detections
87
+ detected_faces = self.face_detector.detect_from_batch(images)
88
+
89
+ results = []
90
+ for i, d in enumerate(detected_faces):
91
+ if len(d) == 0:
92
+ # No face detected
93
+ results.append(None)
94
+ continue
95
+
96
+ # Use the first (highest confidence) face
97
+ d = d[0]
98
+ # Ensure values are valid
99
+ d = np.clip(d, 0, None)
100
+
101
+ # Extract coordinates
102
+ try:
103
+ x1, y1, x2, y2 = map(int, d[:-1])
104
+ # Sanity check on coordinates
105
+ if x1 >= x2 or y1 >= y2 or x1 < 0 or y1 < 0:
106
+ print(f"Invalid face coordinates: {(x1, y1, x2, y2)}")
107
+ results.append(None)
108
+ else:
109
+ results.append((x1, y1, x2, y2))
110
+ except Exception as e:
111
+ print(f"Error processing detection: {str(e)}")
112
+ results.append(None)
113
+
114
+ return results
115
+
116
+ except Exception as e:
117
+ print(f"Error in batch face detection: {str(e)}")
118
+ # Return None for all images
119
+ return [None] * len(images)
wav2lip/face_detection/detection/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .core import FaceDetector
wav2lip/face_detection/detection/core.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import glob
3
+ from tqdm import tqdm
4
+ import numpy as np
5
+ import torch
6
+ import cv2
7
+
8
+
9
+ class FaceDetector(object):
10
+ """An abstract class representing a face detector.
11
+
12
+ Any other face detection implementation must subclass it. All subclasses
13
+ must implement ``detect_from_image``, that return a list of detected
14
+ bounding boxes. Optionally, for speed considerations detect from path is
15
+ recommended.
16
+ """
17
+
18
+ def __init__(self, device, verbose):
19
+ self.device = device
20
+ self.verbose = verbose
21
+
22
+ if verbose:
23
+ if 'cpu' in device:
24
+ logger = logging.getLogger(__name__)
25
+ logger.warning("Detection running on CPU, this may be potentially slow.")
26
+ elif 'mps' in device:
27
+ logger = logging.getLogger(__name__)
28
+ logger.info("Detection running on Apple Silicon GPU (MPS).")
29
+
30
+ if 'cpu' not in device and 'cuda' not in device and 'mps' not in device:
31
+ logger = logging.getLogger(__name__)
32
+ if verbose:
33
+ logger.error("Expected values for device are: {cpu, cuda, mps} but got: %s", device)
34
+ raise ValueError(f"Invalid device type: {device}. Expected 'cpu', 'cuda', or 'mps'.")
35
+
36
+ def detect_from_image(self, tensor_or_path):
37
+ """Detects faces in a given image.
38
+
39
+ This function detects the faces present in a provided BGR(usually)
40
+ image. The input can be either the image itself or the path to it.
41
+
42
+ Arguments:
43
+ tensor_or_path {numpy.ndarray, torch.tensor or string} -- the path
44
+ to an image or the image itself.
45
+
46
+ Example::
47
+
48
+ >>> path_to_image = 'data/image_01.jpg'
49
+ ... detected_faces = detect_from_image(path_to_image)
50
+ [A list of bounding boxes (x1, y1, x2, y2)]
51
+ >>> image = cv2.imread(path_to_image)
52
+ ... detected_faces = detect_from_image(image)
53
+ [A list of bounding boxes (x1, y1, x2, y2)]
54
+
55
+ """
56
+ raise NotImplementedError
57
+
58
+ def detect_from_directory(self, path, extensions=['.jpg', '.png'], recursive=False, show_progress_bar=True):
59
+ """Detects faces from all the images present in a given directory.
60
+
61
+ Arguments:
62
+ path {string} -- a string containing a path that points to the folder containing the images
63
+
64
+ Keyword Arguments:
65
+ extensions {list} -- list of string containing the extensions to be
66
+ consider in the following format: ``.extension_name`` (default:
67
+ {['.jpg', '.png']}) recursive {bool} -- option wherever to scan the
68
+ folder recursively (default: {False}) show_progress_bar {bool} --
69
+ display a progressbar (default: {True})
70
+
71
+ Example:
72
+ >>> directory = 'data'
73
+ ... detected_faces = detect_from_directory(directory)
74
+ {A dictionary of [lists containing bounding boxes(x1, y1, x2, y2)]}
75
+
76
+ """
77
+ if self.verbose:
78
+ logger = logging.getLogger(__name__)
79
+
80
+ if len(extensions) == 0:
81
+ if self.verbose:
82
+ logger.error("Expected at list one extension, but none was received.")
83
+ raise ValueError
84
+
85
+ if self.verbose:
86
+ logger.info("Constructing the list of images.")
87
+ additional_pattern = '/**/*' if recursive else '/*'
88
+ files = []
89
+ for extension in extensions:
90
+ files.extend(glob.glob(path + additional_pattern + extension, recursive=recursive))
91
+
92
+ if self.verbose:
93
+ logger.info("Finished searching for images. %s images found", len(files))
94
+ logger.info("Preparing to run the detection.")
95
+
96
+ predictions = {}
97
+ for image_path in tqdm(files, disable=not show_progress_bar):
98
+ if self.verbose:
99
+ logger.info("Running the face detector on image: %s", image_path)
100
+ predictions[image_path] = self.detect_from_image(image_path)
101
+
102
+ if self.verbose:
103
+ logger.info("The detector was successfully run on all %s images", len(files))
104
+
105
+ return predictions
106
+
107
+ @property
108
+ def reference_scale(self):
109
+ raise NotImplementedError
110
+
111
+ @property
112
+ def reference_x_shift(self):
113
+ raise NotImplementedError
114
+
115
+ @property
116
+ def reference_y_shift(self):
117
+ raise NotImplementedError
118
+
119
+ @staticmethod
120
+ def tensor_or_path_to_ndarray(tensor_or_path, rgb=True):
121
+ """Convert path (represented as a string) or torch.tensor to a numpy.ndarray
122
+
123
+ Arguments:
124
+ tensor_or_path {numpy.ndarray, torch.tensor or string} -- path to the image, or the image itself
125
+ """
126
+ if isinstance(tensor_or_path, str):
127
+ return cv2.imread(tensor_or_path) if not rgb else cv2.imread(tensor_or_path)[..., ::-1]
128
+ elif torch.is_tensor(tensor_or_path):
129
+ # Call cpu in case its coming from cuda
130
+ return tensor_or_path.cpu().numpy()[..., ::-1].copy() if not rgb else tensor_or_path.cpu().numpy()
131
+ elif isinstance(tensor_or_path, np.ndarray):
132
+ return tensor_or_path[..., ::-1].copy() if not rgb else tensor_or_path
133
+ else:
134
+ raise TypeError
wav2lip/face_detection/detection/sfd/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .sfd_detector import SFDDetector as FaceDetector
wav2lip/face_detection/detection/sfd/bbox.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ import os
3
+ import sys
4
+ import cv2
5
+ import random
6
+ import datetime
7
+ import time
8
+ import math
9
+ import argparse
10
+ import numpy as np
11
+ import torch
12
+
13
+ try:
14
+ from iou import IOU
15
+ except BaseException:
16
+ # IOU cython speedup 10x
17
+ def IOU(ax1, ay1, ax2, ay2, bx1, by1, bx2, by2):
18
+ sa = abs((ax2 - ax1) * (ay2 - ay1))
19
+ sb = abs((bx2 - bx1) * (by2 - by1))
20
+ x1, y1 = max(ax1, bx1), max(ay1, by1)
21
+ x2, y2 = min(ax2, bx2), min(ay2, by2)
22
+ w = x2 - x1
23
+ h = y2 - y1
24
+ if w < 0 or h < 0:
25
+ return 0.0
26
+ else:
27
+ return 1.0 * w * h / (sa + sb - w * h)
28
+
29
+
30
+ def bboxlog(x1, y1, x2, y2, axc, ayc, aww, ahh):
31
+ xc, yc, ww, hh = (x2 + x1) / 2, (y2 + y1) / 2, x2 - x1, y2 - y1
32
+ dx, dy = (xc - axc) / aww, (yc - ayc) / ahh
33
+ dw, dh = math.log(ww / aww), math.log(hh / ahh)
34
+ return dx, dy, dw, dh
35
+
36
+
37
+ def bboxloginv(dx, dy, dw, dh, axc, ayc, aww, ahh):
38
+ xc, yc = dx * aww + axc, dy * ahh + ayc
39
+ ww, hh = math.exp(dw) * aww, math.exp(dh) * ahh
40
+ x1, x2, y1, y2 = xc - ww / 2, xc + ww / 2, yc - hh / 2, yc + hh / 2
41
+ return x1, y1, x2, y2
42
+
43
+
44
+ def nms(dets, thresh):
45
+ if 0 == len(dets):
46
+ return []
47
+ x1, y1, x2, y2, scores = dets[:, 0], dets[:, 1], dets[:, 2], dets[:, 3], dets[:, 4]
48
+ areas = (x2 - x1 + 1) * (y2 - y1 + 1)
49
+ order = scores.argsort()[::-1]
50
+
51
+ keep = []
52
+ while order.size > 0:
53
+ i = order[0]
54
+ keep.append(i)
55
+ xx1, yy1 = np.maximum(x1[i], x1[order[1:]]), np.maximum(y1[i], y1[order[1:]])
56
+ xx2, yy2 = np.minimum(x2[i], x2[order[1:]]), np.minimum(y2[i], y2[order[1:]])
57
+
58
+ w, h = np.maximum(0.0, xx2 - xx1 + 1), np.maximum(0.0, yy2 - yy1 + 1)
59
+ ovr = w * h / (areas[i] + areas[order[1:]] - w * h)
60
+
61
+ inds = np.where(ovr <= thresh)[0]
62
+ order = order[inds + 1]
63
+
64
+ return keep
65
+
66
+
67
+ def encode(matched, priors, variances):
68
+ """Encode the variances from the priorbox layers into the ground truth boxes
69
+ we have matched (based on jaccard overlap) with the prior boxes.
70
+ Args:
71
+ matched: (tensor) Coords of ground truth for each prior in point-form
72
+ Shape: [num_priors, 4].
73
+ priors: (tensor) Prior boxes in center-offset form
74
+ Shape: [num_priors,4].
75
+ variances: (list[float]) Variances of priorboxes
76
+ Return:
77
+ encoded boxes (tensor), Shape: [num_priors, 4]
78
+ """
79
+
80
+ # dist b/t match center and prior's center
81
+ g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2]
82
+ # encode variance
83
+ g_cxcy /= (variances[0] * priors[:, 2:])
84
+ # match wh / prior wh
85
+ g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:]
86
+ g_wh = torch.log(g_wh) / variances[1]
87
+ # return target for smooth_l1_loss
88
+ return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4]
89
+
90
+
91
+ def decode(loc, priors, variances):
92
+ """Decode locations from predictions using priors to undo
93
+ the encoding we did for offset regression at train time.
94
+ Args:
95
+ loc (tensor): location predictions for loc layers,
96
+ Shape: [num_priors,4]
97
+ priors (tensor): Prior boxes in center-offset form.
98
+ Shape: [num_priors,4].
99
+ variances: (list[float]) Variances of priorboxes
100
+ Return:
101
+ decoded bounding box predictions
102
+ """
103
+
104
+ boxes = torch.cat((
105
+ priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
106
+ priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)
107
+ boxes[:, :2] -= boxes[:, 2:] / 2
108
+ boxes[:, 2:] += boxes[:, :2]
109
+ return boxes
110
+
111
+ def batch_decode(loc, priors, variances):
112
+ """Decode locations from predictions using priors to undo
113
+ the encoding we did for offset regression at train time.
114
+ Args:
115
+ loc (tensor): location predictions for loc layers,
116
+ Shape: [num_priors,4]
117
+ priors (tensor): Prior boxes in center-offset form.
118
+ Shape: [num_priors,4].
119
+ variances: (list[float]) Variances of priorboxes
120
+ Return:
121
+ decoded bounding box predictions
122
+ """
123
+
124
+ boxes = torch.cat((
125
+ priors[:, :, :2] + loc[:, :, :2] * variances[0] * priors[:, :, 2:],
126
+ priors[:, :, 2:] * torch.exp(loc[:, :, 2:] * variances[1])), 2)
127
+ boxes[:, :, :2] -= boxes[:, :, 2:] / 2
128
+ boxes[:, :, 2:] += boxes[:, :, :2]
129
+ return boxes
wav2lip/face_detection/detection/sfd/detect.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ import os
5
+ import sys
6
+ import cv2
7
+ import random
8
+ import datetime
9
+ import math
10
+ import argparse
11
+ import numpy as np
12
+
13
+ import scipy.io as sio
14
+ import zipfile
15
+ from .net_s3fd import s3fd
16
+ from .bbox import *
17
+
18
+
19
+ def detect(net, img, device):
20
+ img = img - np.array([104, 117, 123])
21
+ img = img.transpose(2, 0, 1)
22
+ img = img.reshape((1,) + img.shape)
23
+
24
+ if 'cuda' in device or 'mps' in device:
25
+ torch.backends.cudnn.benchmark = True
26
+
27
+ img = torch.from_numpy(img).float().to(device)
28
+ BB, CC, HH, WW = img.size()
29
+ with torch.no_grad():
30
+ olist = net(img)
31
+
32
+ bboxlist = []
33
+ for i in range(len(olist) // 2):
34
+ olist[i * 2] = F.softmax(olist[i * 2], dim=1)
35
+ olist = [oelem.data.cpu() for oelem in olist]
36
+ for i in range(len(olist) // 2):
37
+ ocls, oreg = olist[i * 2], olist[i * 2 + 1]
38
+ FB, FC, FH, FW = ocls.size() # feature map size
39
+ stride = 2**(i + 2) # 4,8,16,32,64,128
40
+ anchor = stride * 4
41
+ poss = zip(*np.where(ocls[:, 1, :, :] > 0.05))
42
+ for Iindex, hindex, windex in poss:
43
+ axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride
44
+ score = ocls[0, 1, hindex, windex]
45
+ loc = oreg[0, :, hindex, windex].contiguous().view(1, 4)
46
+ priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]])
47
+ variances = [0.1, 0.2]
48
+ box = decode(loc, priors, variances)
49
+ x1, y1, x2, y2 = box[0] * 1.0
50
+ # cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1)
51
+ bboxlist.append([x1, y1, x2, y2, score])
52
+ bboxlist = np.array(bboxlist)
53
+ if 0 == len(bboxlist):
54
+ bboxlist = np.zeros((1, 5))
55
+
56
+ return bboxlist
57
+
58
+ def batch_detect(net, imgs, device):
59
+ imgs = imgs - np.array([104, 117, 123])
60
+ imgs = imgs.transpose(0, 3, 1, 2)
61
+
62
+ if 'cuda' in device or 'mps' in device:
63
+ torch.backends.cudnn.benchmark = True
64
+
65
+ imgs = torch.from_numpy(imgs).float().to(device)
66
+ BB, CC, HH, WW = imgs.size()
67
+ with torch.no_grad():
68
+ olist = net(imgs)
69
+
70
+ bboxlist = []
71
+ for i in range(len(olist) // 2):
72
+ olist[i * 2] = F.softmax(olist[i * 2], dim=1)
73
+ olist = [oelem.data.cpu() for oelem in olist]
74
+ for i in range(len(olist) // 2):
75
+ ocls, oreg = olist[i * 2], olist[i * 2 + 1]
76
+ FB, FC, FH, FW = ocls.size() # feature map size
77
+ stride = 2**(i + 2) # 4,8,16,32,64,128
78
+ anchor = stride * 4
79
+ poss = zip(*np.where(ocls[:, 1, :, :] > 0.05))
80
+ for Iindex, hindex, windex in poss:
81
+ axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride
82
+ score = ocls[:, 1, hindex, windex]
83
+ loc = oreg[:, :, hindex, windex].contiguous().view(BB, 1, 4)
84
+ priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]]).view(1, 1, 4)
85
+ variances = [0.1, 0.2]
86
+ box = batch_decode(loc, priors, variances)
87
+ box = box[:, 0] * 1.0
88
+ # cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1)
89
+ bboxlist.append(torch.cat([box, score.unsqueeze(1)], 1).cpu().numpy())
90
+ bboxlist = np.array(bboxlist)
91
+ if 0 == len(bboxlist):
92
+ bboxlist = np.zeros((1, BB, 5))
93
+
94
+ return bboxlist
95
+
96
+ def flip_detect(net, img, device):
97
+ img = cv2.flip(img, 1)
98
+ b = detect(net, img, device)
99
+
100
+ bboxlist = np.zeros(b.shape)
101
+ bboxlist[:, 0] = img.shape[1] - b[:, 2]
102
+ bboxlist[:, 1] = b[:, 1]
103
+ bboxlist[:, 2] = img.shape[1] - b[:, 0]
104
+ bboxlist[:, 3] = b[:, 3]
105
+ bboxlist[:, 4] = b[:, 4]
106
+ return bboxlist
107
+
108
+
109
+ def pts_to_bb(pts):
110
+ min_x, min_y = np.min(pts, axis=0)
111
+ max_x, max_y = np.max(pts, axis=0)
112
+ return np.array([min_x, min_y, max_x, max_y])
wav2lip/face_detection/detection/sfd/net_s3fd.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class L2Norm(nn.Module):
7
+ def __init__(self, n_channels, scale=1.0):
8
+ super(L2Norm, self).__init__()
9
+ self.n_channels = n_channels
10
+ self.scale = scale
11
+ self.eps = 1e-10
12
+ self.weight = nn.Parameter(torch.Tensor(self.n_channels))
13
+ self.weight.data *= 0.0
14
+ self.weight.data += self.scale
15
+
16
+ def forward(self, x):
17
+ norm = x.pow(2).sum(dim=1, keepdim=True).sqrt() + self.eps
18
+ x = x / norm * self.weight.view(1, -1, 1, 1)
19
+ return x
20
+
21
+
22
+ class s3fd(nn.Module):
23
+ def __init__(self):
24
+ super(s3fd, self).__init__()
25
+ self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
26
+ self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
27
+
28
+ self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
29
+ self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
30
+
31
+ self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
32
+ self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
33
+ self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
34
+
35
+ self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
36
+ self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
37
+ self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
38
+
39
+ self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
40
+ self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
41
+ self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
42
+
43
+ self.fc6 = nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=3)
44
+ self.fc7 = nn.Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0)
45
+
46
+ self.conv6_1 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0)
47
+ self.conv6_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1)
48
+
49
+ self.conv7_1 = nn.Conv2d(512, 128, kernel_size=1, stride=1, padding=0)
50
+ self.conv7_2 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
51
+
52
+ self.conv3_3_norm = L2Norm(256, scale=10)
53
+ self.conv4_3_norm = L2Norm(512, scale=8)
54
+ self.conv5_3_norm = L2Norm(512, scale=5)
55
+
56
+ self.conv3_3_norm_mbox_conf = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
57
+ self.conv3_3_norm_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
58
+ self.conv4_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
59
+ self.conv4_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
60
+ self.conv5_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
61
+ self.conv5_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
62
+
63
+ self.fc7_mbox_conf = nn.Conv2d(1024, 2, kernel_size=3, stride=1, padding=1)
64
+ self.fc7_mbox_loc = nn.Conv2d(1024, 4, kernel_size=3, stride=1, padding=1)
65
+ self.conv6_2_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
66
+ self.conv6_2_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
67
+ self.conv7_2_mbox_conf = nn.Conv2d(256, 2, kernel_size=3, stride=1, padding=1)
68
+ self.conv7_2_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
69
+
70
+ def forward(self, x):
71
+ h = F.relu(self.conv1_1(x))
72
+ h = F.relu(self.conv1_2(h))
73
+ h = F.max_pool2d(h, 2, 2)
74
+
75
+ h = F.relu(self.conv2_1(h))
76
+ h = F.relu(self.conv2_2(h))
77
+ h = F.max_pool2d(h, 2, 2)
78
+
79
+ h = F.relu(self.conv3_1(h))
80
+ h = F.relu(self.conv3_2(h))
81
+ h = F.relu(self.conv3_3(h))
82
+ f3_3 = h
83
+ h = F.max_pool2d(h, 2, 2)
84
+
85
+ h = F.relu(self.conv4_1(h))
86
+ h = F.relu(self.conv4_2(h))
87
+ h = F.relu(self.conv4_3(h))
88
+ f4_3 = h
89
+ h = F.max_pool2d(h, 2, 2)
90
+
91
+ h = F.relu(self.conv5_1(h))
92
+ h = F.relu(self.conv5_2(h))
93
+ h = F.relu(self.conv5_3(h))
94
+ f5_3 = h
95
+ h = F.max_pool2d(h, 2, 2)
96
+
97
+ h = F.relu(self.fc6(h))
98
+ h = F.relu(self.fc7(h))
99
+ ffc7 = h
100
+ h = F.relu(self.conv6_1(h))
101
+ h = F.relu(self.conv6_2(h))
102
+ f6_2 = h
103
+ h = F.relu(self.conv7_1(h))
104
+ h = F.relu(self.conv7_2(h))
105
+ f7_2 = h
106
+
107
+ f3_3 = self.conv3_3_norm(f3_3)
108
+ f4_3 = self.conv4_3_norm(f4_3)
109
+ f5_3 = self.conv5_3_norm(f5_3)
110
+
111
+ cls1 = self.conv3_3_norm_mbox_conf(f3_3)
112
+ reg1 = self.conv3_3_norm_mbox_loc(f3_3)
113
+ cls2 = self.conv4_3_norm_mbox_conf(f4_3)
114
+ reg2 = self.conv4_3_norm_mbox_loc(f4_3)
115
+ cls3 = self.conv5_3_norm_mbox_conf(f5_3)
116
+ reg3 = self.conv5_3_norm_mbox_loc(f5_3)
117
+ cls4 = self.fc7_mbox_conf(ffc7)
118
+ reg4 = self.fc7_mbox_loc(ffc7)
119
+ cls5 = self.conv6_2_mbox_conf(f6_2)
120
+ reg5 = self.conv6_2_mbox_loc(f6_2)
121
+ cls6 = self.conv7_2_mbox_conf(f7_2)
122
+ reg6 = self.conv7_2_mbox_loc(f7_2)
123
+
124
+ # max-out background label
125
+ chunk = torch.chunk(cls1, 4, 1)
126
+ bmax = torch.max(torch.max(chunk[0], chunk[1]), chunk[2])
127
+ cls1 = torch.cat([bmax, chunk[3]], dim=1)
128
+
129
+ return [cls1, reg1, cls2, reg2, cls3, reg3, cls4, reg4, cls5, reg5, cls6, reg6]
wav2lip/face_detection/detection/sfd/sfd_detector.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ from torch.utils.model_zoo import load_url
4
+
5
+ from ..core import FaceDetector
6
+
7
+ from .net_s3fd import s3fd
8
+ from .bbox import *
9
+ from .detect import *
10
+
11
+ models_urls = {
12
+ 's3fd': 'https://www.adrianbulat.com/downloads/python-fan/s3fd-619a316812.pth',
13
+ }
14
+
15
+
16
+ class SFDDetector(FaceDetector):
17
+ def __init__(self, device, path_to_detector=os.path.join(os.path.dirname(os.path.abspath(__file__)), 's3fd.pth'), verbose=False):
18
+ super(SFDDetector, self).__init__(device, verbose)
19
+
20
+ # Initialise the face detector
21
+ try:
22
+ if not os.path.isfile(path_to_detector):
23
+ model_weights = load_url(models_urls['s3fd'])
24
+ else:
25
+ # For MPS (Apple Silicon), we need to load to CPU first
26
+ if 'mps' in device:
27
+ model_weights = torch.load(path_to_detector, map_location='cpu')
28
+ else:
29
+ model_weights = torch.load(path_to_detector, map_location=device)
30
+
31
+ self.face_detector = s3fd()
32
+ self.face_detector.load_state_dict(model_weights)
33
+ self.face_detector.to(device)
34
+ self.face_detector.eval()
35
+
36
+ if verbose:
37
+ print(f"Face detector loaded successfully and moved to {device}")
38
+
39
+ except Exception as e:
40
+ if verbose:
41
+ print(f"Error loading face detector model: {str(e)}")
42
+ raise
43
+
44
+ def detect_from_image(self, tensor_or_path):
45
+ image = self.tensor_or_path_to_ndarray(tensor_or_path)
46
+
47
+ bboxlist = detect(self.face_detector, image, device=self.device)
48
+ keep = nms(bboxlist, 0.3)
49
+ bboxlist = bboxlist[keep, :]
50
+ bboxlist = [x for x in bboxlist if x[-1] > 0.5]
51
+
52
+ return bboxlist
53
+
54
+ def detect_from_batch(self, images):
55
+ bboxlists = batch_detect(self.face_detector, images, device=self.device)
56
+ keeps = [nms(bboxlists[:, i, :], 0.3) for i in range(bboxlists.shape[1])]
57
+ bboxlists = [bboxlists[keep, i, :] for i, keep in enumerate(keeps)]
58
+ bboxlists = [[x for x in bboxlist if x[-1] > 0.5] for bboxlist in bboxlists]
59
+
60
+ return bboxlists
61
+
62
+ @property
63
+ def reference_scale(self):
64
+ return 195
65
+
66
+ @property
67
+ def reference_x_shift(self):
68
+ return 0
69
+
70
+ @property
71
+ def reference_y_shift(self):
72
+ return 0
wav2lip/face_detection/models.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+
6
+
7
+ def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False):
8
+ "3x3 convolution with padding"
9
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3,
10
+ stride=strd, padding=padding, bias=bias)
11
+
12
+
13
+ class ConvBlock(nn.Module):
14
+ def __init__(self, in_planes, out_planes):
15
+ super(ConvBlock, self).__init__()
16
+ self.bn1 = nn.BatchNorm2d(in_planes)
17
+ self.conv1 = conv3x3(in_planes, int(out_planes / 2))
18
+ self.bn2 = nn.BatchNorm2d(int(out_planes / 2))
19
+ self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4))
20
+ self.bn3 = nn.BatchNorm2d(int(out_planes / 4))
21
+ self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4))
22
+
23
+ if in_planes != out_planes:
24
+ self.downsample = nn.Sequential(
25
+ nn.BatchNorm2d(in_planes),
26
+ nn.ReLU(True),
27
+ nn.Conv2d(in_planes, out_planes,
28
+ kernel_size=1, stride=1, bias=False),
29
+ )
30
+ else:
31
+ self.downsample = None
32
+
33
+ def forward(self, x):
34
+ residual = x
35
+
36
+ out1 = self.bn1(x)
37
+ out1 = F.relu(out1, True)
38
+ out1 = self.conv1(out1)
39
+
40
+ out2 = self.bn2(out1)
41
+ out2 = F.relu(out2, True)
42
+ out2 = self.conv2(out2)
43
+
44
+ out3 = self.bn3(out2)
45
+ out3 = F.relu(out3, True)
46
+ out3 = self.conv3(out3)
47
+
48
+ out3 = torch.cat((out1, out2, out3), 1)
49
+
50
+ if self.downsample is not None:
51
+ residual = self.downsample(residual)
52
+
53
+ out3 += residual
54
+
55
+ return out3
56
+
57
+
58
+ class Bottleneck(nn.Module):
59
+
60
+ expansion = 4
61
+
62
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
63
+ super(Bottleneck, self).__init__()
64
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
65
+ self.bn1 = nn.BatchNorm2d(planes)
66
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
67
+ padding=1, bias=False)
68
+ self.bn2 = nn.BatchNorm2d(planes)
69
+ self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
70
+ self.bn3 = nn.BatchNorm2d(planes * 4)
71
+ self.relu = nn.ReLU(inplace=True)
72
+ self.downsample = downsample
73
+ self.stride = stride
74
+
75
+ def forward(self, x):
76
+ residual = x
77
+
78
+ out = self.conv1(x)
79
+ out = self.bn1(out)
80
+ out = self.relu(out)
81
+
82
+ out = self.conv2(out)
83
+ out = self.bn2(out)
84
+ out = self.relu(out)
85
+
86
+ out = self.conv3(out)
87
+ out = self.bn3(out)
88
+
89
+ if self.downsample is not None:
90
+ residual = self.downsample(x)
91
+
92
+ out += residual
93
+ out = self.relu(out)
94
+
95
+ return out
96
+
97
+
98
+ class HourGlass(nn.Module):
99
+ def __init__(self, num_modules, depth, num_features):
100
+ super(HourGlass, self).__init__()
101
+ self.num_modules = num_modules
102
+ self.depth = depth
103
+ self.features = num_features
104
+
105
+ self._generate_network(self.depth)
106
+
107
+ def _generate_network(self, level):
108
+ self.add_module('b1_' + str(level), ConvBlock(self.features, self.features))
109
+
110
+ self.add_module('b2_' + str(level), ConvBlock(self.features, self.features))
111
+
112
+ if level > 1:
113
+ self._generate_network(level - 1)
114
+ else:
115
+ self.add_module('b2_plus_' + str(level), ConvBlock(self.features, self.features))
116
+
117
+ self.add_module('b3_' + str(level), ConvBlock(self.features, self.features))
118
+
119
+ def _forward(self, level, inp):
120
+ # Upper branch
121
+ up1 = inp
122
+ up1 = self._modules['b1_' + str(level)](up1)
123
+
124
+ # Lower branch
125
+ low1 = F.avg_pool2d(inp, 2, stride=2)
126
+ low1 = self._modules['b2_' + str(level)](low1)
127
+
128
+ if level > 1:
129
+ low2 = self._forward(level - 1, low1)
130
+ else:
131
+ low2 = low1
132
+ low2 = self._modules['b2_plus_' + str(level)](low2)
133
+
134
+ low3 = low2
135
+ low3 = self._modules['b3_' + str(level)](low3)
136
+
137
+ up2 = F.interpolate(low3, scale_factor=2, mode='nearest')
138
+
139
+ return up1 + up2
140
+
141
+ def forward(self, x):
142
+ return self._forward(self.depth, x)
143
+
144
+
145
+ class FAN(nn.Module):
146
+
147
+ def __init__(self, num_modules=1):
148
+ super(FAN, self).__init__()
149
+ self.num_modules = num_modules
150
+
151
+ # Base part
152
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
153
+ self.bn1 = nn.BatchNorm2d(64)
154
+ self.conv2 = ConvBlock(64, 128)
155
+ self.conv3 = ConvBlock(128, 128)
156
+ self.conv4 = ConvBlock(128, 256)
157
+
158
+ # Stacking part
159
+ for hg_module in range(self.num_modules):
160
+ self.add_module('m' + str(hg_module), HourGlass(1, 4, 256))
161
+ self.add_module('top_m_' + str(hg_module), ConvBlock(256, 256))
162
+ self.add_module('conv_last' + str(hg_module),
163
+ nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
164
+ self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(256))
165
+ self.add_module('l' + str(hg_module), nn.Conv2d(256,
166
+ 68, kernel_size=1, stride=1, padding=0))
167
+
168
+ if hg_module < self.num_modules - 1:
169
+ self.add_module(
170
+ 'bl' + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
171
+ self.add_module('al' + str(hg_module), nn.Conv2d(68,
172
+ 256, kernel_size=1, stride=1, padding=0))
173
+
174
+ def forward(self, x):
175
+ x = F.relu(self.bn1(self.conv1(x)), True)
176
+ x = F.avg_pool2d(self.conv2(x), 2, stride=2)
177
+ x = self.conv3(x)
178
+ x = self.conv4(x)
179
+
180
+ previous = x
181
+
182
+ outputs = []
183
+ for i in range(self.num_modules):
184
+ hg = self._modules['m' + str(i)](previous)
185
+
186
+ ll = hg
187
+ ll = self._modules['top_m_' + str(i)](ll)
188
+
189
+ ll = F.relu(self._modules['bn_end' + str(i)]
190
+ (self._modules['conv_last' + str(i)](ll)), True)
191
+
192
+ # Predict heatmaps
193
+ tmp_out = self._modules['l' + str(i)](ll)
194
+ outputs.append(tmp_out)
195
+
196
+ if i < self.num_modules - 1:
197
+ ll = self._modules['bl' + str(i)](ll)
198
+ tmp_out_ = self._modules['al' + str(i)](tmp_out)
199
+ previous = previous + ll + tmp_out_
200
+
201
+ return outputs
202
+
203
+
204
+ class ResNetDepth(nn.Module):
205
+
206
+ def __init__(self, block=Bottleneck, layers=[3, 8, 36, 3], num_classes=68):
207
+ self.inplanes = 64
208
+ super(ResNetDepth, self).__init__()
209
+ self.conv1 = nn.Conv2d(3 + 68, 64, kernel_size=7, stride=2, padding=3,
210
+ bias=False)
211
+ self.bn1 = nn.BatchNorm2d(64)
212
+ self.relu = nn.ReLU(inplace=True)
213
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
214
+ self.layer1 = self._make_layer(block, 64, layers[0])
215
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
216
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
217
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
218
+ self.avgpool = nn.AvgPool2d(7)
219
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
220
+
221
+ for m in self.modules():
222
+ if isinstance(m, nn.Conv2d):
223
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
224
+ m.weight.data.normal_(0, math.sqrt(2. / n))
225
+ elif isinstance(m, nn.BatchNorm2d):
226
+ m.weight.data.fill_(1)
227
+ m.bias.data.zero_()
228
+
229
+ def _make_layer(self, block, planes, blocks, stride=1):
230
+ downsample = None
231
+ if stride != 1 or self.inplanes != planes * block.expansion:
232
+ downsample = nn.Sequential(
233
+ nn.Conv2d(self.inplanes, planes * block.expansion,
234
+ kernel_size=1, stride=stride, bias=False),
235
+ nn.BatchNorm2d(planes * block.expansion),
236
+ )
237
+
238
+ layers = []
239
+ layers.append(block(self.inplanes, planes, stride, downsample))
240
+ self.inplanes = planes * block.expansion
241
+ for i in range(1, blocks):
242
+ layers.append(block(self.inplanes, planes))
243
+
244
+ return nn.Sequential(*layers)
245
+
246
+ def forward(self, x):
247
+ x = self.conv1(x)
248
+ x = self.bn1(x)
249
+ x = self.relu(x)
250
+ x = self.maxpool(x)
251
+
252
+ x = self.layer1(x)
253
+ x = self.layer2(x)
254
+ x = self.layer3(x)
255
+ x = self.layer4(x)
256
+
257
+ x = self.avgpool(x)
258
+ x = x.view(x.size(0), -1)
259
+ x = self.fc(x)
260
+
261
+ return x
wav2lip/face_detection/utils.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ import os
3
+ import sys
4
+ import time
5
+ import torch
6
+ import math
7
+ import numpy as np
8
+ import cv2
9
+
10
+
11
+ def _gaussian(
12
+ size=3, sigma=0.25, amplitude=1, normalize=False, width=None,
13
+ height=None, sigma_horz=None, sigma_vert=None, mean_horz=0.5,
14
+ mean_vert=0.5):
15
+ # handle some defaults
16
+ if width is None:
17
+ width = size
18
+ if height is None:
19
+ height = size
20
+ if sigma_horz is None:
21
+ sigma_horz = sigma
22
+ if sigma_vert is None:
23
+ sigma_vert = sigma
24
+ center_x = mean_horz * width + 0.5
25
+ center_y = mean_vert * height + 0.5
26
+ gauss = np.empty((height, width), dtype=np.float32)
27
+ # generate kernel
28
+ for i in range(height):
29
+ for j in range(width):
30
+ gauss[i][j] = amplitude * math.exp(-(math.pow((j + 1 - center_x) / (
31
+ sigma_horz * width), 2) / 2.0 + math.pow((i + 1 - center_y) / (sigma_vert * height), 2) / 2.0))
32
+ if normalize:
33
+ gauss = gauss / np.sum(gauss)
34
+ return gauss
35
+
36
+
37
+ def draw_gaussian(image, point, sigma):
38
+ # Check if the gaussian is inside
39
+ ul = [math.floor(point[0] - 3 * sigma), math.floor(point[1] - 3 * sigma)]
40
+ br = [math.floor(point[0] + 3 * sigma), math.floor(point[1] + 3 * sigma)]
41
+ if (ul[0] > image.shape[1] or ul[1] > image.shape[0] or br[0] < 1 or br[1] < 1):
42
+ return image
43
+ size = 6 * sigma + 1
44
+ g = _gaussian(size)
45
+ g_x = [int(max(1, -ul[0])), int(min(br[0], image.shape[1])) - int(max(1, ul[0])) + int(max(1, -ul[0]))]
46
+ g_y = [int(max(1, -ul[1])), int(min(br[1], image.shape[0])) - int(max(1, ul[1])) + int(max(1, -ul[1]))]
47
+ img_x = [int(max(1, ul[0])), int(min(br[0], image.shape[1]))]
48
+ img_y = [int(max(1, ul[1])), int(min(br[1], image.shape[0]))]
49
+ assert (g_x[0] > 0 and g_y[1] > 0)
50
+ image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]
51
+ ] = image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]] + g[g_y[0] - 1:g_y[1], g_x[0] - 1:g_x[1]]
52
+ image[image > 1] = 1
53
+ return image
54
+
55
+
56
+ def transform(point, center, scale, resolution, invert=False):
57
+ """Generate and affine transformation matrix.
58
+
59
+ Given a set of points, a center, a scale and a targer resolution, the
60
+ function generates and affine transformation matrix. If invert is ``True``
61
+ it will produce the inverse transformation.
62
+
63
+ Arguments:
64
+ point {torch.tensor} -- the input 2D point
65
+ center {torch.tensor or numpy.array} -- the center around which to perform the transformations
66
+ scale {float} -- the scale of the face/object
67
+ resolution {float} -- the output resolution
68
+
69
+ Keyword Arguments:
70
+ invert {bool} -- define wherever the function should produce the direct or the
71
+ inverse transformation matrix (default: {False})
72
+ """
73
+ _pt = torch.ones(3)
74
+ _pt[0] = point[0]
75
+ _pt[1] = point[1]
76
+
77
+ h = 200.0 * scale
78
+ t = torch.eye(3)
79
+ t[0, 0] = resolution / h
80
+ t[1, 1] = resolution / h
81
+ t[0, 2] = resolution * (-center[0] / h + 0.5)
82
+ t[1, 2] = resolution * (-center[1] / h + 0.5)
83
+
84
+ if invert:
85
+ t = torch.inverse(t)
86
+
87
+ new_point = (torch.matmul(t, _pt))[0:2]
88
+
89
+ return new_point.int()
90
+
91
+
92
+ def crop(image, center, scale, resolution=256.0):
93
+ """Center crops an image or set of heatmaps
94
+
95
+ Arguments:
96
+ image {numpy.array} -- an rgb image
97
+ center {numpy.array} -- the center of the object, usually the same as of the bounding box
98
+ scale {float} -- scale of the face
99
+
100
+ Keyword Arguments:
101
+ resolution {float} -- the size of the output cropped image (default: {256.0})
102
+
103
+ Returns:
104
+ [type] -- [description]
105
+ """ # Crop around the center point
106
+ """ Crops the image around the center. Input is expected to be an np.ndarray """
107
+ ul = transform([1, 1], center, scale, resolution, True)
108
+ br = transform([resolution, resolution], center, scale, resolution, True)
109
+ # pad = math.ceil(torch.norm((ul - br).float()) / 2.0 - (br[0] - ul[0]) / 2.0)
110
+ if image.ndim > 2:
111
+ newDim = np.array([br[1] - ul[1], br[0] - ul[0],
112
+ image.shape[2]], dtype=np.int32)
113
+ newImg = np.zeros(newDim, dtype=np.uint8)
114
+ else:
115
+ newDim = np.array([br[1] - ul[1], br[0] - ul[0]], dtype=np.int)
116
+ newImg = np.zeros(newDim, dtype=np.uint8)
117
+ ht = image.shape[0]
118
+ wd = image.shape[1]
119
+ newX = np.array(
120
+ [max(1, -ul[0] + 1), min(br[0], wd) - ul[0]], dtype=np.int32)
121
+ newY = np.array(
122
+ [max(1, -ul[1] + 1), min(br[1], ht) - ul[1]], dtype=np.int32)
123
+ oldX = np.array([max(1, ul[0] + 1), min(br[0], wd)], dtype=np.int32)
124
+ oldY = np.array([max(1, ul[1] + 1), min(br[1], ht)], dtype=np.int32)
125
+ newImg[newY[0] - 1:newY[1], newX[0] - 1:newX[1]
126
+ ] = image[oldY[0] - 1:oldY[1], oldX[0] - 1:oldX[1], :]
127
+ newImg = cv2.resize(newImg, dsize=(int(resolution), int(resolution)),
128
+ interpolation=cv2.INTER_LINEAR)
129
+ return newImg
130
+
131
+
132
+ def get_preds_fromhm(hm, center=None, scale=None):
133
+ """Obtain (x,y) coordinates given a set of N heatmaps. If the center
134
+ and the scale is provided the function will return the points also in
135
+ the original coordinate frame.
136
+
137
+ Arguments:
138
+ hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H]
139
+
140
+ Keyword Arguments:
141
+ center {torch.tensor} -- the center of the bounding box (default: {None})
142
+ scale {float} -- face scale (default: {None})
143
+ """
144
+ max, idx = torch.max(
145
+ hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2)
146
+ idx += 1
147
+ preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float()
148
+ preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1)
149
+ preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1)
150
+
151
+ for i in range(preds.size(0)):
152
+ for j in range(preds.size(1)):
153
+ hm_ = hm[i, j, :]
154
+ pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1
155
+ if pX > 0 and pX < 63 and pY > 0 and pY < 63:
156
+ diff = torch.FloatTensor(
157
+ [hm_[pY, pX + 1] - hm_[pY, pX - 1],
158
+ hm_[pY + 1, pX] - hm_[pY - 1, pX]])
159
+ preds[i, j].add_(diff.sign_().mul_(.25))
160
+
161
+ preds.add_(-.5)
162
+
163
+ preds_orig = torch.zeros(preds.size())
164
+ if center is not None and scale is not None:
165
+ for i in range(hm.size(0)):
166
+ for j in range(hm.size(1)):
167
+ preds_orig[i, j] = transform(
168
+ preds[i, j], center, scale, hm.size(2), True)
169
+
170
+ return preds, preds_orig
171
+
172
+ def get_preds_fromhm_batch(hm, centers=None, scales=None):
173
+ """Obtain (x,y) coordinates given a set of N heatmaps. If the centers
174
+ and the scales is provided the function will return the points also in
175
+ the original coordinate frame.
176
+
177
+ Arguments:
178
+ hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H]
179
+
180
+ Keyword Arguments:
181
+ centers {torch.tensor} -- the centers of the bounding box (default: {None})
182
+ scales {float} -- face scales (default: {None})
183
+ """
184
+ max, idx = torch.max(
185
+ hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2)
186
+ idx += 1
187
+ preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float()
188
+ preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1)
189
+ preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1)
190
+
191
+ for i in range(preds.size(0)):
192
+ for j in range(preds.size(1)):
193
+ hm_ = hm[i, j, :]
194
+ pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1
195
+ if pX > 0 and pX < 63 and pY > 0 and pY < 63:
196
+ diff = torch.FloatTensor(
197
+ [hm_[pY, pX + 1] - hm_[pY, pX - 1],
198
+ hm_[pY + 1, pX] - hm_[pY - 1, pX]])
199
+ preds[i, j].add_(diff.sign_().mul_(.25))
200
+
201
+ preds.add_(-.5)
202
+
203
+ preds_orig = torch.zeros(preds.size())
204
+ if centers is not None and scales is not None:
205
+ for i in range(hm.size(0)):
206
+ for j in range(hm.size(1)):
207
+ preds_orig[i, j] = transform(
208
+ preds[i, j], centers[i], scales[i], hm.size(2), True)
209
+
210
+ return preds, preds_orig
211
+
212
+ def shuffle_lr(parts, pairs=None):
213
+ """Shuffle the points left-right according to the axis of symmetry
214
+ of the object.
215
+
216
+ Arguments:
217
+ parts {torch.tensor} -- a 3D or 4D object containing the
218
+ heatmaps.
219
+
220
+ Keyword Arguments:
221
+ pairs {list of integers} -- [order of the flipped points] (default: {None})
222
+ """
223
+ if pairs is None:
224
+ pairs = [16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0,
225
+ 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 27, 28, 29, 30, 35,
226
+ 34, 33, 32, 31, 45, 44, 43, 42, 47, 46, 39, 38, 37, 36, 41,
227
+ 40, 54, 53, 52, 51, 50, 49, 48, 59, 58, 57, 56, 55, 64, 63,
228
+ 62, 61, 60, 67, 66, 65]
229
+ if parts.ndimension() == 3:
230
+ parts = parts[pairs, ...]
231
+ else:
232
+ parts = parts[:, pairs, ...]
233
+
234
+ return parts
235
+
236
+
237
+ def flip(tensor, is_label=False):
238
+ """Flip an image or a set of heatmaps left-right
239
+
240
+ Arguments:
241
+ tensor {numpy.array or torch.tensor} -- [the input image or heatmaps]
242
+
243
+ Keyword Arguments:
244
+ is_label {bool} -- [denote wherever the input is an image or a set of heatmaps ] (default: {False})
245
+ """
246
+ if not torch.is_tensor(tensor):
247
+ tensor = torch.from_numpy(tensor)
248
+
249
+ if is_label:
250
+ tensor = shuffle_lr(tensor).flip(tensor.ndimension() - 1)
251
+ else:
252
+ tensor = tensor.flip(tensor.ndimension() - 1)
253
+
254
+ return tensor
255
+
256
+ # From pyzolib/paths.py (https://bitbucket.org/pyzo/pyzolib/src/tip/paths.py)
257
+
258
+
259
+ def appdata_dir(appname=None, roaming=False):
260
+ """ appdata_dir(appname=None, roaming=False)
261
+
262
+ Get the path to the application directory, where applications are allowed
263
+ to write user specific files (e.g. configurations). For non-user specific
264
+ data, consider using common_appdata_dir().
265
+ If appname is given, a subdir is appended (and created if necessary).
266
+ If roaming is True, will prefer a roaming directory (Windows Vista/7).
267
+ """
268
+
269
+ # Define default user directory
270
+ userDir = os.getenv('FACEALIGNMENT_USERDIR', None)
271
+ if userDir is None:
272
+ userDir = os.path.expanduser('~')
273
+ if not os.path.isdir(userDir): # pragma: no cover
274
+ userDir = '/var/tmp' # issue #54
275
+
276
+ # Get system app data dir
277
+ path = None
278
+ if sys.platform.startswith('win'):
279
+ path1, path2 = os.getenv('LOCALAPPDATA'), os.getenv('APPDATA')
280
+ path = (path2 or path1) if roaming else (path1 or path2)
281
+ elif sys.platform.startswith('darwin'):
282
+ path = os.path.join(userDir, 'Library', 'Application Support')
283
+ # On Linux and as fallback
284
+ if not (path and os.path.isdir(path)):
285
+ path = userDir
286
+
287
+ # Maybe we should store things local to the executable (in case of a
288
+ # portable distro or a frozen application that wants to be portable)
289
+ prefix = sys.prefix
290
+ if getattr(sys, 'frozen', None):
291
+ prefix = os.path.abspath(os.path.dirname(sys.executable))
292
+ for reldir in ('settings', '../settings'):
293
+ localpath = os.path.abspath(os.path.join(prefix, reldir))
294
+ if os.path.isdir(localpath): # pragma: no cover
295
+ try:
296
+ open(os.path.join(localpath, 'test.write'), 'wb').close()
297
+ os.remove(os.path.join(localpath, 'test.write'))
298
+ except IOError:
299
+ pass # We cannot write in this directory
300
+ else:
301
+ path = localpath
302
+ break
303
+
304
+ # Get path specific for this app
305
+ if appname:
306
+ if path == userDir:
307
+ appname = '.' + appname.lstrip('.') # Make it a hidden directory
308
+ path = os.path.join(path, appname)
309
+ if not os.path.isdir(path): # pragma: no cover
310
+ os.mkdir(path)
311
+
312
+ # Done
313
+ return path
wav2lip/hparams.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from glob import glob
2
+ import os
3
+
4
+ def get_image_list(data_root, split):
5
+ filelist = []
6
+
7
+ with open('filelists/{}.txt'.format(split)) as f:
8
+ for line in f:
9
+ line = line.strip()
10
+ if ' ' in line: line = line.split()[0]
11
+ filelist.append(os.path.join(data_root, line))
12
+
13
+ return filelist
14
+
15
+ class HParams:
16
+ def __init__(self, **kwargs):
17
+ self.data = {}
18
+
19
+ for key, value in kwargs.items():
20
+ self.data[key] = value
21
+
22
+ def __getattr__(self, key):
23
+ if key not in self.data:
24
+ raise AttributeError("'HParams' object has no attribute %s" % key)
25
+ return self.data[key]
26
+
27
+ def set_hparam(self, key, value):
28
+ self.data[key] = value
29
+
30
+
31
+ # Default hyperparameters
32
+ hparams = HParams(
33
+ num_mels=80, # Number of mel-spectrogram channels and local conditioning dimensionality
34
+ # network
35
+ rescale=True, # Whether to rescale audio prior to preprocessing
36
+ rescaling_max=0.9, # Rescaling value
37
+
38
+ # Use LWS (https://github.com/Jonathan-LeRoux/lws) for STFT and phase reconstruction
39
+ # It"s preferred to set True to use with https://github.com/r9y9/wavenet_vocoder
40
+ # Does not work if n_ffit is not multiple of hop_size!!
41
+ use_lws=False,
42
+
43
+ n_fft=800, # Extra window size is filled with 0 paddings to match this parameter
44
+ hop_size=200, # For 16000Hz, 200 = 12.5 ms (0.0125 * sample_rate)
45
+ win_size=800, # For 16000Hz, 800 = 50 ms (If None, win_size = n_fft) (0.05 * sample_rate)
46
+ sample_rate=16000, # 16000Hz (corresponding to librispeech) (sox --i <filename>)
47
+
48
+ frame_shift_ms=None, # Can replace hop_size parameter. (Recommended: 12.5)
49
+
50
+ # Mel and Linear spectrograms normalization/scaling and clipping
51
+ signal_normalization=True,
52
+ # Whether to normalize mel spectrograms to some predefined range (following below parameters)
53
+ allow_clipping_in_normalization=True, # Only relevant if mel_normalization = True
54
+ symmetric_mels=True,
55
+ # Whether to scale the data to be symmetric around 0. (Also multiplies the output range by 2,
56
+ # faster and cleaner convergence)
57
+ max_abs_value=4.,
58
+ # max absolute value of data. If symmetric, data will be [-max, max] else [0, max] (Must not
59
+ # be too big to avoid gradient explosion,
60
+ # not too small for fast convergence)
61
+ # Contribution by @begeekmyfriend
62
+ # Spectrogram Pre-Emphasis (Lfilter: Reduce spectrogram noise and helps model certitude
63
+ # levels. Also allows for better G&L phase reconstruction)
64
+ preemphasize=True, # whether to apply filter
65
+ preemphasis=0.97, # filter coefficient.
66
+
67
+ # Limits
68
+ min_level_db=-100,
69
+ ref_level_db=20,
70
+ fmin=55,
71
+ # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To
72
+ # test depending on dataset. Pitch info: male~[65, 260], female~[100, 525])
73
+ fmax=7600, # To be increased/reduced depending on data.
74
+
75
+ ###################### Our training parameters #################################
76
+ img_size=96,
77
+ fps=25,
78
+
79
+ batch_size=16,
80
+ initial_learning_rate=1e-4,
81
+ nepochs=200000000000000000, ### ctrl + c, stop whenever eval loss is consistently greater than train loss for ~10 epochs
82
+ num_workers=16,
83
+ checkpoint_interval=3000,
84
+ eval_interval=3000,
85
+ save_optimizer_state=True,
86
+
87
+ syncnet_wt=0.0, # is initially zero, will be set automatically to 0.03 later. Leads to faster convergence.
88
+ syncnet_batch_size=64,
89
+ syncnet_lr=1e-4,
90
+ syncnet_eval_interval=10000,
91
+ syncnet_checkpoint_interval=10000,
92
+
93
+ disc_wt=0.07,
94
+ disc_initial_learning_rate=1e-4,
95
+ )
96
+
97
+
98
+ def hparams_debug_string():
99
+ values = hparams.values()
100
+ hp = [" %s: %s" % (name, values[name]) for name in sorted(values) if name != "sentences"]
101
+ return "Hyperparameters:\n" + "\n".join(hp)
wav2lip/inference.py ADDED
@@ -0,0 +1,484 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ import os
4
+ import argparse
5
+ import subprocess
6
+ from tqdm import tqdm
7
+ import sys
8
+ import traceback
9
+ from .audio import load_wav, melspectrogram
10
+ from .face_detection import FaceAlignment, LandmarksType
11
+ import torch
12
+ import platform
13
+
14
+ parser = argparse.ArgumentParser(description='Inference code to lip-sync videos in the wild using Wav2Lip models')
15
+
16
+ parser.add_argument('--outfile', type=str, help='Video path to save result. See default for an e.g.',
17
+ default='wav2lip/results/result_voice.mp4')
18
+
19
+ parser.add_argument('--static', type=bool,
20
+ help='If True, then use only first video frame for inference', default=False)
21
+ parser.add_argument('--fps', type=float, help='Can be specified only if input is a static image (default: 25)',
22
+ default=25., required=False)
23
+
24
+ parser.add_argument('--pads', nargs='+', type=int, default=[0, 10, 0, 0],
25
+ help='Padding (top, bottom, left, right). Please adjust to include chin at least')
26
+
27
+ parser.add_argument('--face_det_batch_size', type=int,
28
+ help='Batch size for face detection', default=32)
29
+ parser.add_argument('--wav2lip_batch_size', type=int, help='Batch size for Wav2Lip model(s)', default=512)
30
+
31
+ parser.add_argument('--resize_factor', default=1, type=int,
32
+ help='Reduce the resolution by this factor. Sometimes, best results are obtained at 480p or 720p')
33
+
34
+ parser.add_argument('--crop', nargs='+', type=int, default=[0, -1, 0, -1],
35
+ help='Crop video to a smaller region (top, bottom, left, right). Applied after resize_factor and rotate arg. '
36
+ 'Useful if multiple face present. -1 implies the value will be auto-inferred based on height, width')
37
+
38
+ parser.add_argument('--box', nargs='+', type=int, default=[-1, -1, -1, -1],
39
+ help='Specify a constant bounding box for the face. Use only as a last resort if the face is not detected.'
40
+ 'Also, might work only if the face is not moving around much. Syntax: (top, bottom, left, right).')
41
+
42
+ parser.add_argument('--rotate', default=False, action='store_true',
43
+ help='Sometimes videos taken from a phone can be flipped 90deg. If true, will flip video right by 90deg.'
44
+ 'Use if you get a flipped result, despite feeding a normal looking video')
45
+
46
+ parser.add_argument('--nosmooth', default=False, action='store_true',
47
+ help='Prevent smoothing face detections over a short temporal window')
48
+
49
+ args = parser.parse_args()
50
+ args.img_size = 96
51
+
52
+ # Check for available devices
53
+ if torch.backends.mps.is_available():
54
+ device = 'mps' # Use Apple Silicon GPU
55
+ elif torch.cuda.is_available():
56
+ device = 'cuda'
57
+ else:
58
+ device = 'cpu'
59
+
60
+ print('Using {} for inference.'.format(device))
61
+
62
+ def get_smoothened_boxes(boxes, idx):
63
+ """Get smoothened box for a specific index"""
64
+ if idx >= len(boxes) or boxes[idx] is None:
65
+ return None, None
66
+
67
+ # Return the face region and coordinates
68
+ if isinstance(boxes[idx], list) and len(boxes[idx]) == 2: # Format from the specified bounding box
69
+ return boxes[idx][0], boxes[idx][1]
70
+ else: # Format from face detection - [x1, y1, x2, y2]
71
+ if isinstance(boxes[idx], list) or isinstance(boxes[idx], tuple):
72
+ if len(boxes[idx]) >= 4: # Make sure we have all 4 coordinates
73
+ x1, y1, x2, y2 = boxes[idx][:4]
74
+ # Return coordinates in the expected format (y1, y2, x1, x2)
75
+ coords = (y1, y2, x1, x2)
76
+ return None, coords
77
+
78
+ print(f"WARNING: Unexpected box format at idx {idx}: {boxes[idx]}")
79
+ return None, None
80
+
81
+ def face_detect(images):
82
+ print(f"Starting face detection using {device} device...")
83
+ try:
84
+ detector = FaceAlignment(LandmarksType._2D,
85
+ flip_input=False, device=device, verbose=True)
86
+ except Exception as e:
87
+ print(f"Error initializing face detector: {str(e)}")
88
+ print("Attempting to fall back to CPU for face detection...")
89
+ detector = FaceAlignment(LandmarksType._2D,
90
+ flip_input=False, device='cpu', verbose=True)
91
+
92
+ batch_size = args.face_det_batch_size
93
+
94
+ while 1:
95
+ predictions = []
96
+ try:
97
+ for i in range(0, len(images), batch_size):
98
+ batch = np.array(images[i:i + batch_size])
99
+ print(f"Processing detection batch {i//batch_size + 1}, shape: {batch.shape}")
100
+ batch_predictions = detector.get_detections_for_batch(batch)
101
+ predictions.extend(batch_predictions)
102
+ except RuntimeError as e:
103
+ print(f"Runtime error in face detection: {str(e)}")
104
+ if batch_size == 1:
105
+ # Error when batch_size is already 1
106
+ print('Face detection failed at minimum batch size! Using fallback method...')
107
+ # Create empty predictions for all frames to allow processing to continue
108
+ predictions = [None] * len(images)
109
+ break
110
+ batch_size //= 2
111
+ print('Reducing face detection batch size to', batch_size)
112
+ continue
113
+ except Exception as e:
114
+ print(f"Unexpected error in face detection: {str(e)}")
115
+ # Create empty predictions and continue with fallback
116
+ predictions = [None] * len(images)
117
+ break
118
+ break
119
+
120
+ # Check if we have at least one valid face detection
121
+ faces_detected = sum(1 for p in predictions if p is not None)
122
+ print(f"Detected faces in {faces_detected} out of {len(images)} frames ({faces_detected/len(images)*100:.1f}%)")
123
+
124
+ results = []
125
+ pady1, pady2, padx1, padx2 = args.pads
126
+
127
+ for i, (rect, image) in enumerate(zip(predictions, images)):
128
+ if rect is None:
129
+ # Create default coordinates for face detection
130
+ h, w = image.shape[:2]
131
+
132
+ # Simple and consistent face region estimation based on center of the frame
133
+ center_x = w // 2
134
+ center_y = h // 2
135
+
136
+ # Use about 1/3 of the frame height for face
137
+ face_h = h // 3
138
+ face_w = min(w // 2, face_h)
139
+
140
+ # Create a centered box
141
+ x1 = max(0, center_x - face_w // 2 - padx1)
142
+ y1 = max(0, center_y - face_h // 2 - pady1)
143
+ x2 = min(w, center_x + face_w // 2 + padx2)
144
+ y2 = min(h, center_y + face_h // 2 + pady2)
145
+
146
+ if i == 0 or i % 100 == 0: # Log only occasionally to avoid flooding
147
+ print(f"Frame {i}: Using fallback face region at ({x1},{y1},{x2},{y2})")
148
+
149
+ results.append([x1, y1, x2, y2])
150
+ continue
151
+
152
+ # If face is detected, use its coordinates with padding
153
+ y1 = max(0, rect[1] - pady1)
154
+ y2 = min(image.shape[0], rect[3] + pady2)
155
+ x1 = max(0, rect[0] - padx1)
156
+ x2 = min(image.shape[1], rect[2] + padx2)
157
+
158
+ results.append([x1, y1, x2, y2])
159
+
160
+ return results
161
+
162
+ def datagen(frames, mels):
163
+ img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
164
+
165
+ if args.box[0] == -1:
166
+ if not args.static:
167
+ try:
168
+ print(f"Starting face detection for {len(frames)} frames...")
169
+ face_det_results = face_detect(frames) # BGR2RGB for CNN face detection
170
+ print("Face detection completed successfully")
171
+ except Exception as e:
172
+ print(f"Face detection error: {str(e)}")
173
+ print(f"Error type: {type(e).__name__}")
174
+ traceback.print_exc()
175
+ print("Using fallback method with default face regions...")
176
+ # Create default face regions for all frames
177
+ h, w = frames[0].shape[:2]
178
+
179
+ # Simple face region estimation in the center of the frame
180
+ center_x = w // 2
181
+ center_y = h // 2
182
+
183
+ # Use about 1/3 of the frame height for face
184
+ face_h = h // 3
185
+ face_w = min(w // 2, face_h)
186
+
187
+ pady1, pady2, padx1, padx2 = args.pads
188
+ x1 = max(0, center_x - face_w // 2 - padx1)
189
+ y1 = max(0, center_y - face_h // 2 - pady1)
190
+ x2 = min(w, center_x + face_w // 2 + padx2)
191
+ y2 = min(h, center_y + face_h // 2 + pady2)
192
+
193
+ print(f"Estimated face region: x1={x1}, y1={y1}, x2={x2}, y2={y2}")
194
+
195
+ # Use the same format as the face_detect function returns
196
+ face_det_results = [[x1, y1, x2, y2] for _ in range(len(frames))]
197
+ else:
198
+ try:
199
+ print("Starting face detection for static image...")
200
+ face_det_results = face_detect([frames[0]])
201
+ print("Face detection completed successfully")
202
+ except Exception as e:
203
+ print(f"Face detection error: {str(e)}")
204
+ print(f"Error type: {type(e).__name__}")
205
+ traceback.print_exc()
206
+ print("Using fallback method with default face region...")
207
+ # Create default face region for static image
208
+ h, w = frames[0].shape[:2]
209
+
210
+ # Simple face region estimation in the center of the frame
211
+ center_x = w // 2
212
+ center_y = h // 2
213
+
214
+ # Use about 1/3 of the frame height for face
215
+ face_h = h // 3
216
+ face_w = min(w // 2, face_h)
217
+
218
+ pady1, pady2, padx1, padx2 = args.pads
219
+ x1 = max(0, center_x - face_w // 2 - padx1)
220
+ y1 = max(0, center_y - face_h // 2 - pady1)
221
+ x2 = min(w, center_x + face_w // 2 + padx2)
222
+ y2 = min(h, center_y + face_h // 2 + pady2)
223
+
224
+ print(f"Estimated face region for static image: x1={x1}, y1={y1}, x2={x2}, y2={y2}")
225
+
226
+ # Use the same format as the face_detect function returns
227
+ face_det_results = [[x1, y1, x2, y2]]
228
+ else:
229
+ print('Using the specified bounding box instead of face detection...')
230
+ y1, y2, x1, x2 = args.box
231
+ face_det_results = [[x1, y1, x2, y2] for _ in range(len(frames))]
232
+
233
+ for i, m in enumerate(mels):
234
+ idx = 0 if args.static else i%len(frames)
235
+ frame_to_save = frames[idx].copy()
236
+
237
+ if args.box[0] == -1:
238
+ face, coords = get_smoothened_boxes(face_det_results, idx)
239
+
240
+ if coords is None:
241
+ print(f'Face coordinates not detected! Skipping frame {i}')
242
+ continue
243
+
244
+ # If face is None, extract it from the frame using coordinates
245
+ if face is None:
246
+ y1, y2, x1, x2 = coords
247
+ try:
248
+ if y1 >= y2 or x1 >= x2:
249
+ print(f"Invalid coordinates at frame {i}: y1={y1}, y2={y2}, x1={x1}, x2={x2}")
250
+ continue
251
+ if y1 < 0 or x1 < 0 or y2 > frame_to_save.shape[0] or x2 > frame_to_save.shape[1]:
252
+ print(f"Out of bounds coordinates at frame {i}. Adjusting...")
253
+ y1 = max(0, y1)
254
+ x1 = max(0, x1)
255
+ y2 = min(frame_to_save.shape[0], y2)
256
+ x2 = min(frame_to_save.shape[1], x2)
257
+
258
+ # Check if the region is too small
259
+ if (y2 - y1) < 10 or (x2 - x1) < 10:
260
+ print(f"Region too small at frame {i}. Skipping.")
261
+ continue
262
+
263
+ face = frames[idx][y1:y2, x1:x2]
264
+ except Exception as e:
265
+ print(f"Error extracting face at frame {i}: {str(e)}")
266
+ continue
267
+ else:
268
+ face = frames[idx][y1:y2, x1:x2]
269
+ coords = (y1, y2, x1, x2)
270
+
271
+ try:
272
+ face = cv2.resize(face, (args.img_size, args.img_size))
273
+ img_batch.append(face)
274
+ mel_batch.append(m)
275
+ frame_batch.append(frame_to_save)
276
+ coords_batch.append(coords)
277
+ except Exception as e:
278
+ print(f"Error processing frame {i}: {str(e)}")
279
+ continue
280
+
281
+ if len(img_batch) >= args.wav2lip_batch_size:
282
+ img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
283
+
284
+ img_masked = img_batch.copy()
285
+ img_masked[:, args.img_size//2:] = 0
286
+
287
+ img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
288
+ mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
289
+
290
+ yield img_batch, mel_batch, frame_batch, coords_batch
291
+ img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
292
+
293
+ if len(img_batch) > 0:
294
+ img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
295
+
296
+ img_masked = img_batch.copy()
297
+ img_masked[:, args.img_size//2:] = 0
298
+
299
+ img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
300
+ mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
301
+
302
+ yield img_batch, mel_batch, frame_batch, coords_batch
303
+
304
+ mel_step_size = 16
305
+
306
+ def _load(checkpoint_path):
307
+ # Handle loading for different devices
308
+ checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
309
+ return checkpoint
310
+
311
+
312
+ def main(face, audio, model, slow_mode=False):
313
+ if slow_mode:
314
+ print("Using SLOW animation mode (full face animation)")
315
+ else:
316
+ print("Using FAST animation mode (lips only)")
317
+
318
+ if not os.path.isfile(face):
319
+ raise ValueError('--face argument must be a valid path to video/image file')
320
+
321
+ elif face.split('.')[1] in ['jpg', 'png', 'jpeg'] and not slow_mode:
322
+ full_frames = [cv2.imread(face)]
323
+ fps = args.fps
324
+
325
+ else:
326
+ video_stream = cv2.VideoCapture(face)
327
+ fps = video_stream.get(cv2.CAP_PROP_FPS)
328
+
329
+ # Get video dimensions for potential downscaling of large videos
330
+ frame_width = int(video_stream.get(cv2.CAP_PROP_FRAME_WIDTH))
331
+ frame_height = int(video_stream.get(cv2.CAP_PROP_FRAME_HEIGHT))
332
+ total_frames = int(video_stream.get(cv2.CAP_PROP_FRAME_COUNT))
333
+
334
+ # Auto-adjust resize factor for very large videos
335
+ original_resize_factor = args.resize_factor
336
+ if frame_width > 1920 or frame_height > 1080:
337
+ # For 4K or larger videos, use a higher resize factor
338
+ if frame_width >= 3840 or frame_height >= 2160:
339
+ args.resize_factor = max(4, args.resize_factor)
340
+ print(f"Auto-adjusting resize factor to {args.resize_factor} for high-resolution video")
341
+ # For 1080p-4K videos
342
+ elif frame_width > 1920 or frame_height > 1080:
343
+ args.resize_factor = max(2, args.resize_factor)
344
+ print(f"Auto-adjusting resize factor to {args.resize_factor} for high-resolution video")
345
+
346
+ print('Reading video frames...')
347
+
348
+ full_frames = []
349
+
350
+ # For large videos, report progress and limit memory usage
351
+ frame_limit = 5000 # Maximum number of frames to process at once
352
+ if total_frames > frame_limit:
353
+ print(f"Large video detected ({total_frames} frames). Will process in chunks.")
354
+
355
+ # Use tqdm for progress reporting
356
+ pbar = tqdm(total=min(total_frames, frame_limit))
357
+ frame_count = 0
358
+
359
+ while frame_count < frame_limit:
360
+ still_reading, frame = video_stream.read()
361
+ if not still_reading:
362
+ video_stream.release()
363
+ break
364
+
365
+ if args.resize_factor > 1:
366
+ frame = cv2.resize(frame, (frame.shape[1]//args.resize_factor, frame.shape[0]//args.resize_factor))
367
+
368
+ if args.rotate:
369
+ frame = cv2.rotate(frame, cv2.cv2.ROTATE_90_CLOCKWISE)
370
+
371
+ y1, y2, x1, x2 = args.crop
372
+ if x2 == -1: x2 = frame.shape[1]
373
+ if y2 == -1: y2 = frame.shape[0]
374
+
375
+ frame = frame[y1:y2, x1:x2]
376
+
377
+ full_frames.append(frame)
378
+ frame_count += 1
379
+ pbar.update(1)
380
+
381
+ # For very large videos, limit frames to avoid memory issues
382
+ if frame_count >= frame_limit:
383
+ print(f"Reached frame limit of {frame_limit}. Processing this chunk.")
384
+ break
385
+
386
+ pbar.close()
387
+
388
+ # Reset resize factor to original value after processing
389
+ args.resize_factor = original_resize_factor
390
+
391
+ print ("Number of frames available for inference: "+str(len(full_frames)))
392
+
393
+ if not audio.endswith('.wav'):
394
+ print('Extracting raw audio...')
395
+ command = 'ffmpeg -y -i {} -strict -2 {}'.format(audio, 'temp/temp.wav')
396
+
397
+ subprocess.call(command, shell=True)
398
+ audio = 'temp/temp.wav'
399
+
400
+ wav = load_wav(audio, 16000)
401
+ mel = melspectrogram(wav)
402
+ print(mel.shape)
403
+
404
+ if np.isnan(mel.reshape(-1)).sum() > 0:
405
+ raise ValueError('Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again')
406
+
407
+ mel_chunks = []
408
+ mel_idx_multiplier = 80./fps
409
+ i = 0
410
+ while 1:
411
+ start_idx = int(i * mel_idx_multiplier)
412
+ if start_idx + mel_step_size > len(mel[0]):
413
+ mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:])
414
+ break
415
+ mel_chunks.append(mel[:, start_idx : start_idx + mel_step_size])
416
+ i += 1
417
+
418
+ print("Length of mel chunks: {}".format(len(mel_chunks)))
419
+
420
+ full_frames = full_frames[:len(mel_chunks)]
421
+
422
+ batch_size = args.wav2lip_batch_size
423
+ gen = datagen(full_frames.copy(), mel_chunks)
424
+
425
+ # Initialize video writer outside the try block
426
+ out = None
427
+ try:
428
+ for i, (img_batch, mel_batch, frames, coords) in enumerate(tqdm(gen,
429
+ total=int(np.ceil(float(len(mel_chunks))/args.wav2lip_batch_size)))):
430
+ if i == 0:
431
+ #model = load_model(checkpoint_path)
432
+ print ("Model loaded")
433
+
434
+ frame_h, frame_w = full_frames[0].shape[:-1]
435
+ out = cv2.VideoWriter('wav2lip/temp/result.avi',
436
+ cv2.VideoWriter_fourcc(*'DIVX'), fps, (frame_w, frame_h))
437
+
438
+ img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
439
+ mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
440
+
441
+ with torch.no_grad():
442
+ pred = model(mel_batch, img_batch)
443
+
444
+ pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
445
+
446
+ for p, f, c in zip(pred, frames, coords):
447
+ y1, y2, x1, x2 = c
448
+ p = cv2.resize(p.astype(np.uint8), (x2 - x1, y2 - y1))
449
+
450
+ f[y1:y2, x1:x2] = p
451
+ out.write(f)
452
+ except Exception as e:
453
+ print(f"Error during processing: {str(e)}")
454
+ print("Attempting to save any completed frames...")
455
+
456
+ # Save the results - only if out was initialized
457
+ if out is not None:
458
+ out.release()
459
+
460
+ # Convert the output video to MP4 if needed - only if the AVI exists
461
+ result_path = 'wav2lip/results/result_voice.mp4'
462
+ if os.path.exists('wav2lip/temp/result.avi'):
463
+ # Check if the result file is valid (has frames)
464
+ avi_info = os.stat('wav2lip/temp/result.avi')
465
+ if avi_info.st_size > 1000: # If file is too small, it's likely empty
466
+ # Modified command to include the audio file
467
+ command = 'ffmpeg -y -i {} -i {} -c:v libx264 -preset ultrafast -c:a aac -map 0:v:0 -map 1:a:0 {}'.format(
468
+ 'wav2lip/temp/result.avi', audio, result_path)
469
+ try:
470
+ subprocess.call(command, shell=True)
471
+ if os.path.exists(result_path):
472
+ print(f"Successfully created output video with audio at {result_path}")
473
+ else:
474
+ print(f"Error: Output video file was not created.")
475
+ except Exception as e:
476
+ print(f"Error during video conversion: {str(e)}")
477
+ else:
478
+ print(f"Warning: Output AVI file is too small ({avi_info.st_size} bytes). Face detection may have failed.")
479
+ else:
480
+ print("No output video was created. Face detection likely failed completely.")
481
+ # Return a default path even if no output was created
482
+
483
+ # Return even if there were errors
484
+ return result_path
wav2lip/models/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .wav2lip import Wav2Lip, Wav2Lip_disc_qual
2
+ from .syncnet import SyncNet_color
wav2lip/models/conv.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+
5
+ class Conv2d(nn.Module):
6
+ def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
7
+ super().__init__(*args, **kwargs)
8
+ self.conv_block = nn.Sequential(
9
+ nn.Conv2d(cin, cout, kernel_size, stride, padding),
10
+ nn.BatchNorm2d(cout)
11
+ )
12
+ self.act = nn.ReLU()
13
+ self.residual = residual
14
+
15
+ def forward(self, x):
16
+ out = self.conv_block(x)
17
+ if self.residual:
18
+ out += x
19
+ return self.act(out)
20
+
21
+ class nonorm_Conv2d(nn.Module):
22
+ def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
23
+ super().__init__(*args, **kwargs)
24
+ self.conv_block = nn.Sequential(
25
+ nn.Conv2d(cin, cout, kernel_size, stride, padding),
26
+ )
27
+ self.act = nn.LeakyReLU(0.01, inplace=True)
28
+
29
+ def forward(self, x):
30
+ out = self.conv_block(x)
31
+ return self.act(out)
32
+
33
+ class Conv2dTranspose(nn.Module):
34
+ def __init__(self, cin, cout, kernel_size, stride, padding, output_padding=0, *args, **kwargs):
35
+ super().__init__(*args, **kwargs)
36
+ self.conv_block = nn.Sequential(
37
+ nn.ConvTranspose2d(cin, cout, kernel_size, stride, padding, output_padding),
38
+ nn.BatchNorm2d(cout)
39
+ )
40
+ self.act = nn.ReLU()
41
+
42
+ def forward(self, x):
43
+ out = self.conv_block(x)
44
+ return self.act(out)
wav2lip/models/syncnet.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+
5
+ from .conv import Conv2d
6
+
7
+ class SyncNet_color(nn.Module):
8
+ def __init__(self):
9
+ super(SyncNet_color, self).__init__()
10
+
11
+ self.face_encoder = nn.Sequential(
12
+ Conv2d(15, 32, kernel_size=(7, 7), stride=1, padding=3),
13
+
14
+ Conv2d(32, 64, kernel_size=5, stride=(1, 2), padding=1),
15
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
16
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
17
+
18
+ Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
19
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
20
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
21
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
22
+
23
+ Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
24
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
25
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
26
+
27
+ Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
28
+ Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
29
+ Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
30
+
31
+ Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
32
+ Conv2d(512, 512, kernel_size=3, stride=1, padding=0),
33
+ Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
34
+
35
+ self.audio_encoder = nn.Sequential(
36
+ Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
37
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
38
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
39
+
40
+ Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
41
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
42
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
43
+
44
+ Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
45
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
46
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
47
+
48
+ Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
49
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
50
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
51
+
52
+ Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
53
+ Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
54
+
55
+ def forward(self, audio_sequences, face_sequences): # audio_sequences := (B, dim, T)
56
+ face_embedding = self.face_encoder(face_sequences)
57
+ audio_embedding = self.audio_encoder(audio_sequences)
58
+
59
+ audio_embedding = audio_embedding.view(audio_embedding.size(0), -1)
60
+ face_embedding = face_embedding.view(face_embedding.size(0), -1)
61
+
62
+ audio_embedding = F.normalize(audio_embedding, p=2, dim=1)
63
+ face_embedding = F.normalize(face_embedding, p=2, dim=1)
64
+
65
+
66
+ return audio_embedding, face_embedding
wav2lip/models/wav2lip.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+ import math
5
+
6
+ from .conv import Conv2dTranspose, Conv2d, nonorm_Conv2d
7
+
8
+ class Wav2Lip(nn.Module):
9
+ def __init__(self):
10
+ super(Wav2Lip, self).__init__()
11
+
12
+ self.face_encoder_blocks = nn.ModuleList([
13
+ nn.Sequential(Conv2d(6, 16, kernel_size=7, stride=1, padding=3)), # 96,96
14
+
15
+ nn.Sequential(Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # 48,48
16
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
17
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True)),
18
+
19
+ nn.Sequential(Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # 24,24
20
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
21
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
22
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True)),
23
+
24
+ nn.Sequential(Conv2d(64, 128, kernel_size=3, stride=2, padding=1), # 12,12
25
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
26
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True)),
27
+
28
+ nn.Sequential(Conv2d(128, 256, kernel_size=3, stride=2, padding=1), # 6,6
29
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
30
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True)),
31
+
32
+ nn.Sequential(Conv2d(256, 512, kernel_size=3, stride=2, padding=1), # 3,3
33
+ Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),),
34
+
35
+ nn.Sequential(Conv2d(512, 512, kernel_size=3, stride=1, padding=0), # 1, 1
36
+ Conv2d(512, 512, kernel_size=1, stride=1, padding=0)),])
37
+
38
+ self.audio_encoder = nn.Sequential(
39
+ Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
40
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
41
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
42
+
43
+ Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
44
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
45
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
46
+
47
+ Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
48
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
49
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
50
+
51
+ Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
52
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
53
+
54
+ Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
55
+ Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
56
+
57
+ self.face_decoder_blocks = nn.ModuleList([
58
+ nn.Sequential(Conv2d(512, 512, kernel_size=1, stride=1, padding=0),),
59
+
60
+ nn.Sequential(Conv2dTranspose(1024, 512, kernel_size=3, stride=1, padding=0), # 3,3
61
+ Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),),
62
+
63
+ nn.Sequential(Conv2dTranspose(1024, 512, kernel_size=3, stride=2, padding=1, output_padding=1),
64
+ Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
65
+ Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),), # 6, 6
66
+
67
+ nn.Sequential(Conv2dTranspose(768, 384, kernel_size=3, stride=2, padding=1, output_padding=1),
68
+ Conv2d(384, 384, kernel_size=3, stride=1, padding=1, residual=True),
69
+ Conv2d(384, 384, kernel_size=3, stride=1, padding=1, residual=True),), # 12, 12
70
+
71
+ nn.Sequential(Conv2dTranspose(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1),
72
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
73
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),), # 24, 24
74
+
75
+ nn.Sequential(Conv2dTranspose(320, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
76
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
77
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),), # 48, 48
78
+
79
+ nn.Sequential(Conv2dTranspose(160, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
80
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
81
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),),]) # 96,96
82
+
83
+ self.output_block = nn.Sequential(Conv2d(80, 32, kernel_size=3, stride=1, padding=1),
84
+ nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0),
85
+ nn.Sigmoid())
86
+
87
+ def forward(self, audio_sequences, face_sequences):
88
+ # audio_sequences = (B, T, 1, 80, 16)
89
+ B = audio_sequences.size(0)
90
+
91
+ input_dim_size = len(face_sequences.size())
92
+ if input_dim_size > 4:
93
+ audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0)
94
+ face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0)
95
+
96
+ audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1
97
+
98
+ feats = []
99
+ x = face_sequences
100
+ for f in self.face_encoder_blocks:
101
+ x = f(x)
102
+ feats.append(x)
103
+
104
+ x = audio_embedding
105
+ for f in self.face_decoder_blocks:
106
+ x = f(x)
107
+ try:
108
+ x = torch.cat((x, feats[-1]), dim=1)
109
+ except Exception as e:
110
+ print(x.size())
111
+ print(feats[-1].size())
112
+ raise e
113
+
114
+ feats.pop()
115
+
116
+ x = self.output_block(x)
117
+
118
+ if input_dim_size > 4:
119
+ x = torch.split(x, B, dim=0) # [(B, C, H, W)]
120
+ outputs = torch.stack(x, dim=2) # (B, C, T, H, W)
121
+
122
+ else:
123
+ outputs = x
124
+
125
+ return outputs
126
+
127
+ class Wav2Lip_disc_qual(nn.Module):
128
+ def __init__(self):
129
+ super(Wav2Lip_disc_qual, self).__init__()
130
+
131
+ self.face_encoder_blocks = nn.ModuleList([
132
+ nn.Sequential(nonorm_Conv2d(3, 32, kernel_size=7, stride=1, padding=3)), # 48,96
133
+
134
+ nn.Sequential(nonorm_Conv2d(32, 64, kernel_size=5, stride=(1, 2), padding=2), # 48,48
135
+ nonorm_Conv2d(64, 64, kernel_size=5, stride=1, padding=2)),
136
+
137
+ nn.Sequential(nonorm_Conv2d(64, 128, kernel_size=5, stride=2, padding=2), # 24,24
138
+ nonorm_Conv2d(128, 128, kernel_size=5, stride=1, padding=2)),
139
+
140
+ nn.Sequential(nonorm_Conv2d(128, 256, kernel_size=5, stride=2, padding=2), # 12,12
141
+ nonorm_Conv2d(256, 256, kernel_size=5, stride=1, padding=2)),
142
+
143
+ nn.Sequential(nonorm_Conv2d(256, 512, kernel_size=3, stride=2, padding=1), # 6,6
144
+ nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=1)),
145
+
146
+ nn.Sequential(nonorm_Conv2d(512, 512, kernel_size=3, stride=2, padding=1), # 3,3
147
+ nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=1),),
148
+
149
+ nn.Sequential(nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=0), # 1, 1
150
+ nonorm_Conv2d(512, 512, kernel_size=1, stride=1, padding=0)),])
151
+
152
+ self.binary_pred = nn.Sequential(nn.Conv2d(512, 1, kernel_size=1, stride=1, padding=0), nn.Sigmoid())
153
+ self.label_noise = .0
154
+
155
+ def get_lower_half(self, face_sequences):
156
+ return face_sequences[:, :, face_sequences.size(2)//2:]
157
+
158
+ def to_2d(self, face_sequences):
159
+ B = face_sequences.size(0)
160
+ face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0)
161
+ return face_sequences
162
+
163
+ def perceptual_forward(self, false_face_sequences):
164
+ false_face_sequences = self.to_2d(false_face_sequences)
165
+ false_face_sequences = self.get_lower_half(false_face_sequences)
166
+
167
+ false_feats = false_face_sequences
168
+ for f in self.face_encoder_blocks:
169
+ false_feats = f(false_feats)
170
+
171
+ false_pred_loss = F.binary_cross_entropy(self.binary_pred(false_feats).view(len(false_feats), -1),
172
+ torch.ones((len(false_feats), 1)).cuda())
173
+
174
+ return false_pred_loss
175
+
176
+ def forward(self, face_sequences):
177
+ face_sequences = self.to_2d(face_sequences)
178
+ face_sequences = self.get_lower_half(face_sequences)
179
+
180
+ x = face_sequences
181
+ for f in self.face_encoder_blocks:
182
+ x = f(x)
183
+
184
+ return self.binary_pred(x).view(len(x), -1)