Spaces:
Running
Running
Commit
·
14d3449
0
Parent(s):
Initial commit with essential files
Browse files- .streamlit/config.toml +4 -0
- README.md +246 -0
- app.py +1000 -0
- avatars_images/avatar1.jpg +0 -0
- avatars_images/avatar2.jpg +0 -0
- avatars_images/avatar3.png +0 -0
- packages.txt +3 -0
- requirements.txt +12 -0
- wav2lip/audio.py +136 -0
- wav2lip/face_detection/README.md +1 -0
- wav2lip/face_detection/__init__.py +7 -0
- wav2lip/face_detection/api.py +119 -0
- wav2lip/face_detection/detection/__init__.py +1 -0
- wav2lip/face_detection/detection/core.py +134 -0
- wav2lip/face_detection/detection/sfd/__init__.py +1 -0
- wav2lip/face_detection/detection/sfd/bbox.py +129 -0
- wav2lip/face_detection/detection/sfd/detect.py +112 -0
- wav2lip/face_detection/detection/sfd/net_s3fd.py +129 -0
- wav2lip/face_detection/detection/sfd/sfd_detector.py +72 -0
- wav2lip/face_detection/models.py +261 -0
- wav2lip/face_detection/utils.py +313 -0
- wav2lip/hparams.py +101 -0
- wav2lip/inference.py +484 -0
- wav2lip/models/__init__.py +2 -0
- wav2lip/models/conv.py +44 -0
- wav2lip/models/syncnet.py +66 -0
- wav2lip/models/wav2lip.py +184 -0
.streamlit/config.toml
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[theme]
|
2 |
+
base="dark"
|
3 |
+
primaryColor="#865bf1"
|
4 |
+
font="monospace"
|
README.md
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# AI Lip Sync
|
2 |
+
|
3 |
+

|
4 |
+
|
5 |
+
An AI-powered application that synchronizes lip movements with audio input, built with Wav2Lip and Streamlit.
|
6 |
+
|
7 |
+
## Features
|
8 |
+
|
9 |
+
- **Multiple Avatar Options**: Choose from built-in avatars or upload your own image/video
|
10 |
+
- **Audio Input Flexibility**: Record audio directly or upload WAV/MP3 files
|
11 |
+
- **Quality Assessment**: Automatic analysis of video and audio quality with recommendations
|
12 |
+
- **GPU Acceleration**: Optimized for Apple Silicon (M1/M2) GPUs
|
13 |
+
- **Two Animation Modes**: Fast (lips only) or Slow (full face animation)
|
14 |
+
- **Video Trimming**: Trim the output video to remove unwanted portions
|
15 |
+
|
16 |
+
## Quick Setup Guide
|
17 |
+
|
18 |
+
### Prerequisites
|
19 |
+
|
20 |
+
- Python 3.9+
|
21 |
+
- ffmpeg (for audio processing)
|
22 |
+
- Git LFS (optional, for handling large model files)
|
23 |
+
|
24 |
+
### Installation
|
25 |
+
|
26 |
+
1. Clone the repository:
|
27 |
+
```bash
|
28 |
+
git clone https://github.com/yourusername/ai-lip-sync-app.git
|
29 |
+
cd ai-lip-sync-app
|
30 |
+
```
|
31 |
+
|
32 |
+
2. Create and activate a virtual environment:
|
33 |
+
```bash
|
34 |
+
python -m venv .venv
|
35 |
+
# On macOS/Linux
|
36 |
+
source .venv/bin/activate
|
37 |
+
# On Windows
|
38 |
+
.venv\Scripts\activate
|
39 |
+
```
|
40 |
+
|
41 |
+
3. Install Python dependencies:
|
42 |
+
```bash
|
43 |
+
pip install -r requirements.txt
|
44 |
+
```
|
45 |
+
|
46 |
+
4. Install system dependencies:
|
47 |
+
```bash
|
48 |
+
# On Ubuntu/Debian
|
49 |
+
sudo apt-get update
|
50 |
+
sudo apt-get install $(cat packages.txt)
|
51 |
+
|
52 |
+
# On macOS with Homebrew
|
53 |
+
brew install ffmpeg
|
54 |
+
```
|
55 |
+
|
56 |
+
5. Run the application:
|
57 |
+
```bash
|
58 |
+
python -m streamlit run app.py
|
59 |
+
```
|
60 |
+
|
61 |
+
> **Note**: If you encounter a "streamlit: command not found" error, always use `python -m streamlit run app.py` instead of `streamlit run app.py`
|
62 |
+
|
63 |
+
The application will automatically download the required model files on first run.
|
64 |
+
|
65 |
+
## Usage Guide
|
66 |
+
|
67 |
+
1. **Choose Avatar Source**:
|
68 |
+
- Select from built-in avatars or upload your own image/video
|
69 |
+
- For best results, use clear frontal face images/videos
|
70 |
+
|
71 |
+
2. **Provide Audio**:
|
72 |
+
- Record directly using your microphone
|
73 |
+
- Upload WAV or MP3 files
|
74 |
+
|
75 |
+
3. **Quality Assessment**:
|
76 |
+
- The app will automatically analyze your uploaded video and audio
|
77 |
+
- Review the quality analysis and recommendations
|
78 |
+
- Make adjustments if needed for better results
|
79 |
+
|
80 |
+
4. **Generate Animation**:
|
81 |
+
- Choose "Fast animate" for quicker processing (lips only)
|
82 |
+
- Choose "Slower animate" for more realistic results (full face)
|
83 |
+
|
84 |
+
5. **View and Edit Results**:
|
85 |
+
- The generated video will appear in the app
|
86 |
+
- Use the trim feature to remove unwanted portions from the start or end
|
87 |
+
- Download the original or trimmed version to your computer
|
88 |
+
|
89 |
+
## Video Trimming Feature
|
90 |
+
|
91 |
+
The app now includes a video trimming capability:
|
92 |
+
|
93 |
+
- After generating a lip-sync video, you'll see trimming options below the result
|
94 |
+
- Use the sliders to select the start and end times for your trimmed video
|
95 |
+
- Click "Trim Video" to create a shortened version
|
96 |
+
- Both original and trimmed videos can be downloaded directly from the app
|
97 |
+
|
98 |
+
## Quality Assessment Feature
|
99 |
+
|
100 |
+
The app now includes automatic quality assessment for uploaded videos and audio:
|
101 |
+
|
102 |
+
### Video Analysis:
|
103 |
+
- Resolution check (higher resolution = better results)
|
104 |
+
- Face detection (confirms a face is present and properly sized)
|
105 |
+
- Frame rate analysis
|
106 |
+
- Overall quality score with specific recommendations
|
107 |
+
|
108 |
+
### Audio Analysis:
|
109 |
+
- Speech detection (confirms speech is present)
|
110 |
+
- Volume level assessment
|
111 |
+
- Silence detection
|
112 |
+
- Overall quality score with specific recommendations
|
113 |
+
|
114 |
+
## Troubleshooting
|
115 |
+
|
116 |
+
- **"No face detected" error**: Ensure your video has a clear, well-lit frontal face
|
117 |
+
- **Poor lip sync results**: Try using higher quality audio with clear speech
|
118 |
+
- **Performance issues**: For large videos, try the "Fast animate" option or use a smaller video clip
|
119 |
+
- **Memory errors**: Close other applications to free up memory, or use a machine with more RAM
|
120 |
+
|
121 |
+
## Technical Details
|
122 |
+
|
123 |
+
The project is built on the Wav2Lip model with several optimizations:
|
124 |
+
- Apple Silicon (M1/M2) GPU acceleration using MPS backend
|
125 |
+
- Automatic video resolution scaling for large videos
|
126 |
+
- Memory optimizations for processing longer videos
|
127 |
+
- Quality assessment using OpenCV and librosa
|
128 |
+
|
129 |
+
## Original Project Background
|
130 |
+
|
131 |
+
The project started as a part of an interview process with some company, I received an email with the following task:
|
132 |
+
|
133 |
+
Assignment Object:<br>
|
134 |
+
    Your task is to develop a lip-syncing model using machine learning
|
135 |
+
techniques. It takes an input image and audio and then generates a video
|
136 |
+
where the image appears to lip sync with the provided audio. You have to
|
137 |
+
develop this task using python3.
|
138 |
+
|
139 |
+
Requirements:<br>
|
140 |
+
    ● Avatar / Image : Get one AI-generated avatar, the avatar may be for a<br>
|
141 |
+
    man, woman, old man, old lady or a child. Ensure that the avatar is<br>
|
142 |
+
    created by artificial intelligence and does not represent real<br>
|
143 |
+
    individuals.<br>
|
144 |
+
    ● Audio : Provide two distinct and clear audio recordings—one in Arabic<br>
|
145 |
+
    and the other in English. The duration of each audio clip should be<br>
|
146 |
+
    no less than 30 seconds and no more than 1 minute.<br>
|
147 |
+
    ● Lip-sync model: Develop a lip-syncing model to synchronise the lip<br>
|
148 |
+
    movements of the chosen avatar with the provided audio. Ensure the<br>
|
149 |
+
    model demonstrates proficiency in accurately aligning lip motions<br>
|
150 |
+
    with the spoken words in both Arabic and English.<br>
|
151 |
+
    Hint : You can refer to state of the art models in lip-syncing.<br>
|
152 |
+
|
153 |
+
I was given about 96 hours to accomplish this task, I spent the first 12 hours sick with a very bad flu and no proper internet connection so I had 84 hours!<br>
|
154 |
+
After submitting the task on time, I took more time to deploy the project on Streamlight, as I thought it was a fun project and would be a nice addition to my CV:)
|
155 |
+
|
156 |
+
Given the provided hint from the company, "You can refer to state-of-the-art models in lip-syncing.", I started looking into the available open-source pre-trained model that can accomplish this task and most available resources pointed towards **Wav2Lip**. I found a couple of interesting tutorials for that model that I will share below.
|
157 |
+
|
158 |
+
### How to run the application locally:<br>
|
159 |
+
|
160 |
+
1- clone the repo to your local machine.<br>
|
161 |
+
2- open your terminal inside the project folder and run the following command: `pip install -r requirements.txt` and then run this command `sudo xargs -a packages.txt apt-get install` to install the needed modules and packages.<br>
|
162 |
+
3- open your terminal inside the project folder and run the following command: `streamlit run app.py` to run the streamlit application.<br>
|
163 |
+
|
164 |
+
### Things I changed in the wav2lip and why:<br>
|
165 |
+
|
166 |
+
In order to work with and deploy the wav2lip model I had to make the following changes:<br>
|
167 |
+
1- Changed the `_build_mel_basis()` function in `audio.py`, I had to do that to be able to work with `librosa>=0.10.0` package, check this [issue](https://github.com/Rudrabha/Wav2Lip/issues/550) for more details.<br>
|
168 |
+
2- Changed the `main()` function at the `inferance.py` to directly take an output from the `app.py` instead of using the command line arguments.<br>
|
169 |
+
3- I took the `load_model(path)` function and added it to `app.py` and added `@st.cache_data` in order to only load the model once, instead of using it multiple times, I also modified it<br>
|
170 |
+
4- Deleted the unnecessary files like the checkpoints to make the Streamlit website deployment easier.<br>
|
171 |
+
5- Since I'm using Streamlit for deployment and Streamlit Cloud doesn't support GPU, I had to change the device to work with `cpu` instead of `cuda`.<br>
|
172 |
+
6- I made other minor changes like changing the path to a file or modifying import statements.
|
173 |
+
|
174 |
+
### Issues I had with Streamlit, during the deployment:
|
175 |
+
|
176 |
+
This part is a documentation for me, just in case, I need to face an issue in the future and also could be helpful for any poor soul who would have to work with Streamlit:
|
177 |
+
|
178 |
+
1-
|
179 |
+
```
|
180 |
+
Error downloading object: wav2lip/checkpoints/wav2lip_gan.pth (ca9ab7b): Smudge error: Error downloading wav2lip/checkpoints/wav2lip_gan.pth (ca9ab7b7b812c0e80a6e70a5977c545a1e8a365a6c49d5e533023c034d7ac3d8): batch request: [email protected]: Permission denied (publickey).: exit status 255
|
181 |
+
|
182 |
+
Errors logged to /mount/src/ai-lip-sync/.git/lfs/logs/20240121T212252.496674
|
183 |
+
```
|
184 |
+
This essentially Streamlit telling you that it can't handle that big file, upload it to Google Drive, and then load it using Python code later, and no `git lfs` won't solve the problem :)<br>
|
185 |
+
A ground rule that I learned here is: that the lighter you make your app, the better and faster it is to deploy it.<br>
|
186 |
+
I opened a topic with that issue on the Streamlit forum, right [here](https://discuss.streamlit.io/t/file-upload-fails-with-error-downloading-object-wav2lip-checkpoints-wav2lip-gan-pth-ca9ab7b/60261)<br>
|
187 |
+
|
188 |
+
2- Other issues that I faced a lot were dependency issues -lots of them- and that was mostly due to the fact that I depended on `pipreqs` to write down my `requirements.txt`, that `pipreqs` missed up my modules, it added unneeded ones and missed others, unfortunately, it took me some time to discover that and really slowed me down.
|
189 |
+
|
190 |
+
3-
|
191 |
+
```
|
192 |
+
ImportError: libGL.so.1: cannot open shared object file: No such file or directory
|
193 |
+
```
|
194 |
+
I faced that problem during importing `cv2` -`openCv`- and the solution was to install `libgl1-mesa-dev` and some other packages using `apt`, you can't just add such packages to the `requirements.txt`, you need to create a file named `packages.txt` to do so.
|
195 |
+
|
196 |
+
4- Streamlit can't handle heavy processing, I discovered that when I tried to deploy the `slow animation` button to process video input alongside recording to get more accurate lip-syncing, the application failed directly when I used that button -and I tried to use it twice :)-, and that kinda make sense as Streamlit doesn't have a GPU or even a high ram space -I don't have a good GPU but I have about 64GB ram which was enough to run that function locally- and to solve that issue, I initiated another branch to contain the deployment version that doesn't have the `slow animation` button and used that branch for deployment while kept the main branch containing that button.
|
197 |
+
|
198 |
+
**Pushing the checkpoints files:**<br>
|
199 |
+
|
200 |
+
Given the size of those kind of files, There are 2 ways to handle that.
|
201 |
+
|
202 |
+
At the start, I had to use git lfs, here's how to do it:<br>
|
203 |
+
|
204 |
+
1- Follow the installation instructions that are suitable for your system from [here](https://docs.github.com/en/repositories/working-with-files/managing-large-files/installing-git-large-file-storage) <br>
|
205 |
+
2- Use the command `git lfs track "*.pth"` to let git lfs know that those are your big files.<br>
|
206 |
+
3- When pushing from the command line -I usually use VS code but it usually doesn't work with big files like `.pth` files- you need to generate a personal access token, to do so, follow the instructions from [here](https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/managing-your-personal-access-tokens#creating-a-fine-grained-personal-access-token), and then copy the token<br>
|
207 |
+
4- When pushing the file from the terminal you will be asked to pass a password, don't pass your GitHub profile password, instead pass your personal access token that you got from step 3.
|
208 |
+
|
209 |
+
But then Streamlit wasn't capable of even pulling the repo! so I uploaded the model checkpoints and some other files to Google Drive, put them in a public folder, and then used a module called gdown to download those folders when needed! here's a [link](https://github.com/wkentaro/gdown) to that gdown, it's straightforward to use and install.
|
210 |
+
|
211 |
+
|
212 |
+
**Video preview of the application:**<br>
|
213 |
+
|
214 |
+
**fast animation version**<br>
|
215 |
+
Notice how only the lips are moving.
|
216 |
+
|
217 |
+
English version:
|
218 |
+
|
219 |
+
https://github.com/Aml-Hassan-Abd-El-hamid/AI-Lip-Sync/assets/66205928/36577ccb-5ec6-4bb4-b7ff-44bb52a4f984
|
220 |
+
|
221 |
+
Arabic version:
|
222 |
+
|
223 |
+
|
224 |
+
|
225 |
+
https://github.com/Aml-Hassan-Abd-El-hamid/ai-lip-sync-app/assets/66205928/4346aa6d-ea4e-400e-9124-1cce06b049df
|
226 |
+
|
227 |
+
|
228 |
+
|
229 |
+
**slower animation version**<br>
|
230 |
+
Notice how the eye and the whole face are moving instead of only the lips.<br>
|
231 |
+
|
232 |
+
Unfortunately, Streamlit can't handle the computational power that the slower animation version requires and that's why I made it only available on the offline version, which means that you need to run the application locally to try that version.
|
233 |
+
|
234 |
+
English version:
|
235 |
+
|
236 |
+
https://github.com/Aml-Hassan-Abd-El-hamid/AI-Lip-Sync/assets/66205928/26740856-52e5-4fe7-868d-3b9341e97064
|
237 |
+
|
238 |
+
Arabic version:
|
239 |
+
|
240 |
+
|
241 |
+
|
242 |
+
https://github.com/Aml-Hassan-Abd-El-hamid/ai-lip-sync-app/assets/66205928/ba97daca-b30d-4179-9387-a382abbca3ba
|
243 |
+
|
244 |
+
|
245 |
+
|
246 |
+
The only difference between the fast and slow versions of animation here is the fact that the fast version passes only a photo while the slow version passes a video instead.
|
app.py
ADDED
@@ -0,0 +1,1000 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import streamlit as st
|
3 |
+
from streamlit_image_select import image_select
|
4 |
+
import torch
|
5 |
+
from streamlit_mic_recorder import mic_recorder
|
6 |
+
from wav2lip import inference
|
7 |
+
from wav2lip.models import Wav2Lip
|
8 |
+
import gdown
|
9 |
+
import warnings
|
10 |
+
import cv2
|
11 |
+
import numpy as np
|
12 |
+
import librosa
|
13 |
+
from pathlib import Path
|
14 |
+
import subprocess
|
15 |
+
import time
|
16 |
+
from PIL import Image
|
17 |
+
import matplotlib.pyplot as plt
|
18 |
+
import sys
|
19 |
+
import threading
|
20 |
+
import concurrent.futures
|
21 |
+
|
22 |
+
# Suppress warnings
|
23 |
+
warnings.filterwarnings('ignore')
|
24 |
+
|
25 |
+
# More comprehensive fix for Streamlit file watcher issues with PyTorch
|
26 |
+
os.environ['STREAMLIT_WATCH_IGNORE'] = 'torch'
|
27 |
+
if 'torch' in sys.modules:
|
28 |
+
sys.modules['torch'].__path__ = type('', (), {'_path': []})()
|
29 |
+
|
30 |
+
# Check if MPS (Apple Silicon GPU) is available, otherwise use CPU
|
31 |
+
if torch.backends.mps.is_available():
|
32 |
+
device = 'mps'
|
33 |
+
# Enable memory optimization for Apple Silicon
|
34 |
+
torch.mps.empty_cache()
|
35 |
+
# Set the memory format to optimize for M2 Max
|
36 |
+
torch._C._set_cudnn_benchmark(True)
|
37 |
+
st.success("Using Apple M2 Max GPU for acceleration with optimized settings!")
|
38 |
+
else:
|
39 |
+
device = 'cpu'
|
40 |
+
st.warning("Using CPU for inference (slower). GPU acceleration not available.")
|
41 |
+
|
42 |
+
print(f"Using {device} for inference.")
|
43 |
+
|
44 |
+
# Add functions to analyze video and audio quality
|
45 |
+
def analyze_video_quality(file_path):
|
46 |
+
"""Analyze video quality and detect faces for better user guidance"""
|
47 |
+
try:
|
48 |
+
# Open the video file
|
49 |
+
video = cv2.VideoCapture(file_path)
|
50 |
+
|
51 |
+
# Get video properties
|
52 |
+
width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
|
53 |
+
height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
54 |
+
fps = video.get(cv2.CAP_PROP_FPS)
|
55 |
+
frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
|
56 |
+
duration = frame_count / fps if fps > 0 else 0
|
57 |
+
|
58 |
+
# Read a frame for face detection
|
59 |
+
success, frame = video.read()
|
60 |
+
if not success:
|
61 |
+
return {
|
62 |
+
"resolution": f"{width}x{height}",
|
63 |
+
"fps": fps,
|
64 |
+
"duration": f"{duration:.1f} seconds",
|
65 |
+
"quality": "Unknown",
|
66 |
+
"face_detected": False,
|
67 |
+
"message": "Could not analyze video content."
|
68 |
+
}
|
69 |
+
|
70 |
+
# Detect faces using OpenCV's face detector
|
71 |
+
face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
|
72 |
+
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
|
73 |
+
faces = face_cascade.detectMultiScale(gray, 1.1, 4)
|
74 |
+
|
75 |
+
# Determine quality score based on resolution and face detection
|
76 |
+
quality_score = 0
|
77 |
+
|
78 |
+
# Resolution assessment
|
79 |
+
if width >= 1920 or height >= 1080: # 1080p or higher
|
80 |
+
resolution_quality = "Excellent"
|
81 |
+
quality_score += 3
|
82 |
+
elif width >= 1280 or height >= 720: # 720p
|
83 |
+
resolution_quality = "Good"
|
84 |
+
quality_score += 2
|
85 |
+
elif width >= 640 or height >= 480: # 480p
|
86 |
+
resolution_quality = "Fair"
|
87 |
+
quality_score += 1
|
88 |
+
else:
|
89 |
+
resolution_quality = "Low"
|
90 |
+
|
91 |
+
# Overall quality assessment
|
92 |
+
face_detected = len(faces) > 0
|
93 |
+
|
94 |
+
if face_detected:
|
95 |
+
quality_score += 2
|
96 |
+
face_message = "Face detected! ✅"
|
97 |
+
|
98 |
+
# Check face size relative to frame
|
99 |
+
for (x, y, w, h) in faces:
|
100 |
+
face_area_ratio = (w * h) / (width * height)
|
101 |
+
if face_area_ratio > 0.1: # Face takes up at least 10% of frame
|
102 |
+
quality_score += 1
|
103 |
+
face_size = "Good face size"
|
104 |
+
else:
|
105 |
+
face_size = "Face may be too small"
|
106 |
+
else:
|
107 |
+
face_message = "No face detected! ⚠️ Lip sync results may be poor."
|
108 |
+
face_size = "N/A"
|
109 |
+
|
110 |
+
# Determine overall quality
|
111 |
+
if quality_score >= 5:
|
112 |
+
quality = "Excellent"
|
113 |
+
elif quality_score >= 3:
|
114 |
+
quality = "Good"
|
115 |
+
elif quality_score >= 1:
|
116 |
+
quality = "Fair"
|
117 |
+
else:
|
118 |
+
quality = "Poor"
|
119 |
+
|
120 |
+
# Release video resource
|
121 |
+
video.release()
|
122 |
+
|
123 |
+
return {
|
124 |
+
"resolution": f"{width}x{height}",
|
125 |
+
"fps": f"{fps:.1f}",
|
126 |
+
"duration": f"{duration:.1f} seconds",
|
127 |
+
"quality": quality,
|
128 |
+
"resolution_quality": resolution_quality,
|
129 |
+
"face_detected": face_detected,
|
130 |
+
"face_message": face_message,
|
131 |
+
"face_size": face_size,
|
132 |
+
"message": get_video_recommendation(quality, face_detected, width, height)
|
133 |
+
}
|
134 |
+
|
135 |
+
except Exception as e:
|
136 |
+
return {
|
137 |
+
"quality": "Error",
|
138 |
+
"message": f"Could not analyze video: {str(e)}"
|
139 |
+
}
|
140 |
+
|
141 |
+
def analyze_audio_quality(file_path):
|
142 |
+
"""Analyze audio quality for better user guidance"""
|
143 |
+
try:
|
144 |
+
# Load audio file using librosa
|
145 |
+
y, sr = librosa.load(file_path, sr=None)
|
146 |
+
|
147 |
+
# Get duration
|
148 |
+
duration = librosa.get_duration(y=y, sr=sr)
|
149 |
+
|
150 |
+
# Calculate audio features
|
151 |
+
rms = librosa.feature.rms(y=y)[0]
|
152 |
+
mean_volume = np.mean(rms)
|
153 |
+
|
154 |
+
# Simple speech detection (using energy levels)
|
155 |
+
has_speech = np.max(rms) > 0.05
|
156 |
+
|
157 |
+
# Check for silence periods
|
158 |
+
silence_threshold = 0.01
|
159 |
+
silence_percentage = np.mean(rms < silence_threshold) * 100
|
160 |
+
|
161 |
+
# Calculate quality score
|
162 |
+
quality_score = 0
|
163 |
+
|
164 |
+
# Volume assessment
|
165 |
+
if 0.05 <= mean_volume <= 0.2:
|
166 |
+
volume_quality = "Good volume levels"
|
167 |
+
quality_score += 2
|
168 |
+
elif mean_volume > 0.2:
|
169 |
+
volume_quality = "Audio might be too loud"
|
170 |
+
quality_score += 1
|
171 |
+
else:
|
172 |
+
volume_quality = "Audio might be too quiet"
|
173 |
+
|
174 |
+
# Speech detection
|
175 |
+
if has_speech:
|
176 |
+
speech_quality = "Speech detected ✅"
|
177 |
+
quality_score += 2
|
178 |
+
else:
|
179 |
+
speech_quality = "Speech may not be clear ⚠️"
|
180 |
+
|
181 |
+
# Silence assessment (some silence is normal)
|
182 |
+
if silence_percentage < 40:
|
183 |
+
silence_quality = "Good speech-to-silence ratio"
|
184 |
+
quality_score += 1
|
185 |
+
else:
|
186 |
+
silence_quality = "Too much silence detected"
|
187 |
+
|
188 |
+
# Determine overall quality
|
189 |
+
if quality_score >= 4:
|
190 |
+
quality = "Excellent"
|
191 |
+
elif quality_score >= 2:
|
192 |
+
quality = "Good"
|
193 |
+
elif quality_score >= 1:
|
194 |
+
quality = "Fair"
|
195 |
+
else:
|
196 |
+
quality = "Poor"
|
197 |
+
|
198 |
+
return {
|
199 |
+
"duration": f"{duration:.1f} seconds",
|
200 |
+
"quality": quality,
|
201 |
+
"volume_quality": volume_quality,
|
202 |
+
"speech_quality": speech_quality,
|
203 |
+
"silence_quality": silence_quality,
|
204 |
+
"message": get_audio_recommendation(quality, has_speech, mean_volume, silence_percentage)
|
205 |
+
}
|
206 |
+
|
207 |
+
except Exception as e:
|
208 |
+
return {
|
209 |
+
"quality": "Error",
|
210 |
+
"message": f"Could not analyze audio: {str(e)}"
|
211 |
+
}
|
212 |
+
|
213 |
+
def get_video_recommendation(quality, face_detected, width, height):
|
214 |
+
"""Get recommendations based on video quality"""
|
215 |
+
if not face_detected:
|
216 |
+
return "⚠️ No face detected. For best results, use a video with a clear, well-lit face looking toward the camera."
|
217 |
+
|
218 |
+
if quality == "Poor":
|
219 |
+
return "⚠️ Low quality video. Consider using a higher resolution video with better lighting and a clearly visible face."
|
220 |
+
|
221 |
+
if width < 640 or height < 480:
|
222 |
+
return "⚠️ Video resolution is low. For better results, use a video with at least 480p resolution."
|
223 |
+
|
224 |
+
if quality == "Excellent":
|
225 |
+
return "✅ Great video quality! This should work well for lip syncing."
|
226 |
+
|
227 |
+
return "✅ Video quality is acceptable for lip syncing."
|
228 |
+
|
229 |
+
def get_audio_recommendation(quality, has_speech, volume, silence_percentage):
|
230 |
+
"""Get recommendations based on audio quality"""
|
231 |
+
if not has_speech:
|
232 |
+
return "⚠️ Speech may not be clearly detected. For best results, use audio with clear speech."
|
233 |
+
|
234 |
+
if quality == "Poor":
|
235 |
+
return "⚠️ Low quality audio. Consider using clearer audio with consistent volume levels."
|
236 |
+
|
237 |
+
if volume < 0.01:
|
238 |
+
return "⚠️ Audio volume is very low. This may result in poor lip sync."
|
239 |
+
|
240 |
+
if volume > 0.3:
|
241 |
+
return "⚠️ Audio volume is very high. This may cause distortion in lip sync."
|
242 |
+
|
243 |
+
if silence_percentage > 50:
|
244 |
+
return "⚠️ Audio contains a lot of silence. Lip sync will only work during speech sections."
|
245 |
+
|
246 |
+
if quality == "Excellent":
|
247 |
+
return "✅ Great audio quality! This should work well for lip syncing."
|
248 |
+
|
249 |
+
return "✅ Audio quality is acceptable for lip syncing."
|
250 |
+
|
251 |
+
#@st.cache_data is used to only load the model once
|
252 |
+
#@st.cache_data
|
253 |
+
@st.cache_resource
|
254 |
+
def load_model(path):
|
255 |
+
st.write("Please wait for the model to be loaded or it will cause an error")
|
256 |
+
wav2lip_checkpoints_url = "https://drive.google.com/drive/folders/1Sy5SHRmI3zgg2RJaOttNsN3iJS9VVkbg?usp=sharing"
|
257 |
+
if not os.path.exists(path):
|
258 |
+
gdown.download_folder(wav2lip_checkpoints_url, quiet=True, use_cookies=False)
|
259 |
+
st.write("Please wait")
|
260 |
+
model = Wav2Lip()
|
261 |
+
print("Load checkpoint from: {}".format(path))
|
262 |
+
|
263 |
+
# Optimize model loading for M2 Max
|
264 |
+
if device == 'mps':
|
265 |
+
# Clear cache before loading model
|
266 |
+
torch.mps.empty_cache()
|
267 |
+
|
268 |
+
# Load model with device mapping
|
269 |
+
checkpoint = torch.load(path, map_location=torch.device(device))
|
270 |
+
s = checkpoint["state_dict"]
|
271 |
+
new_s = {}
|
272 |
+
for k, v in s.items():
|
273 |
+
new_s[k.replace('module.', '')] = v
|
274 |
+
model.load_state_dict(new_s)
|
275 |
+
model = model.to(device)
|
276 |
+
|
277 |
+
# Set model to evaluation mode and optimize for inference
|
278 |
+
model.eval()
|
279 |
+
if device == 'mps':
|
280 |
+
# Attempt to optimize the model for inference
|
281 |
+
try:
|
282 |
+
# Use torch's inference mode for optimized inference
|
283 |
+
torch._C._jit_set_profiling_executor(False)
|
284 |
+
torch._C._jit_set_profiling_mode(False)
|
285 |
+
print("Applied M2 Max optimizations")
|
286 |
+
except:
|
287 |
+
print("Could not apply all M2 Max optimizations")
|
288 |
+
|
289 |
+
st.write(f"Model loaded successfully on {device} with optimized settings for M2 Max!")
|
290 |
+
return model
|
291 |
+
@st.cache_resource
|
292 |
+
def load_avatar_videos_for_slow_animation(path):
|
293 |
+
if not os.path.exists(path):
|
294 |
+
try:
|
295 |
+
os.makedirs(path, exist_ok=True)
|
296 |
+
print(f"Created directory: {path}")
|
297 |
+
|
298 |
+
avatar_videos_url = "https://drive.google.com/drive/folders/1h9pkU5wenrS2vmKqXBfFmrg-1hYw5s4q?usp=sharing"
|
299 |
+
print(f"Downloading avatar videos from: {avatar_videos_url}")
|
300 |
+
gdown.download_folder(avatar_videos_url, quiet=False, use_cookies=False)
|
301 |
+
print(f"Avatar videos downloaded successfully to: {path}")
|
302 |
+
except Exception as e:
|
303 |
+
print(f"Error downloading avatar videos: {str(e)}")
|
304 |
+
# Create default empty videos if download fails
|
305 |
+
for avatar_file in ["avatar1.mp4", "avatar2.mp4", "avatar3.mp4"]:
|
306 |
+
video_path = os.path.join(path, avatar_file)
|
307 |
+
if not os.path.exists(video_path):
|
308 |
+
print(f"Creating empty video file: {video_path}")
|
309 |
+
# Get the matching image
|
310 |
+
img_key = f"avatars_images/{os.path.splitext(avatar_file)[0]}" + (".jpg" if avatar_file != "avatar3.mp4" else ".png")
|
311 |
+
try:
|
312 |
+
# Create a video from the image
|
313 |
+
img = cv2.imread(img_key)
|
314 |
+
if img is not None:
|
315 |
+
# Create a short 5-second video from the image
|
316 |
+
print(f"Creating video from image: {img_key}")
|
317 |
+
height, width = img.shape[:2]
|
318 |
+
output_video = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), 30, (width, height))
|
319 |
+
for _ in range(150): # 5 seconds at 30 fps
|
320 |
+
output_video.write(img)
|
321 |
+
output_video.release()
|
322 |
+
else:
|
323 |
+
print(f"Could not read image: {img_key}")
|
324 |
+
except Exception as e:
|
325 |
+
print(f"Error creating video from image: {str(e)}")
|
326 |
+
else:
|
327 |
+
print(f"Avatar videos directory already exists: {path}")
|
328 |
+
# Check if files exist in the directory
|
329 |
+
files = os.listdir(path)
|
330 |
+
if not files:
|
331 |
+
print(f"No files found in {path}, directory exists but is empty")
|
332 |
+
else:
|
333 |
+
print(f"Found {len(files)} files in {path}: {', '.join(files)}")
|
334 |
+
|
335 |
+
|
336 |
+
|
337 |
+
image_video_map = {
|
338 |
+
"avatars_images/avatar1.jpg":"avatars_videos/avatar1.mp4",
|
339 |
+
"avatars_images/avatar2.jpg":"avatars_videos/avatar2.mp4",
|
340 |
+
"avatars_images/avatar3.png":"avatars_videos/avatar3.mp4"
|
341 |
+
}
|
342 |
+
def streamlit_look():
|
343 |
+
"""
|
344 |
+
Modest front-end code:)
|
345 |
+
"""
|
346 |
+
data={}
|
347 |
+
st.title("Welcome to AI Lip Sync :)")
|
348 |
+
|
349 |
+
# Add a brief app description
|
350 |
+
st.markdown("""
|
351 |
+
This app uses AI to synchronize a person's lip movements with any audio file.
|
352 |
+
You can choose from built-in avatars or upload your own image/video, then provide audio
|
353 |
+
to create realistic lip-synced videos. Powered by Wav2Lip and optimized for Apple Silicon.
|
354 |
+
""")
|
355 |
+
|
356 |
+
# Add a guidelines section with an expander for best practices
|
357 |
+
with st.expander("📋 Guidelines & Best Practices (Click to expand)", expanded=False):
|
358 |
+
st.markdown("""
|
359 |
+
### Guidelines for Best Results
|
360 |
+
|
361 |
+
#### Audio and Video Length
|
362 |
+
- Audio and video don't need to be exactly the same length
|
363 |
+
- If audio is shorter than video: Only the matching portion will be lip-synced
|
364 |
+
- If audio is longer than video: Audio will be trimmed to match video length
|
365 |
+
|
366 |
+
#### Face Quality
|
367 |
+
- Clear, well-lit frontal views of faces work best
|
368 |
+
- Faces should take up a reasonable portion of the frame
|
369 |
+
- Avoid extreme angles, heavy shadows, or partial face views
|
370 |
+
|
371 |
+
#### Audio Quality
|
372 |
+
- Clear speech with minimal background noise works best
|
373 |
+
- Consistent audio volume improves synchronization
|
374 |
+
- Supported formats: WAV, MP3
|
375 |
+
|
376 |
+
#### Video Quality
|
377 |
+
- Stable videos with minimal camera movement
|
378 |
+
- The person's mouth should be clearly visible
|
379 |
+
- Videos at 480p or higher resolution work best
|
380 |
+
- Very high-resolution videos will be automatically downscaled
|
381 |
+
|
382 |
+
#### Processing Tips
|
383 |
+
- Shorter videos process faster and often give better results
|
384 |
+
- "Fast animation" only moves the lips (quicker processing)
|
385 |
+
- "Slow animation" animates the full face (better quality, slower)
|
386 |
+
- Your M2 Max GPU will significantly speed up processing
|
387 |
+
""")
|
388 |
+
|
389 |
+
# Option to choose between built-in avatars or upload a custom one
|
390 |
+
avatar_source = st.radio("Choose avatar source:", ["Upload my own image/video", "Use built-in avatars"])
|
391 |
+
|
392 |
+
if avatar_source == "Use built-in avatars":
|
393 |
+
st.write("Please choose your avatar from the following options:")
|
394 |
+
avatar_img = image_select("",
|
395 |
+
["avatars_images/avatar1.jpg",
|
396 |
+
"avatars_images/avatar2.jpg",
|
397 |
+
"avatars_images/avatar3.png",
|
398 |
+
])
|
399 |
+
data["imge_path"] = avatar_img
|
400 |
+
else:
|
401 |
+
st.write("Upload an image or video file for your avatar:")
|
402 |
+
uploaded_file = st.file_uploader("Choose an image or video file", type=["jpg", "jpeg", "png", "mp4"], key="avatar_uploader")
|
403 |
+
|
404 |
+
if uploaded_file is not None:
|
405 |
+
# Save the uploaded file
|
406 |
+
file_path = os.path.join("uploads", uploaded_file.name)
|
407 |
+
os.makedirs("uploads", exist_ok=True)
|
408 |
+
|
409 |
+
with open(file_path, "wb") as f:
|
410 |
+
f.write(uploaded_file.getvalue())
|
411 |
+
|
412 |
+
# Set the file path as image path
|
413 |
+
data["imge_path"] = file_path
|
414 |
+
st.success(f"File uploaded successfully: {uploaded_file.name}")
|
415 |
+
|
416 |
+
# Preview the uploaded image/video
|
417 |
+
if uploaded_file.name.endswith(('.jpg', '.jpeg', '.png')):
|
418 |
+
st.image(file_path, caption="Uploaded Image")
|
419 |
+
elif uploaded_file.name.endswith('.mp4'):
|
420 |
+
st.video(file_path)
|
421 |
+
|
422 |
+
# Analyze video quality for MP4 files
|
423 |
+
with st.spinner("Analyzing video quality..."):
|
424 |
+
video_analysis = analyze_video_quality(file_path)
|
425 |
+
|
426 |
+
# Display video quality analysis in a nice box
|
427 |
+
with st.expander("📊 Video Quality Analysis", expanded=True):
|
428 |
+
col1, col2 = st.columns(2)
|
429 |
+
|
430 |
+
with col1:
|
431 |
+
st.markdown(f"**Resolution:** {video_analysis['resolution']}")
|
432 |
+
st.markdown(f"**FPS:** {video_analysis['fps']}")
|
433 |
+
st.markdown(f"**Duration:** {video_analysis['duration']}")
|
434 |
+
|
435 |
+
with col2:
|
436 |
+
quality_color = {
|
437 |
+
"Excellent": "green",
|
438 |
+
"Good": "lightgreen",
|
439 |
+
"Fair": "orange",
|
440 |
+
"Poor": "red",
|
441 |
+
"Error": "red"
|
442 |
+
}.get(video_analysis['quality'], "gray")
|
443 |
+
|
444 |
+
st.markdown(f"**Quality:** <span style='color:{quality_color};font-weight:bold'>{video_analysis['quality']}</span>", unsafe_allow_html=True)
|
445 |
+
st.markdown(f"**Face Detection:** {'✅ Detected' if video_analysis.get('face_detected', False) else '❌ Not detected'}")
|
446 |
+
|
447 |
+
# Display the recommendation
|
448 |
+
st.info(video_analysis['message'])
|
449 |
+
|
450 |
+
# Option to choose between mic recording or upload audio file
|
451 |
+
audio_source = st.radio("Choose audio source:", ["Upload audio file", "Record with microphone"])
|
452 |
+
|
453 |
+
if audio_source == "Record with microphone":
|
454 |
+
audio = mic_recorder(
|
455 |
+
start_prompt="Start recording",
|
456 |
+
stop_prompt="Stop recording",
|
457 |
+
just_once=False,
|
458 |
+
use_container_width=False,
|
459 |
+
callback=None,
|
460 |
+
args=(),
|
461 |
+
kwargs={},
|
462 |
+
key=None)
|
463 |
+
|
464 |
+
if audio:
|
465 |
+
st.audio(audio["bytes"])
|
466 |
+
data["audio"] = audio["bytes"]
|
467 |
+
else:
|
468 |
+
st.write("Upload an audio file:")
|
469 |
+
uploaded_audio = st.file_uploader("Choose an audio file", type=["wav", "mp3"], key="audio_uploader")
|
470 |
+
|
471 |
+
if uploaded_audio is not None:
|
472 |
+
# Save the uploaded audio file
|
473 |
+
audio_path = os.path.join("uploads", uploaded_audio.name)
|
474 |
+
os.makedirs("uploads", exist_ok=True)
|
475 |
+
|
476 |
+
with open(audio_path, "wb") as f:
|
477 |
+
f.write(uploaded_audio.getvalue())
|
478 |
+
|
479 |
+
# Preview the uploaded audio
|
480 |
+
st.audio(audio_path)
|
481 |
+
|
482 |
+
# Read the file into bytes for consistency with microphone recording
|
483 |
+
with open(audio_path, "rb") as f:
|
484 |
+
audio_bytes = f.read()
|
485 |
+
|
486 |
+
data["audio"] = audio_bytes
|
487 |
+
st.success(f"Audio file uploaded successfully: {uploaded_audio.name}")
|
488 |
+
|
489 |
+
# Analyze audio quality
|
490 |
+
with st.spinner("Analyzing audio quality..."):
|
491 |
+
audio_analysis = analyze_audio_quality(audio_path)
|
492 |
+
|
493 |
+
# Display audio quality analysis in a nice box
|
494 |
+
with st.expander("🎵 Audio Quality Analysis", expanded=True):
|
495 |
+
col1, col2 = st.columns(2)
|
496 |
+
|
497 |
+
with col1:
|
498 |
+
st.markdown(f"**Duration:** {audio_analysis['duration']}")
|
499 |
+
st.markdown(f"**Volume:** {audio_analysis['volume_quality']}")
|
500 |
+
|
501 |
+
with col2:
|
502 |
+
quality_color = {
|
503 |
+
"Excellent": "green",
|
504 |
+
"Good": "lightgreen",
|
505 |
+
"Fair": "orange",
|
506 |
+
"Poor": "red",
|
507 |
+
"Error": "red"
|
508 |
+
}.get(audio_analysis['quality'], "gray")
|
509 |
+
|
510 |
+
st.markdown(f"**Quality:** <span style='color:{quality_color};font-weight:bold'>{audio_analysis['quality']}</span>", unsafe_allow_html=True)
|
511 |
+
st.markdown(f"**Speech:** {audio_analysis['speech_quality']}")
|
512 |
+
|
513 |
+
# Display the recommendation
|
514 |
+
st.info(audio_analysis['message'])
|
515 |
+
|
516 |
+
return data
|
517 |
+
|
518 |
+
def main():
|
519 |
+
# Initialize session state to track processing status
|
520 |
+
if 'processed' not in st.session_state:
|
521 |
+
st.session_state.processed = False
|
522 |
+
|
523 |
+
data = streamlit_look()
|
524 |
+
|
525 |
+
# Add debug information
|
526 |
+
st.write("Debug info:")
|
527 |
+
if "imge_path" in data:
|
528 |
+
st.write(f"Image/Video path: {data['imge_path']}")
|
529 |
+
else:
|
530 |
+
st.write("No image/video selected yet")
|
531 |
+
|
532 |
+
if "audio" in data:
|
533 |
+
st.write("Audio file selected ✓")
|
534 |
+
else:
|
535 |
+
st.write("No audio selected yet")
|
536 |
+
|
537 |
+
# Only proceed if we have both image/video and audio data
|
538 |
+
if "imge_path" in data and "audio" in data:
|
539 |
+
st.write("This app will automatically save your audio when you click animate.")
|
540 |
+
save_record = st.button("save record manually")
|
541 |
+
st.write("With fast animation only the lips of the avatar will move, and it will take probably less than a minute for a record of about 30 seconds, but with slow animation choice, the full face of the avatar will move and it will take about 30 minutes for a record of about 30 seconds to get ready.")
|
542 |
+
model = load_model("wav2lip_checkpoints/wav2lip_gan.pth")
|
543 |
+
|
544 |
+
# Check for duration mismatches between video and audio
|
545 |
+
if data["imge_path"].endswith('.mp4'):
|
546 |
+
# Save audio to temp file for analysis
|
547 |
+
if not os.path.exists('record.wav'):
|
548 |
+
with open('record.wav', mode='wb') as f:
|
549 |
+
f.write(data["audio"])
|
550 |
+
|
551 |
+
# Get durations
|
552 |
+
video_duration = get_video_duration(data["imge_path"])
|
553 |
+
audio_duration = get_audio_duration('record.wav')
|
554 |
+
|
555 |
+
# Check for significant duration mismatch (more than 2 seconds difference)
|
556 |
+
if abs(video_duration - audio_duration) > 2:
|
557 |
+
st.warning(f"⚠️ Duration mismatch detected: Video is {video_duration:.1f}s and Audio is {audio_duration:.1f}s")
|
558 |
+
|
559 |
+
# Create a tab for handling duration mismatches
|
560 |
+
with st.expander("Duration Mismatch Options (Click to expand)", expanded=True):
|
561 |
+
st.info("The video and audio have different durations. Choose an option below:")
|
562 |
+
|
563 |
+
if video_duration > audio_duration:
|
564 |
+
if st.button("Trim Video to Match Audio Duration"):
|
565 |
+
# Update duration values to match
|
566 |
+
output_path = 'uploads/trimmed_input_video.mp4'
|
567 |
+
with st.spinner(f"Trimming video from {video_duration:.1f}s to {audio_duration:.1f}s..."):
|
568 |
+
success = trim_video(data["imge_path"], output_path, 0, audio_duration)
|
569 |
+
|
570 |
+
if success:
|
571 |
+
st.success("Video trimmed to match audio duration!")
|
572 |
+
# Update the image path to use the trimmed video
|
573 |
+
data["imge_path"] = output_path
|
574 |
+
st.video(output_path)
|
575 |
+
else: # audio_duration > video_duration
|
576 |
+
if st.button("Trim Audio to Match Video Duration"):
|
577 |
+
# Update duration values to match
|
578 |
+
output_path = 'uploads/trimmed_input_audio.wav'
|
579 |
+
with st.spinner(f"Trimming audio from {audio_duration:.1f}s to {video_duration:.1f}s..."):
|
580 |
+
success = trim_audio('record.wav', output_path, 0, video_duration)
|
581 |
+
|
582 |
+
if success:
|
583 |
+
st.success("Audio trimmed to match video duration!")
|
584 |
+
# Update the audio data with the trimmed audio
|
585 |
+
with open(output_path, "rb") as f:
|
586 |
+
data["audio"] = f.read()
|
587 |
+
# Save the trimmed audio as record.wav
|
588 |
+
with open('record.wav', mode='wb') as f:
|
589 |
+
f.write(data["audio"])
|
590 |
+
st.audio(output_path)
|
591 |
+
|
592 |
+
# Animation buttons
|
593 |
+
fast_animate = st.button("fast animate")
|
594 |
+
slower_animate = st.button("slower animate")
|
595 |
+
|
596 |
+
# Function to save the audio record
|
597 |
+
def save_audio_record():
|
598 |
+
if os.path.exists('record.wav'):
|
599 |
+
os.remove('record.wav')
|
600 |
+
with open('record.wav', mode='wb') as f:
|
601 |
+
f.write(data["audio"])
|
602 |
+
st.write("Audio record saved!")
|
603 |
+
|
604 |
+
if save_record:
|
605 |
+
save_audio_record()
|
606 |
+
|
607 |
+
# Show previously generated results if they exist and we're not generating new ones
|
608 |
+
if os.path.exists('wav2lip/results/result_voice.mp4') and st.session_state.processed and not (fast_animate or slower_animate):
|
609 |
+
st.video('wav2lip/results/result_voice.mp4')
|
610 |
+
display_trim_options('wav2lip/results/result_voice.mp4')
|
611 |
+
|
612 |
+
if fast_animate:
|
613 |
+
# Automatically save the record before animation
|
614 |
+
save_audio_record()
|
615 |
+
|
616 |
+
progress_placeholder = st.empty()
|
617 |
+
status_placeholder = st.empty()
|
618 |
+
|
619 |
+
progress_bar = progress_placeholder.progress(0, text="Processing: 0% complete")
|
620 |
+
status_placeholder.info("Preparing to process...")
|
621 |
+
|
622 |
+
# Call the inference function inside a try block with progress updates at key points
|
623 |
+
try:
|
624 |
+
# Initialize a progress tracker
|
625 |
+
progress_steps = [
|
626 |
+
(0, "Starting processing..."),
|
627 |
+
(15, "Step 1/4: Loading and analyzing video frames"),
|
628 |
+
(30, "Step 2/4: Performing face detection (this may take a while for long videos)"),
|
629 |
+
(60, "Step 3/4: Generating lip-synced frames"),
|
630 |
+
(80, "Step 4/4: Creating final video with audio"),
|
631 |
+
(100, "Processing complete!")
|
632 |
+
]
|
633 |
+
current_step = 0
|
634 |
+
|
635 |
+
# Redirect stdout to capture progress information
|
636 |
+
import io
|
637 |
+
sys.stdout = io.StringIO()
|
638 |
+
|
639 |
+
# Update progress for the initial step
|
640 |
+
progress, message = progress_steps[current_step]
|
641 |
+
progress_bar.progress(progress, text=f"Processing: {progress}% complete")
|
642 |
+
status_placeholder.info(message)
|
643 |
+
current_step += 1
|
644 |
+
|
645 |
+
# Run the inference in a background thread
|
646 |
+
with concurrent.futures.ThreadPoolExecutor() as executor:
|
647 |
+
# Start the inference process
|
648 |
+
future = executor.submit(inference.main, data["imge_path"], "record.wav", model)
|
649 |
+
|
650 |
+
# Monitor the output for progress indicators
|
651 |
+
while not future.done():
|
652 |
+
captured_output = sys.stdout.getvalue()
|
653 |
+
|
654 |
+
# Check for progress indicators and update UI
|
655 |
+
if current_step < len(progress_steps):
|
656 |
+
# Check for stage 1 completion: frames read
|
657 |
+
if current_step == 1 and "Number of frames available for inference" in captured_output:
|
658 |
+
progress, message = progress_steps[current_step]
|
659 |
+
progress_bar.progress(progress, text=f"Processing: {progress}% complete")
|
660 |
+
status_placeholder.info(message)
|
661 |
+
current_step += 1
|
662 |
+
# Check for stage 2 completion: face detection
|
663 |
+
elif current_step == 2 and "Face detection completed successfully" in captured_output:
|
664 |
+
progress, message = progress_steps[current_step]
|
665 |
+
progress_bar.progress(progress, text=f"Processing: {progress}% complete")
|
666 |
+
status_placeholder.info(message)
|
667 |
+
current_step += 1
|
668 |
+
# Check for stage 3 completion: ffmpeg started
|
669 |
+
elif current_step == 3 and "ffmpeg" in captured_output:
|
670 |
+
progress, message = progress_steps[current_step]
|
671 |
+
progress_bar.progress(progress, text=f"Processing: {progress}% complete")
|
672 |
+
status_placeholder.info(message)
|
673 |
+
current_step += 1
|
674 |
+
|
675 |
+
# Sleep to avoid excessive CPU usage
|
676 |
+
time.sleep(0.5)
|
677 |
+
|
678 |
+
try:
|
679 |
+
# Get the result or propagate exceptions
|
680 |
+
future.result()
|
681 |
+
|
682 |
+
# Show completion
|
683 |
+
progress, message = progress_steps[-1]
|
684 |
+
progress_bar.progress(progress, text=f"Processing: {progress}% complete")
|
685 |
+
status_placeholder.success("Lip sync complete! Your video is ready.")
|
686 |
+
except Exception as e:
|
687 |
+
raise e
|
688 |
+
|
689 |
+
# Restore stdout
|
690 |
+
sys.stdout = sys.__stdout__
|
691 |
+
|
692 |
+
if os.path.exists('wav2lip/results/result_voice.mp4'):
|
693 |
+
st.video('wav2lip/results/result_voice.mp4')
|
694 |
+
display_trim_options('wav2lip/results/result_voice.mp4')
|
695 |
+
# Set processed flag to True after successful processing
|
696 |
+
st.session_state.processed = True
|
697 |
+
|
698 |
+
except Exception as e:
|
699 |
+
# Restore stdout in case of error
|
700 |
+
sys.stdout = sys.__stdout__
|
701 |
+
|
702 |
+
progress_placeholder.empty()
|
703 |
+
status_placeholder.error(f"Error during processing: {str(e)}")
|
704 |
+
st.error("Failed to generate video. Please try again or use a different image/audio.")
|
705 |
+
|
706 |
+
if slower_animate:
|
707 |
+
# Automatically save the record before animation
|
708 |
+
save_audio_record()
|
709 |
+
|
710 |
+
progress_placeholder = st.empty()
|
711 |
+
status_placeholder = st.empty()
|
712 |
+
|
713 |
+
progress_bar = progress_placeholder.progress(0, text="Processing: 0% complete")
|
714 |
+
status_placeholder.info("Preparing to process...")
|
715 |
+
|
716 |
+
# Derive the video path from the selected avatar
|
717 |
+
if data["imge_path"].endswith('.mp4'):
|
718 |
+
video_path = data["imge_path"]
|
719 |
+
else:
|
720 |
+
# Get the avatar video path for the selected avatar
|
721 |
+
avatar_list = load_avatar_videos_for_slow_animation("./data/avatars/samples")
|
722 |
+
video_path = avatar_list[available_avatars_for_slow.index(avatar_choice)]
|
723 |
+
|
724 |
+
try:
|
725 |
+
# Initialize a progress tracker
|
726 |
+
progress_steps = [
|
727 |
+
(0, "Starting processing..."),
|
728 |
+
(15, "Step 1/4: Loading and analyzing video frames"),
|
729 |
+
(30, "Step 2/4: Performing face detection (this may take a while for long videos)"),
|
730 |
+
(60, "Step 3/4: Generating lip-synced frames with full-face animation"),
|
731 |
+
(80, "Step 4/4: Creating final video with audio"),
|
732 |
+
(100, "Processing complete!")
|
733 |
+
]
|
734 |
+
current_step = 0
|
735 |
+
|
736 |
+
# Redirect stdout to capture progress information
|
737 |
+
import io
|
738 |
+
sys.stdout = io.StringIO()
|
739 |
+
|
740 |
+
# Update progress for the initial step
|
741 |
+
progress, message = progress_steps[current_step]
|
742 |
+
progress_bar.progress(progress, text=f"Processing: {progress}% complete")
|
743 |
+
status_placeholder.info(message)
|
744 |
+
current_step += 1
|
745 |
+
|
746 |
+
# Run the inference in a background thread
|
747 |
+
with concurrent.futures.ThreadPoolExecutor() as executor:
|
748 |
+
# Start the inference process
|
749 |
+
future = executor.submit(inference.main, video_path, "record.wav", model, slow_mode=True)
|
750 |
+
|
751 |
+
# Monitor the output for progress indicators
|
752 |
+
while not future.done():
|
753 |
+
captured_output = sys.stdout.getvalue()
|
754 |
+
|
755 |
+
# Check for progress indicators and update UI
|
756 |
+
if current_step < len(progress_steps):
|
757 |
+
# Check for stage 1 completion: frames read
|
758 |
+
if current_step == 1 and "Number of frames available for inference" in captured_output:
|
759 |
+
progress, message = progress_steps[current_step]
|
760 |
+
progress_bar.progress(progress, text=f"Processing: {progress}% complete")
|
761 |
+
status_placeholder.info(message)
|
762 |
+
current_step += 1
|
763 |
+
# Check for stage 2 completion: face detection
|
764 |
+
elif current_step == 2 and "Face detection completed successfully" in captured_output:
|
765 |
+
progress, message = progress_steps[current_step]
|
766 |
+
progress_bar.progress(progress, text=f"Processing: {progress}% complete")
|
767 |
+
status_placeholder.info(message)
|
768 |
+
current_step += 1
|
769 |
+
# Check for stage 3 completion: ffmpeg started
|
770 |
+
elif current_step == 3 and "ffmpeg" in captured_output:
|
771 |
+
progress, message = progress_steps[current_step]
|
772 |
+
progress_bar.progress(progress, text=f"Processing: {progress}% complete")
|
773 |
+
status_placeholder.info(message)
|
774 |
+
current_step += 1
|
775 |
+
|
776 |
+
# Sleep to avoid excessive CPU usage
|
777 |
+
time.sleep(0.5)
|
778 |
+
|
779 |
+
try:
|
780 |
+
# Get the result or propagate exceptions
|
781 |
+
future.result()
|
782 |
+
|
783 |
+
# Show completion
|
784 |
+
progress, message = progress_steps[-1]
|
785 |
+
progress_bar.progress(progress, text=f"Processing: {progress}% complete")
|
786 |
+
status_placeholder.success("Lip sync complete! Your video is ready.")
|
787 |
+
except Exception as e:
|
788 |
+
raise e
|
789 |
+
|
790 |
+
# Restore stdout
|
791 |
+
sys.stdout = sys.__stdout__
|
792 |
+
|
793 |
+
if os.path.exists('wav2lip/results/result_voice.mp4'):
|
794 |
+
st.video('wav2lip/results/result_voice.mp4')
|
795 |
+
display_trim_options('wav2lip/results/result_voice.mp4')
|
796 |
+
# Set processed flag to True after successful processing
|
797 |
+
st.session_state.processed = True
|
798 |
+
except Exception as e:
|
799 |
+
# Restore stdout in case of error
|
800 |
+
sys.stdout = sys.__stdout__
|
801 |
+
|
802 |
+
progress_placeholder.empty()
|
803 |
+
status_placeholder.error(f"Error during processing: {str(e)}")
|
804 |
+
st.error("Failed to generate video. Please try again or use a different video/audio.")
|
805 |
+
else:
|
806 |
+
if "imge_path" not in data and "audio" not in data:
|
807 |
+
st.warning("Please upload both an image/video AND provide audio to continue.")
|
808 |
+
elif "imge_path" not in data:
|
809 |
+
st.warning("Please select or upload an image/video to continue.")
|
810 |
+
else:
|
811 |
+
st.warning("Please provide audio to continue.")
|
812 |
+
|
813 |
+
# Function to display trim options and handle video trimming
|
814 |
+
def display_trim_options(video_path):
|
815 |
+
"""Display options to trim the video and handle the trimming process"""
|
816 |
+
st.subheader("Video Processing Options")
|
817 |
+
|
818 |
+
# Check if the video exists first
|
819 |
+
if not os.path.exists(video_path):
|
820 |
+
st.error(f"Video file not found at {video_path}. Try running the animation again.")
|
821 |
+
return
|
822 |
+
|
823 |
+
# Add tabs for different operations
|
824 |
+
download_tab, trim_tab = st.tabs(["Download Original", "Trim Video"])
|
825 |
+
|
826 |
+
with download_tab:
|
827 |
+
st.write("Download the original generated video:")
|
828 |
+
try:
|
829 |
+
st.video(video_path)
|
830 |
+
st.download_button(
|
831 |
+
label="Download Original Video",
|
832 |
+
data=open(video_path, 'rb').read(),
|
833 |
+
file_name="original_lip_sync_video.mp4",
|
834 |
+
mime="video/mp4"
|
835 |
+
)
|
836 |
+
except Exception as e:
|
837 |
+
st.error(f"Error loading video: {str(e)}")
|
838 |
+
|
839 |
+
with trim_tab:
|
840 |
+
st.write("You can trim the generated video to remove unwanted parts from the beginning or end.")
|
841 |
+
|
842 |
+
duration = get_video_duration(video_path)
|
843 |
+
if duration <= 0:
|
844 |
+
st.error("Could not determine video duration")
|
845 |
+
return
|
846 |
+
|
847 |
+
# Display video duration
|
848 |
+
st.write(f"Video duration: {duration:.2f} seconds")
|
849 |
+
|
850 |
+
# Create a slider for selecting start and end times
|
851 |
+
col1, col2 = st.columns(2)
|
852 |
+
|
853 |
+
with col1:
|
854 |
+
start_time = st.slider("Start time (seconds)",
|
855 |
+
min_value=0.0,
|
856 |
+
max_value=float(duration),
|
857 |
+
value=0.0,
|
858 |
+
step=0.1)
|
859 |
+
st.write(f"Start at: {start_time:.1f}s")
|
860 |
+
|
861 |
+
with col2:
|
862 |
+
end_time = st.slider("End time (seconds)",
|
863 |
+
min_value=0.0,
|
864 |
+
max_value=float(duration),
|
865 |
+
value=float(duration),
|
866 |
+
step=0.1)
|
867 |
+
st.write(f"End at: {end_time:.1f}s")
|
868 |
+
|
869 |
+
# Display trim duration
|
870 |
+
trim_duration = end_time - start_time
|
871 |
+
st.info(f"Trimmed video duration will be: {trim_duration:.1f} seconds")
|
872 |
+
|
873 |
+
# Validate the selected range
|
874 |
+
if start_time >= end_time:
|
875 |
+
st.error("Start time must be less than end time")
|
876 |
+
return
|
877 |
+
|
878 |
+
# Button to perform trimming
|
879 |
+
if st.button("Trim Video"):
|
880 |
+
# Generate output path
|
881 |
+
output_path = 'wav2lip/results/trimmed_video.mp4'
|
882 |
+
|
883 |
+
# Show progress
|
884 |
+
with st.spinner("Trimming video..."):
|
885 |
+
success = trim_video(video_path, output_path, start_time, end_time)
|
886 |
+
|
887 |
+
if success:
|
888 |
+
st.success("Video trimmed successfully!")
|
889 |
+
try:
|
890 |
+
st.video(output_path)
|
891 |
+
|
892 |
+
# Add download button for trimmed video
|
893 |
+
st.download_button(
|
894 |
+
label="Download Trimmed Video",
|
895 |
+
data=open(output_path, 'rb').read(),
|
896 |
+
file_name="trimmed_lip_sync_video.mp4",
|
897 |
+
mime="video/mp4"
|
898 |
+
)
|
899 |
+
except Exception as e:
|
900 |
+
st.error(f"Error displaying trimmed video: {str(e)}")
|
901 |
+
else:
|
902 |
+
st.error("Failed to trim video. Try again with different timing parameters.")
|
903 |
+
|
904 |
+
# Function to trim video using ffmpeg
|
905 |
+
def trim_video(input_path, output_path, start_time, end_time):
|
906 |
+
"""
|
907 |
+
Trim a video using ffmpeg from start_time to end_time.
|
908 |
+
|
909 |
+
Args:
|
910 |
+
input_path: Path to the input video
|
911 |
+
output_path: Path to save the trimmed video
|
912 |
+
start_time: Start time in seconds
|
913 |
+
end_time: End time in seconds
|
914 |
+
|
915 |
+
Returns:
|
916 |
+
bool: True if successful, False otherwise
|
917 |
+
"""
|
918 |
+
try:
|
919 |
+
# Check if input file exists
|
920 |
+
if not os.path.exists(input_path):
|
921 |
+
st.error(f"Input video not found: {input_path}")
|
922 |
+
return False
|
923 |
+
|
924 |
+
# Format the command - use -ss before -i for faster seeking
|
925 |
+
# Add quotes around file paths to handle spaces and special characters
|
926 |
+
command = f'ffmpeg -y -ss {start_time} -i "{input_path}" -to {end_time} -c:v copy -c:a copy "{output_path}"'
|
927 |
+
|
928 |
+
# Use subprocess.run for better error handling
|
929 |
+
result = subprocess.run(
|
930 |
+
command,
|
931 |
+
shell=True,
|
932 |
+
stdout=subprocess.PIPE,
|
933 |
+
stderr=subprocess.PIPE,
|
934 |
+
text=True
|
935 |
+
)
|
936 |
+
|
937 |
+
if result.returncode != 0:
|
938 |
+
st.error(f"FFMPEG error: {result.stderr}")
|
939 |
+
return False
|
940 |
+
|
941 |
+
# Verify the output file exists and has a size greater than 0
|
942 |
+
if os.path.exists(output_path) and os.path.getsize(output_path) > 0:
|
943 |
+
return True
|
944 |
+
else:
|
945 |
+
st.error("Output file was not created correctly")
|
946 |
+
return False
|
947 |
+
|
948 |
+
except Exception as e:
|
949 |
+
st.error(f"Error trimming video: {str(e)}")
|
950 |
+
return False
|
951 |
+
|
952 |
+
# Function to get video duration
|
953 |
+
def get_video_duration(video_path):
|
954 |
+
"""Get the duration of a video file in seconds"""
|
955 |
+
try:
|
956 |
+
video = cv2.VideoCapture(video_path)
|
957 |
+
fps = video.get(cv2.CAP_PROP_FPS)
|
958 |
+
frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
|
959 |
+
video.release()
|
960 |
+
|
961 |
+
duration = frame_count / fps if fps > 0 else 0
|
962 |
+
return duration
|
963 |
+
except Exception as e:
|
964 |
+
st.error(f"Error getting video duration: {str(e)}")
|
965 |
+
return 0
|
966 |
+
|
967 |
+
# Function to get audio duration
|
968 |
+
def get_audio_duration(audio_path):
|
969 |
+
"""Get the duration of an audio file in seconds"""
|
970 |
+
try:
|
971 |
+
y, sr = librosa.load(audio_path, sr=None)
|
972 |
+
duration = librosa.get_duration(y=y, sr=sr)
|
973 |
+
return duration
|
974 |
+
except Exception as e:
|
975 |
+
st.error(f"Error getting audio duration: {str(e)}")
|
976 |
+
return 0
|
977 |
+
|
978 |
+
# Function to trim audio file
|
979 |
+
def trim_audio(input_path, output_path, start_time, end_time):
|
980 |
+
"""Trim an audio file to the specified start and end times"""
|
981 |
+
try:
|
982 |
+
# Command to trim audio using ffmpeg
|
983 |
+
command = f'ffmpeg -y -i "{input_path}" -ss {start_time} -to {end_time} -c copy "{output_path}"'
|
984 |
+
|
985 |
+
# Execute the command
|
986 |
+
subprocess.call(command, shell=True)
|
987 |
+
|
988 |
+
# Check if output file exists
|
989 |
+
if os.path.exists(output_path):
|
990 |
+
return True
|
991 |
+
else:
|
992 |
+
st.error("Output audio file was not created correctly")
|
993 |
+
return False
|
994 |
+
|
995 |
+
except Exception as e:
|
996 |
+
st.error(f"Error trimming audio: {str(e)}")
|
997 |
+
return False
|
998 |
+
|
999 |
+
if __name__ == "__main__":
|
1000 |
+
main()
|
avatars_images/avatar1.jpg
ADDED
![]() |
avatars_images/avatar2.jpg
ADDED
![]() |
avatars_images/avatar3.png
ADDED
![]() |
packages.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
python3-opencv
|
2 |
+
libgl1-mesa-dev
|
3 |
+
ffmpeg
|
requirements.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy==1.26.3
|
2 |
+
scipy==1.12.0
|
3 |
+
iou==0.1.0
|
4 |
+
librosa==0.10.1
|
5 |
+
opencv_contrib_python==4.9.0.80
|
6 |
+
streamlit==1.31.0
|
7 |
+
streamlit_image_select==0.6.0
|
8 |
+
streamlit_mic_recorder==0.0.4
|
9 |
+
torch==2.2.1
|
10 |
+
tqdm==4.64.1
|
11 |
+
gdown
|
12 |
+
matplotlib==3.10.1
|
wav2lip/audio.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import librosa
|
2 |
+
import librosa.filters
|
3 |
+
import numpy as np
|
4 |
+
# import tensorflow as tf
|
5 |
+
from scipy import signal
|
6 |
+
from scipy.io import wavfile
|
7 |
+
from .hparams import hparams as hp
|
8 |
+
|
9 |
+
def load_wav(path, sr):
|
10 |
+
return librosa.core.load(path, sr=sr)[0]
|
11 |
+
|
12 |
+
def save_wav(wav, path, sr):
|
13 |
+
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
|
14 |
+
#proposed by @dsmiller
|
15 |
+
wavfile.write(path, sr, wav.astype(np.int16))
|
16 |
+
|
17 |
+
def save_wavenet_wav(wav, path, sr):
|
18 |
+
librosa.output.write_wav(path, wav, sr=sr)
|
19 |
+
|
20 |
+
def preemphasis(wav, k, preemphasize=True):
|
21 |
+
if preemphasize:
|
22 |
+
return signal.lfilter([1, -k], [1], wav)
|
23 |
+
return wav
|
24 |
+
|
25 |
+
def inv_preemphasis(wav, k, inv_preemphasize=True):
|
26 |
+
if inv_preemphasize:
|
27 |
+
return signal.lfilter([1], [1, -k], wav)
|
28 |
+
return wav
|
29 |
+
|
30 |
+
def get_hop_size():
|
31 |
+
hop_size = hp.hop_size
|
32 |
+
if hop_size is None:
|
33 |
+
assert hp.frame_shift_ms is not None
|
34 |
+
hop_size = int(hp.frame_shift_ms / 1000 * hp.sample_rate)
|
35 |
+
return hop_size
|
36 |
+
|
37 |
+
def linearspectrogram(wav):
|
38 |
+
D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
|
39 |
+
S = _amp_to_db(np.abs(D)) - hp.ref_level_db
|
40 |
+
|
41 |
+
if hp.signal_normalization:
|
42 |
+
return _normalize(S)
|
43 |
+
return S
|
44 |
+
|
45 |
+
def melspectrogram(wav):
|
46 |
+
D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
|
47 |
+
S = _amp_to_db(_linear_to_mel(np.abs(D))) - hp.ref_level_db
|
48 |
+
|
49 |
+
if hp.signal_normalization:
|
50 |
+
return _normalize(S)
|
51 |
+
return S
|
52 |
+
|
53 |
+
def _lws_processor():
|
54 |
+
import lws
|
55 |
+
return lws.lws(hp.n_fft, get_hop_size(), fftsize=hp.win_size, mode="speech")
|
56 |
+
|
57 |
+
def _stft(y):
|
58 |
+
if hp.use_lws:
|
59 |
+
return _lws_processor(hp).stft(y).T
|
60 |
+
else:
|
61 |
+
return librosa.stft(y=y, n_fft=hp.n_fft, hop_length=get_hop_size(), win_length=hp.win_size)
|
62 |
+
|
63 |
+
##########################################################
|
64 |
+
#Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!)
|
65 |
+
def num_frames(length, fsize, fshift):
|
66 |
+
"""Compute number of time frames of spectrogram
|
67 |
+
"""
|
68 |
+
pad = (fsize - fshift)
|
69 |
+
if length % fshift == 0:
|
70 |
+
M = (length + pad * 2 - fsize) // fshift + 1
|
71 |
+
else:
|
72 |
+
M = (length + pad * 2 - fsize) // fshift + 2
|
73 |
+
return M
|
74 |
+
|
75 |
+
|
76 |
+
def pad_lr(x, fsize, fshift):
|
77 |
+
"""Compute left and right padding
|
78 |
+
"""
|
79 |
+
M = num_frames(len(x), fsize, fshift)
|
80 |
+
pad = (fsize - fshift)
|
81 |
+
T = len(x) + 2 * pad
|
82 |
+
r = (M - 1) * fshift + fsize - T
|
83 |
+
return pad, pad + r
|
84 |
+
##########################################################
|
85 |
+
#Librosa correct padding
|
86 |
+
def librosa_pad_lr(x, fsize, fshift):
|
87 |
+
return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0]
|
88 |
+
|
89 |
+
# Conversions
|
90 |
+
_mel_basis = None
|
91 |
+
|
92 |
+
def _linear_to_mel(spectogram):
|
93 |
+
global _mel_basis
|
94 |
+
if _mel_basis is None:
|
95 |
+
_mel_basis = _build_mel_basis()
|
96 |
+
return np.dot(_mel_basis, spectogram)
|
97 |
+
|
98 |
+
|
99 |
+
def _build_mel_basis():
|
100 |
+
assert hp.fmax <= hp.sample_rate // 2
|
101 |
+
return librosa.filters.mel(sr=hp.sample_rate, n_fft=hp.n_fft, n_mels=hp.num_mels,
|
102 |
+
fmin=hp.fmin, fmax=hp.fmax)
|
103 |
+
def _amp_to_db(x):
|
104 |
+
min_level = np.exp(hp.min_level_db / 20 * np.log(10))
|
105 |
+
return 20 * np.log10(np.maximum(min_level, x))
|
106 |
+
|
107 |
+
def _db_to_amp(x):
|
108 |
+
return np.power(10.0, (x) * 0.05)
|
109 |
+
|
110 |
+
def _normalize(S):
|
111 |
+
if hp.allow_clipping_in_normalization:
|
112 |
+
if hp.symmetric_mels:
|
113 |
+
return np.clip((2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value,
|
114 |
+
-hp.max_abs_value, hp.max_abs_value)
|
115 |
+
else:
|
116 |
+
return np.clip(hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)), 0, hp.max_abs_value)
|
117 |
+
|
118 |
+
assert S.max() <= 0 and S.min() - hp.min_level_db >= 0
|
119 |
+
if hp.symmetric_mels:
|
120 |
+
return (2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value
|
121 |
+
else:
|
122 |
+
return hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db))
|
123 |
+
|
124 |
+
def _denormalize(D):
|
125 |
+
if hp.allow_clipping_in_normalization:
|
126 |
+
if hp.symmetric_mels:
|
127 |
+
return (((np.clip(D, -hp.max_abs_value,
|
128 |
+
hp.max_abs_value) + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value))
|
129 |
+
+ hp.min_level_db)
|
130 |
+
else:
|
131 |
+
return ((np.clip(D, 0, hp.max_abs_value) * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db)
|
132 |
+
|
133 |
+
if hp.symmetric_mels:
|
134 |
+
return (((D + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) + hp.min_level_db)
|
135 |
+
else:
|
136 |
+
return ((D * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db)
|
wav2lip/face_detection/README.md
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
The code for Face Detection in this folder has been taken from the wonderful [face_alignment](https://github.com/1adrianb/face-alignment) repository. This has been modified to take batches of faces at a time.
|
wav2lip/face_detection/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
__author__ = """Adrian Bulat"""
|
4 |
+
__email__ = '[email protected]'
|
5 |
+
__version__ = '1.0.1'
|
6 |
+
|
7 |
+
from .api import FaceAlignment, LandmarksType, NetworkSize
|
wav2lip/face_detection/api.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
from torch.utils.model_zoo import load_url
|
5 |
+
from enum import Enum
|
6 |
+
import numpy as np
|
7 |
+
import cv2
|
8 |
+
from .detection import sfd
|
9 |
+
try:
|
10 |
+
import urllib.request as request_file
|
11 |
+
except BaseException:
|
12 |
+
import urllib as request_file
|
13 |
+
|
14 |
+
from .models import FAN, ResNetDepth
|
15 |
+
from .utils import *
|
16 |
+
|
17 |
+
|
18 |
+
class LandmarksType(Enum):
|
19 |
+
"""Enum class defining the type of landmarks to detect.
|
20 |
+
|
21 |
+
``_2D`` - the detected points ``(x,y)`` are detected in a 2D space and follow the visible contour of the face
|
22 |
+
``_2halfD`` - this points represent the projection of the 3D points into 3D
|
23 |
+
``_3D`` - detect the points ``(x,y,z)``` in a 3D space
|
24 |
+
|
25 |
+
"""
|
26 |
+
_2D = 1
|
27 |
+
_2halfD = 2
|
28 |
+
_3D = 3
|
29 |
+
|
30 |
+
|
31 |
+
class NetworkSize(Enum):
|
32 |
+
# TINY = 1
|
33 |
+
# SMALL = 2
|
34 |
+
# MEDIUM = 3
|
35 |
+
LARGE = 4
|
36 |
+
|
37 |
+
def __new__(cls, value):
|
38 |
+
member = object.__new__(cls)
|
39 |
+
member._value_ = value
|
40 |
+
return member
|
41 |
+
|
42 |
+
def __int__(self):
|
43 |
+
return self.value
|
44 |
+
|
45 |
+
ROOT = os.path.dirname(os.path.abspath(__file__))
|
46 |
+
|
47 |
+
class FaceAlignment:
|
48 |
+
def __init__(self, landmarks_type, network_size=NetworkSize.LARGE,
|
49 |
+
device='cuda', flip_input=False, face_detector='sfd', verbose=False):
|
50 |
+
self.device = device
|
51 |
+
self.flip_input = flip_input
|
52 |
+
self.landmarks_type = landmarks_type
|
53 |
+
self.verbose = verbose
|
54 |
+
|
55 |
+
network_size = int(network_size)
|
56 |
+
|
57 |
+
if 'cuda' in device or 'mps' in device:
|
58 |
+
torch.backends.cudnn.benchmark = True
|
59 |
+
if 'mps' in device and verbose:
|
60 |
+
print("Using Apple Silicon GPU (MPS) for face detection.")
|
61 |
+
|
62 |
+
# Get the face detector
|
63 |
+
#face_detector_module = __import__('from .detection. import' + face_detector,
|
64 |
+
# globals(), locals(), [face_detector], 0)
|
65 |
+
#self.face_detector = face_detector_module.FaceDetector(device=device, verbose=verbose)
|
66 |
+
try:
|
67 |
+
self.face_detector = sfd.FaceDetector(device=device, verbose=verbose)
|
68 |
+
except Exception as e:
|
69 |
+
print(f"Error initializing face detector: {e}")
|
70 |
+
print("Falling back to CPU for face detection.")
|
71 |
+
# If detection fails on GPU (MPS/CUDA), fall back to CPU
|
72 |
+
self.device = 'cpu'
|
73 |
+
self.face_detector = sfd.FaceDetector(device='cpu', verbose=verbose)
|
74 |
+
|
75 |
+
def get_detections_for_batch(self, images):
|
76 |
+
"""
|
77 |
+
Returns a list of bounding boxes for each image in the batch.
|
78 |
+
If no face is detected, returns None for that image.
|
79 |
+
"""
|
80 |
+
try:
|
81 |
+
# Convert to RGB for face detection
|
82 |
+
images = images.copy()
|
83 |
+
if images.shape[-1] == 3:
|
84 |
+
images = images[..., ::-1] # BGR to RGB
|
85 |
+
|
86 |
+
# Get face detections
|
87 |
+
detected_faces = self.face_detector.detect_from_batch(images)
|
88 |
+
|
89 |
+
results = []
|
90 |
+
for i, d in enumerate(detected_faces):
|
91 |
+
if len(d) == 0:
|
92 |
+
# No face detected
|
93 |
+
results.append(None)
|
94 |
+
continue
|
95 |
+
|
96 |
+
# Use the first (highest confidence) face
|
97 |
+
d = d[0]
|
98 |
+
# Ensure values are valid
|
99 |
+
d = np.clip(d, 0, None)
|
100 |
+
|
101 |
+
# Extract coordinates
|
102 |
+
try:
|
103 |
+
x1, y1, x2, y2 = map(int, d[:-1])
|
104 |
+
# Sanity check on coordinates
|
105 |
+
if x1 >= x2 or y1 >= y2 or x1 < 0 or y1 < 0:
|
106 |
+
print(f"Invalid face coordinates: {(x1, y1, x2, y2)}")
|
107 |
+
results.append(None)
|
108 |
+
else:
|
109 |
+
results.append((x1, y1, x2, y2))
|
110 |
+
except Exception as e:
|
111 |
+
print(f"Error processing detection: {str(e)}")
|
112 |
+
results.append(None)
|
113 |
+
|
114 |
+
return results
|
115 |
+
|
116 |
+
except Exception as e:
|
117 |
+
print(f"Error in batch face detection: {str(e)}")
|
118 |
+
# Return None for all images
|
119 |
+
return [None] * len(images)
|
wav2lip/face_detection/detection/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .core import FaceDetector
|
wav2lip/face_detection/detection/core.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import glob
|
3 |
+
from tqdm import tqdm
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import cv2
|
7 |
+
|
8 |
+
|
9 |
+
class FaceDetector(object):
|
10 |
+
"""An abstract class representing a face detector.
|
11 |
+
|
12 |
+
Any other face detection implementation must subclass it. All subclasses
|
13 |
+
must implement ``detect_from_image``, that return a list of detected
|
14 |
+
bounding boxes. Optionally, for speed considerations detect from path is
|
15 |
+
recommended.
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, device, verbose):
|
19 |
+
self.device = device
|
20 |
+
self.verbose = verbose
|
21 |
+
|
22 |
+
if verbose:
|
23 |
+
if 'cpu' in device:
|
24 |
+
logger = logging.getLogger(__name__)
|
25 |
+
logger.warning("Detection running on CPU, this may be potentially slow.")
|
26 |
+
elif 'mps' in device:
|
27 |
+
logger = logging.getLogger(__name__)
|
28 |
+
logger.info("Detection running on Apple Silicon GPU (MPS).")
|
29 |
+
|
30 |
+
if 'cpu' not in device and 'cuda' not in device and 'mps' not in device:
|
31 |
+
logger = logging.getLogger(__name__)
|
32 |
+
if verbose:
|
33 |
+
logger.error("Expected values for device are: {cpu, cuda, mps} but got: %s", device)
|
34 |
+
raise ValueError(f"Invalid device type: {device}. Expected 'cpu', 'cuda', or 'mps'.")
|
35 |
+
|
36 |
+
def detect_from_image(self, tensor_or_path):
|
37 |
+
"""Detects faces in a given image.
|
38 |
+
|
39 |
+
This function detects the faces present in a provided BGR(usually)
|
40 |
+
image. The input can be either the image itself or the path to it.
|
41 |
+
|
42 |
+
Arguments:
|
43 |
+
tensor_or_path {numpy.ndarray, torch.tensor or string} -- the path
|
44 |
+
to an image or the image itself.
|
45 |
+
|
46 |
+
Example::
|
47 |
+
|
48 |
+
>>> path_to_image = 'data/image_01.jpg'
|
49 |
+
... detected_faces = detect_from_image(path_to_image)
|
50 |
+
[A list of bounding boxes (x1, y1, x2, y2)]
|
51 |
+
>>> image = cv2.imread(path_to_image)
|
52 |
+
... detected_faces = detect_from_image(image)
|
53 |
+
[A list of bounding boxes (x1, y1, x2, y2)]
|
54 |
+
|
55 |
+
"""
|
56 |
+
raise NotImplementedError
|
57 |
+
|
58 |
+
def detect_from_directory(self, path, extensions=['.jpg', '.png'], recursive=False, show_progress_bar=True):
|
59 |
+
"""Detects faces from all the images present in a given directory.
|
60 |
+
|
61 |
+
Arguments:
|
62 |
+
path {string} -- a string containing a path that points to the folder containing the images
|
63 |
+
|
64 |
+
Keyword Arguments:
|
65 |
+
extensions {list} -- list of string containing the extensions to be
|
66 |
+
consider in the following format: ``.extension_name`` (default:
|
67 |
+
{['.jpg', '.png']}) recursive {bool} -- option wherever to scan the
|
68 |
+
folder recursively (default: {False}) show_progress_bar {bool} --
|
69 |
+
display a progressbar (default: {True})
|
70 |
+
|
71 |
+
Example:
|
72 |
+
>>> directory = 'data'
|
73 |
+
... detected_faces = detect_from_directory(directory)
|
74 |
+
{A dictionary of [lists containing bounding boxes(x1, y1, x2, y2)]}
|
75 |
+
|
76 |
+
"""
|
77 |
+
if self.verbose:
|
78 |
+
logger = logging.getLogger(__name__)
|
79 |
+
|
80 |
+
if len(extensions) == 0:
|
81 |
+
if self.verbose:
|
82 |
+
logger.error("Expected at list one extension, but none was received.")
|
83 |
+
raise ValueError
|
84 |
+
|
85 |
+
if self.verbose:
|
86 |
+
logger.info("Constructing the list of images.")
|
87 |
+
additional_pattern = '/**/*' if recursive else '/*'
|
88 |
+
files = []
|
89 |
+
for extension in extensions:
|
90 |
+
files.extend(glob.glob(path + additional_pattern + extension, recursive=recursive))
|
91 |
+
|
92 |
+
if self.verbose:
|
93 |
+
logger.info("Finished searching for images. %s images found", len(files))
|
94 |
+
logger.info("Preparing to run the detection.")
|
95 |
+
|
96 |
+
predictions = {}
|
97 |
+
for image_path in tqdm(files, disable=not show_progress_bar):
|
98 |
+
if self.verbose:
|
99 |
+
logger.info("Running the face detector on image: %s", image_path)
|
100 |
+
predictions[image_path] = self.detect_from_image(image_path)
|
101 |
+
|
102 |
+
if self.verbose:
|
103 |
+
logger.info("The detector was successfully run on all %s images", len(files))
|
104 |
+
|
105 |
+
return predictions
|
106 |
+
|
107 |
+
@property
|
108 |
+
def reference_scale(self):
|
109 |
+
raise NotImplementedError
|
110 |
+
|
111 |
+
@property
|
112 |
+
def reference_x_shift(self):
|
113 |
+
raise NotImplementedError
|
114 |
+
|
115 |
+
@property
|
116 |
+
def reference_y_shift(self):
|
117 |
+
raise NotImplementedError
|
118 |
+
|
119 |
+
@staticmethod
|
120 |
+
def tensor_or_path_to_ndarray(tensor_or_path, rgb=True):
|
121 |
+
"""Convert path (represented as a string) or torch.tensor to a numpy.ndarray
|
122 |
+
|
123 |
+
Arguments:
|
124 |
+
tensor_or_path {numpy.ndarray, torch.tensor or string} -- path to the image, or the image itself
|
125 |
+
"""
|
126 |
+
if isinstance(tensor_or_path, str):
|
127 |
+
return cv2.imread(tensor_or_path) if not rgb else cv2.imread(tensor_or_path)[..., ::-1]
|
128 |
+
elif torch.is_tensor(tensor_or_path):
|
129 |
+
# Call cpu in case its coming from cuda
|
130 |
+
return tensor_or_path.cpu().numpy()[..., ::-1].copy() if not rgb else tensor_or_path.cpu().numpy()
|
131 |
+
elif isinstance(tensor_or_path, np.ndarray):
|
132 |
+
return tensor_or_path[..., ::-1].copy() if not rgb else tensor_or_path
|
133 |
+
else:
|
134 |
+
raise TypeError
|
wav2lip/face_detection/detection/sfd/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .sfd_detector import SFDDetector as FaceDetector
|
wav2lip/face_detection/detection/sfd/bbox.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import cv2
|
5 |
+
import random
|
6 |
+
import datetime
|
7 |
+
import time
|
8 |
+
import math
|
9 |
+
import argparse
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
|
13 |
+
try:
|
14 |
+
from iou import IOU
|
15 |
+
except BaseException:
|
16 |
+
# IOU cython speedup 10x
|
17 |
+
def IOU(ax1, ay1, ax2, ay2, bx1, by1, bx2, by2):
|
18 |
+
sa = abs((ax2 - ax1) * (ay2 - ay1))
|
19 |
+
sb = abs((bx2 - bx1) * (by2 - by1))
|
20 |
+
x1, y1 = max(ax1, bx1), max(ay1, by1)
|
21 |
+
x2, y2 = min(ax2, bx2), min(ay2, by2)
|
22 |
+
w = x2 - x1
|
23 |
+
h = y2 - y1
|
24 |
+
if w < 0 or h < 0:
|
25 |
+
return 0.0
|
26 |
+
else:
|
27 |
+
return 1.0 * w * h / (sa + sb - w * h)
|
28 |
+
|
29 |
+
|
30 |
+
def bboxlog(x1, y1, x2, y2, axc, ayc, aww, ahh):
|
31 |
+
xc, yc, ww, hh = (x2 + x1) / 2, (y2 + y1) / 2, x2 - x1, y2 - y1
|
32 |
+
dx, dy = (xc - axc) / aww, (yc - ayc) / ahh
|
33 |
+
dw, dh = math.log(ww / aww), math.log(hh / ahh)
|
34 |
+
return dx, dy, dw, dh
|
35 |
+
|
36 |
+
|
37 |
+
def bboxloginv(dx, dy, dw, dh, axc, ayc, aww, ahh):
|
38 |
+
xc, yc = dx * aww + axc, dy * ahh + ayc
|
39 |
+
ww, hh = math.exp(dw) * aww, math.exp(dh) * ahh
|
40 |
+
x1, x2, y1, y2 = xc - ww / 2, xc + ww / 2, yc - hh / 2, yc + hh / 2
|
41 |
+
return x1, y1, x2, y2
|
42 |
+
|
43 |
+
|
44 |
+
def nms(dets, thresh):
|
45 |
+
if 0 == len(dets):
|
46 |
+
return []
|
47 |
+
x1, y1, x2, y2, scores = dets[:, 0], dets[:, 1], dets[:, 2], dets[:, 3], dets[:, 4]
|
48 |
+
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
49 |
+
order = scores.argsort()[::-1]
|
50 |
+
|
51 |
+
keep = []
|
52 |
+
while order.size > 0:
|
53 |
+
i = order[0]
|
54 |
+
keep.append(i)
|
55 |
+
xx1, yy1 = np.maximum(x1[i], x1[order[1:]]), np.maximum(y1[i], y1[order[1:]])
|
56 |
+
xx2, yy2 = np.minimum(x2[i], x2[order[1:]]), np.minimum(y2[i], y2[order[1:]])
|
57 |
+
|
58 |
+
w, h = np.maximum(0.0, xx2 - xx1 + 1), np.maximum(0.0, yy2 - yy1 + 1)
|
59 |
+
ovr = w * h / (areas[i] + areas[order[1:]] - w * h)
|
60 |
+
|
61 |
+
inds = np.where(ovr <= thresh)[0]
|
62 |
+
order = order[inds + 1]
|
63 |
+
|
64 |
+
return keep
|
65 |
+
|
66 |
+
|
67 |
+
def encode(matched, priors, variances):
|
68 |
+
"""Encode the variances from the priorbox layers into the ground truth boxes
|
69 |
+
we have matched (based on jaccard overlap) with the prior boxes.
|
70 |
+
Args:
|
71 |
+
matched: (tensor) Coords of ground truth for each prior in point-form
|
72 |
+
Shape: [num_priors, 4].
|
73 |
+
priors: (tensor) Prior boxes in center-offset form
|
74 |
+
Shape: [num_priors,4].
|
75 |
+
variances: (list[float]) Variances of priorboxes
|
76 |
+
Return:
|
77 |
+
encoded boxes (tensor), Shape: [num_priors, 4]
|
78 |
+
"""
|
79 |
+
|
80 |
+
# dist b/t match center and prior's center
|
81 |
+
g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2]
|
82 |
+
# encode variance
|
83 |
+
g_cxcy /= (variances[0] * priors[:, 2:])
|
84 |
+
# match wh / prior wh
|
85 |
+
g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:]
|
86 |
+
g_wh = torch.log(g_wh) / variances[1]
|
87 |
+
# return target for smooth_l1_loss
|
88 |
+
return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4]
|
89 |
+
|
90 |
+
|
91 |
+
def decode(loc, priors, variances):
|
92 |
+
"""Decode locations from predictions using priors to undo
|
93 |
+
the encoding we did for offset regression at train time.
|
94 |
+
Args:
|
95 |
+
loc (tensor): location predictions for loc layers,
|
96 |
+
Shape: [num_priors,4]
|
97 |
+
priors (tensor): Prior boxes in center-offset form.
|
98 |
+
Shape: [num_priors,4].
|
99 |
+
variances: (list[float]) Variances of priorboxes
|
100 |
+
Return:
|
101 |
+
decoded bounding box predictions
|
102 |
+
"""
|
103 |
+
|
104 |
+
boxes = torch.cat((
|
105 |
+
priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
|
106 |
+
priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)
|
107 |
+
boxes[:, :2] -= boxes[:, 2:] / 2
|
108 |
+
boxes[:, 2:] += boxes[:, :2]
|
109 |
+
return boxes
|
110 |
+
|
111 |
+
def batch_decode(loc, priors, variances):
|
112 |
+
"""Decode locations from predictions using priors to undo
|
113 |
+
the encoding we did for offset regression at train time.
|
114 |
+
Args:
|
115 |
+
loc (tensor): location predictions for loc layers,
|
116 |
+
Shape: [num_priors,4]
|
117 |
+
priors (tensor): Prior boxes in center-offset form.
|
118 |
+
Shape: [num_priors,4].
|
119 |
+
variances: (list[float]) Variances of priorboxes
|
120 |
+
Return:
|
121 |
+
decoded bounding box predictions
|
122 |
+
"""
|
123 |
+
|
124 |
+
boxes = torch.cat((
|
125 |
+
priors[:, :, :2] + loc[:, :, :2] * variances[0] * priors[:, :, 2:],
|
126 |
+
priors[:, :, 2:] * torch.exp(loc[:, :, 2:] * variances[1])), 2)
|
127 |
+
boxes[:, :, :2] -= boxes[:, :, 2:] / 2
|
128 |
+
boxes[:, :, 2:] += boxes[:, :, :2]
|
129 |
+
return boxes
|
wav2lip/face_detection/detection/sfd/detect.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
import cv2
|
7 |
+
import random
|
8 |
+
import datetime
|
9 |
+
import math
|
10 |
+
import argparse
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
import scipy.io as sio
|
14 |
+
import zipfile
|
15 |
+
from .net_s3fd import s3fd
|
16 |
+
from .bbox import *
|
17 |
+
|
18 |
+
|
19 |
+
def detect(net, img, device):
|
20 |
+
img = img - np.array([104, 117, 123])
|
21 |
+
img = img.transpose(2, 0, 1)
|
22 |
+
img = img.reshape((1,) + img.shape)
|
23 |
+
|
24 |
+
if 'cuda' in device or 'mps' in device:
|
25 |
+
torch.backends.cudnn.benchmark = True
|
26 |
+
|
27 |
+
img = torch.from_numpy(img).float().to(device)
|
28 |
+
BB, CC, HH, WW = img.size()
|
29 |
+
with torch.no_grad():
|
30 |
+
olist = net(img)
|
31 |
+
|
32 |
+
bboxlist = []
|
33 |
+
for i in range(len(olist) // 2):
|
34 |
+
olist[i * 2] = F.softmax(olist[i * 2], dim=1)
|
35 |
+
olist = [oelem.data.cpu() for oelem in olist]
|
36 |
+
for i in range(len(olist) // 2):
|
37 |
+
ocls, oreg = olist[i * 2], olist[i * 2 + 1]
|
38 |
+
FB, FC, FH, FW = ocls.size() # feature map size
|
39 |
+
stride = 2**(i + 2) # 4,8,16,32,64,128
|
40 |
+
anchor = stride * 4
|
41 |
+
poss = zip(*np.where(ocls[:, 1, :, :] > 0.05))
|
42 |
+
for Iindex, hindex, windex in poss:
|
43 |
+
axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride
|
44 |
+
score = ocls[0, 1, hindex, windex]
|
45 |
+
loc = oreg[0, :, hindex, windex].contiguous().view(1, 4)
|
46 |
+
priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]])
|
47 |
+
variances = [0.1, 0.2]
|
48 |
+
box = decode(loc, priors, variances)
|
49 |
+
x1, y1, x2, y2 = box[0] * 1.0
|
50 |
+
# cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1)
|
51 |
+
bboxlist.append([x1, y1, x2, y2, score])
|
52 |
+
bboxlist = np.array(bboxlist)
|
53 |
+
if 0 == len(bboxlist):
|
54 |
+
bboxlist = np.zeros((1, 5))
|
55 |
+
|
56 |
+
return bboxlist
|
57 |
+
|
58 |
+
def batch_detect(net, imgs, device):
|
59 |
+
imgs = imgs - np.array([104, 117, 123])
|
60 |
+
imgs = imgs.transpose(0, 3, 1, 2)
|
61 |
+
|
62 |
+
if 'cuda' in device or 'mps' in device:
|
63 |
+
torch.backends.cudnn.benchmark = True
|
64 |
+
|
65 |
+
imgs = torch.from_numpy(imgs).float().to(device)
|
66 |
+
BB, CC, HH, WW = imgs.size()
|
67 |
+
with torch.no_grad():
|
68 |
+
olist = net(imgs)
|
69 |
+
|
70 |
+
bboxlist = []
|
71 |
+
for i in range(len(olist) // 2):
|
72 |
+
olist[i * 2] = F.softmax(olist[i * 2], dim=1)
|
73 |
+
olist = [oelem.data.cpu() for oelem in olist]
|
74 |
+
for i in range(len(olist) // 2):
|
75 |
+
ocls, oreg = olist[i * 2], olist[i * 2 + 1]
|
76 |
+
FB, FC, FH, FW = ocls.size() # feature map size
|
77 |
+
stride = 2**(i + 2) # 4,8,16,32,64,128
|
78 |
+
anchor = stride * 4
|
79 |
+
poss = zip(*np.where(ocls[:, 1, :, :] > 0.05))
|
80 |
+
for Iindex, hindex, windex in poss:
|
81 |
+
axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride
|
82 |
+
score = ocls[:, 1, hindex, windex]
|
83 |
+
loc = oreg[:, :, hindex, windex].contiguous().view(BB, 1, 4)
|
84 |
+
priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]]).view(1, 1, 4)
|
85 |
+
variances = [0.1, 0.2]
|
86 |
+
box = batch_decode(loc, priors, variances)
|
87 |
+
box = box[:, 0] * 1.0
|
88 |
+
# cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1)
|
89 |
+
bboxlist.append(torch.cat([box, score.unsqueeze(1)], 1).cpu().numpy())
|
90 |
+
bboxlist = np.array(bboxlist)
|
91 |
+
if 0 == len(bboxlist):
|
92 |
+
bboxlist = np.zeros((1, BB, 5))
|
93 |
+
|
94 |
+
return bboxlist
|
95 |
+
|
96 |
+
def flip_detect(net, img, device):
|
97 |
+
img = cv2.flip(img, 1)
|
98 |
+
b = detect(net, img, device)
|
99 |
+
|
100 |
+
bboxlist = np.zeros(b.shape)
|
101 |
+
bboxlist[:, 0] = img.shape[1] - b[:, 2]
|
102 |
+
bboxlist[:, 1] = b[:, 1]
|
103 |
+
bboxlist[:, 2] = img.shape[1] - b[:, 0]
|
104 |
+
bboxlist[:, 3] = b[:, 3]
|
105 |
+
bboxlist[:, 4] = b[:, 4]
|
106 |
+
return bboxlist
|
107 |
+
|
108 |
+
|
109 |
+
def pts_to_bb(pts):
|
110 |
+
min_x, min_y = np.min(pts, axis=0)
|
111 |
+
max_x, max_y = np.max(pts, axis=0)
|
112 |
+
return np.array([min_x, min_y, max_x, max_y])
|
wav2lip/face_detection/detection/sfd/net_s3fd.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
class L2Norm(nn.Module):
|
7 |
+
def __init__(self, n_channels, scale=1.0):
|
8 |
+
super(L2Norm, self).__init__()
|
9 |
+
self.n_channels = n_channels
|
10 |
+
self.scale = scale
|
11 |
+
self.eps = 1e-10
|
12 |
+
self.weight = nn.Parameter(torch.Tensor(self.n_channels))
|
13 |
+
self.weight.data *= 0.0
|
14 |
+
self.weight.data += self.scale
|
15 |
+
|
16 |
+
def forward(self, x):
|
17 |
+
norm = x.pow(2).sum(dim=1, keepdim=True).sqrt() + self.eps
|
18 |
+
x = x / norm * self.weight.view(1, -1, 1, 1)
|
19 |
+
return x
|
20 |
+
|
21 |
+
|
22 |
+
class s3fd(nn.Module):
|
23 |
+
def __init__(self):
|
24 |
+
super(s3fd, self).__init__()
|
25 |
+
self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
|
26 |
+
self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
|
27 |
+
|
28 |
+
self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
|
29 |
+
self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
|
30 |
+
|
31 |
+
self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
|
32 |
+
self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
|
33 |
+
self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
|
34 |
+
|
35 |
+
self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
|
36 |
+
self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
37 |
+
self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
38 |
+
|
39 |
+
self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
40 |
+
self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
41 |
+
self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
42 |
+
|
43 |
+
self.fc6 = nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=3)
|
44 |
+
self.fc7 = nn.Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0)
|
45 |
+
|
46 |
+
self.conv6_1 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0)
|
47 |
+
self.conv6_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1)
|
48 |
+
|
49 |
+
self.conv7_1 = nn.Conv2d(512, 128, kernel_size=1, stride=1, padding=0)
|
50 |
+
self.conv7_2 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
|
51 |
+
|
52 |
+
self.conv3_3_norm = L2Norm(256, scale=10)
|
53 |
+
self.conv4_3_norm = L2Norm(512, scale=8)
|
54 |
+
self.conv5_3_norm = L2Norm(512, scale=5)
|
55 |
+
|
56 |
+
self.conv3_3_norm_mbox_conf = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
|
57 |
+
self.conv3_3_norm_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
|
58 |
+
self.conv4_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
|
59 |
+
self.conv4_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
|
60 |
+
self.conv5_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
|
61 |
+
self.conv5_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
|
62 |
+
|
63 |
+
self.fc7_mbox_conf = nn.Conv2d(1024, 2, kernel_size=3, stride=1, padding=1)
|
64 |
+
self.fc7_mbox_loc = nn.Conv2d(1024, 4, kernel_size=3, stride=1, padding=1)
|
65 |
+
self.conv6_2_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
|
66 |
+
self.conv6_2_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
|
67 |
+
self.conv7_2_mbox_conf = nn.Conv2d(256, 2, kernel_size=3, stride=1, padding=1)
|
68 |
+
self.conv7_2_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
|
69 |
+
|
70 |
+
def forward(self, x):
|
71 |
+
h = F.relu(self.conv1_1(x))
|
72 |
+
h = F.relu(self.conv1_2(h))
|
73 |
+
h = F.max_pool2d(h, 2, 2)
|
74 |
+
|
75 |
+
h = F.relu(self.conv2_1(h))
|
76 |
+
h = F.relu(self.conv2_2(h))
|
77 |
+
h = F.max_pool2d(h, 2, 2)
|
78 |
+
|
79 |
+
h = F.relu(self.conv3_1(h))
|
80 |
+
h = F.relu(self.conv3_2(h))
|
81 |
+
h = F.relu(self.conv3_3(h))
|
82 |
+
f3_3 = h
|
83 |
+
h = F.max_pool2d(h, 2, 2)
|
84 |
+
|
85 |
+
h = F.relu(self.conv4_1(h))
|
86 |
+
h = F.relu(self.conv4_2(h))
|
87 |
+
h = F.relu(self.conv4_3(h))
|
88 |
+
f4_3 = h
|
89 |
+
h = F.max_pool2d(h, 2, 2)
|
90 |
+
|
91 |
+
h = F.relu(self.conv5_1(h))
|
92 |
+
h = F.relu(self.conv5_2(h))
|
93 |
+
h = F.relu(self.conv5_3(h))
|
94 |
+
f5_3 = h
|
95 |
+
h = F.max_pool2d(h, 2, 2)
|
96 |
+
|
97 |
+
h = F.relu(self.fc6(h))
|
98 |
+
h = F.relu(self.fc7(h))
|
99 |
+
ffc7 = h
|
100 |
+
h = F.relu(self.conv6_1(h))
|
101 |
+
h = F.relu(self.conv6_2(h))
|
102 |
+
f6_2 = h
|
103 |
+
h = F.relu(self.conv7_1(h))
|
104 |
+
h = F.relu(self.conv7_2(h))
|
105 |
+
f7_2 = h
|
106 |
+
|
107 |
+
f3_3 = self.conv3_3_norm(f3_3)
|
108 |
+
f4_3 = self.conv4_3_norm(f4_3)
|
109 |
+
f5_3 = self.conv5_3_norm(f5_3)
|
110 |
+
|
111 |
+
cls1 = self.conv3_3_norm_mbox_conf(f3_3)
|
112 |
+
reg1 = self.conv3_3_norm_mbox_loc(f3_3)
|
113 |
+
cls2 = self.conv4_3_norm_mbox_conf(f4_3)
|
114 |
+
reg2 = self.conv4_3_norm_mbox_loc(f4_3)
|
115 |
+
cls3 = self.conv5_3_norm_mbox_conf(f5_3)
|
116 |
+
reg3 = self.conv5_3_norm_mbox_loc(f5_3)
|
117 |
+
cls4 = self.fc7_mbox_conf(ffc7)
|
118 |
+
reg4 = self.fc7_mbox_loc(ffc7)
|
119 |
+
cls5 = self.conv6_2_mbox_conf(f6_2)
|
120 |
+
reg5 = self.conv6_2_mbox_loc(f6_2)
|
121 |
+
cls6 = self.conv7_2_mbox_conf(f7_2)
|
122 |
+
reg6 = self.conv7_2_mbox_loc(f7_2)
|
123 |
+
|
124 |
+
# max-out background label
|
125 |
+
chunk = torch.chunk(cls1, 4, 1)
|
126 |
+
bmax = torch.max(torch.max(chunk[0], chunk[1]), chunk[2])
|
127 |
+
cls1 = torch.cat([bmax, chunk[3]], dim=1)
|
128 |
+
|
129 |
+
return [cls1, reg1, cls2, reg2, cls3, reg3, cls4, reg4, cls5, reg5, cls6, reg6]
|
wav2lip/face_detection/detection/sfd/sfd_detector.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
from torch.utils.model_zoo import load_url
|
4 |
+
|
5 |
+
from ..core import FaceDetector
|
6 |
+
|
7 |
+
from .net_s3fd import s3fd
|
8 |
+
from .bbox import *
|
9 |
+
from .detect import *
|
10 |
+
|
11 |
+
models_urls = {
|
12 |
+
's3fd': 'https://www.adrianbulat.com/downloads/python-fan/s3fd-619a316812.pth',
|
13 |
+
}
|
14 |
+
|
15 |
+
|
16 |
+
class SFDDetector(FaceDetector):
|
17 |
+
def __init__(self, device, path_to_detector=os.path.join(os.path.dirname(os.path.abspath(__file__)), 's3fd.pth'), verbose=False):
|
18 |
+
super(SFDDetector, self).__init__(device, verbose)
|
19 |
+
|
20 |
+
# Initialise the face detector
|
21 |
+
try:
|
22 |
+
if not os.path.isfile(path_to_detector):
|
23 |
+
model_weights = load_url(models_urls['s3fd'])
|
24 |
+
else:
|
25 |
+
# For MPS (Apple Silicon), we need to load to CPU first
|
26 |
+
if 'mps' in device:
|
27 |
+
model_weights = torch.load(path_to_detector, map_location='cpu')
|
28 |
+
else:
|
29 |
+
model_weights = torch.load(path_to_detector, map_location=device)
|
30 |
+
|
31 |
+
self.face_detector = s3fd()
|
32 |
+
self.face_detector.load_state_dict(model_weights)
|
33 |
+
self.face_detector.to(device)
|
34 |
+
self.face_detector.eval()
|
35 |
+
|
36 |
+
if verbose:
|
37 |
+
print(f"Face detector loaded successfully and moved to {device}")
|
38 |
+
|
39 |
+
except Exception as e:
|
40 |
+
if verbose:
|
41 |
+
print(f"Error loading face detector model: {str(e)}")
|
42 |
+
raise
|
43 |
+
|
44 |
+
def detect_from_image(self, tensor_or_path):
|
45 |
+
image = self.tensor_or_path_to_ndarray(tensor_or_path)
|
46 |
+
|
47 |
+
bboxlist = detect(self.face_detector, image, device=self.device)
|
48 |
+
keep = nms(bboxlist, 0.3)
|
49 |
+
bboxlist = bboxlist[keep, :]
|
50 |
+
bboxlist = [x for x in bboxlist if x[-1] > 0.5]
|
51 |
+
|
52 |
+
return bboxlist
|
53 |
+
|
54 |
+
def detect_from_batch(self, images):
|
55 |
+
bboxlists = batch_detect(self.face_detector, images, device=self.device)
|
56 |
+
keeps = [nms(bboxlists[:, i, :], 0.3) for i in range(bboxlists.shape[1])]
|
57 |
+
bboxlists = [bboxlists[keep, i, :] for i, keep in enumerate(keeps)]
|
58 |
+
bboxlists = [[x for x in bboxlist if x[-1] > 0.5] for bboxlist in bboxlists]
|
59 |
+
|
60 |
+
return bboxlists
|
61 |
+
|
62 |
+
@property
|
63 |
+
def reference_scale(self):
|
64 |
+
return 195
|
65 |
+
|
66 |
+
@property
|
67 |
+
def reference_x_shift(self):
|
68 |
+
return 0
|
69 |
+
|
70 |
+
@property
|
71 |
+
def reference_y_shift(self):
|
72 |
+
return 0
|
wav2lip/face_detection/models.py
ADDED
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import math
|
5 |
+
|
6 |
+
|
7 |
+
def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False):
|
8 |
+
"3x3 convolution with padding"
|
9 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3,
|
10 |
+
stride=strd, padding=padding, bias=bias)
|
11 |
+
|
12 |
+
|
13 |
+
class ConvBlock(nn.Module):
|
14 |
+
def __init__(self, in_planes, out_planes):
|
15 |
+
super(ConvBlock, self).__init__()
|
16 |
+
self.bn1 = nn.BatchNorm2d(in_planes)
|
17 |
+
self.conv1 = conv3x3(in_planes, int(out_planes / 2))
|
18 |
+
self.bn2 = nn.BatchNorm2d(int(out_planes / 2))
|
19 |
+
self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4))
|
20 |
+
self.bn3 = nn.BatchNorm2d(int(out_planes / 4))
|
21 |
+
self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4))
|
22 |
+
|
23 |
+
if in_planes != out_planes:
|
24 |
+
self.downsample = nn.Sequential(
|
25 |
+
nn.BatchNorm2d(in_planes),
|
26 |
+
nn.ReLU(True),
|
27 |
+
nn.Conv2d(in_planes, out_planes,
|
28 |
+
kernel_size=1, stride=1, bias=False),
|
29 |
+
)
|
30 |
+
else:
|
31 |
+
self.downsample = None
|
32 |
+
|
33 |
+
def forward(self, x):
|
34 |
+
residual = x
|
35 |
+
|
36 |
+
out1 = self.bn1(x)
|
37 |
+
out1 = F.relu(out1, True)
|
38 |
+
out1 = self.conv1(out1)
|
39 |
+
|
40 |
+
out2 = self.bn2(out1)
|
41 |
+
out2 = F.relu(out2, True)
|
42 |
+
out2 = self.conv2(out2)
|
43 |
+
|
44 |
+
out3 = self.bn3(out2)
|
45 |
+
out3 = F.relu(out3, True)
|
46 |
+
out3 = self.conv3(out3)
|
47 |
+
|
48 |
+
out3 = torch.cat((out1, out2, out3), 1)
|
49 |
+
|
50 |
+
if self.downsample is not None:
|
51 |
+
residual = self.downsample(residual)
|
52 |
+
|
53 |
+
out3 += residual
|
54 |
+
|
55 |
+
return out3
|
56 |
+
|
57 |
+
|
58 |
+
class Bottleneck(nn.Module):
|
59 |
+
|
60 |
+
expansion = 4
|
61 |
+
|
62 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
63 |
+
super(Bottleneck, self).__init__()
|
64 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
65 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
66 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
|
67 |
+
padding=1, bias=False)
|
68 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
69 |
+
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
|
70 |
+
self.bn3 = nn.BatchNorm2d(planes * 4)
|
71 |
+
self.relu = nn.ReLU(inplace=True)
|
72 |
+
self.downsample = downsample
|
73 |
+
self.stride = stride
|
74 |
+
|
75 |
+
def forward(self, x):
|
76 |
+
residual = x
|
77 |
+
|
78 |
+
out = self.conv1(x)
|
79 |
+
out = self.bn1(out)
|
80 |
+
out = self.relu(out)
|
81 |
+
|
82 |
+
out = self.conv2(out)
|
83 |
+
out = self.bn2(out)
|
84 |
+
out = self.relu(out)
|
85 |
+
|
86 |
+
out = self.conv3(out)
|
87 |
+
out = self.bn3(out)
|
88 |
+
|
89 |
+
if self.downsample is not None:
|
90 |
+
residual = self.downsample(x)
|
91 |
+
|
92 |
+
out += residual
|
93 |
+
out = self.relu(out)
|
94 |
+
|
95 |
+
return out
|
96 |
+
|
97 |
+
|
98 |
+
class HourGlass(nn.Module):
|
99 |
+
def __init__(self, num_modules, depth, num_features):
|
100 |
+
super(HourGlass, self).__init__()
|
101 |
+
self.num_modules = num_modules
|
102 |
+
self.depth = depth
|
103 |
+
self.features = num_features
|
104 |
+
|
105 |
+
self._generate_network(self.depth)
|
106 |
+
|
107 |
+
def _generate_network(self, level):
|
108 |
+
self.add_module('b1_' + str(level), ConvBlock(self.features, self.features))
|
109 |
+
|
110 |
+
self.add_module('b2_' + str(level), ConvBlock(self.features, self.features))
|
111 |
+
|
112 |
+
if level > 1:
|
113 |
+
self._generate_network(level - 1)
|
114 |
+
else:
|
115 |
+
self.add_module('b2_plus_' + str(level), ConvBlock(self.features, self.features))
|
116 |
+
|
117 |
+
self.add_module('b3_' + str(level), ConvBlock(self.features, self.features))
|
118 |
+
|
119 |
+
def _forward(self, level, inp):
|
120 |
+
# Upper branch
|
121 |
+
up1 = inp
|
122 |
+
up1 = self._modules['b1_' + str(level)](up1)
|
123 |
+
|
124 |
+
# Lower branch
|
125 |
+
low1 = F.avg_pool2d(inp, 2, stride=2)
|
126 |
+
low1 = self._modules['b2_' + str(level)](low1)
|
127 |
+
|
128 |
+
if level > 1:
|
129 |
+
low2 = self._forward(level - 1, low1)
|
130 |
+
else:
|
131 |
+
low2 = low1
|
132 |
+
low2 = self._modules['b2_plus_' + str(level)](low2)
|
133 |
+
|
134 |
+
low3 = low2
|
135 |
+
low3 = self._modules['b3_' + str(level)](low3)
|
136 |
+
|
137 |
+
up2 = F.interpolate(low3, scale_factor=2, mode='nearest')
|
138 |
+
|
139 |
+
return up1 + up2
|
140 |
+
|
141 |
+
def forward(self, x):
|
142 |
+
return self._forward(self.depth, x)
|
143 |
+
|
144 |
+
|
145 |
+
class FAN(nn.Module):
|
146 |
+
|
147 |
+
def __init__(self, num_modules=1):
|
148 |
+
super(FAN, self).__init__()
|
149 |
+
self.num_modules = num_modules
|
150 |
+
|
151 |
+
# Base part
|
152 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
|
153 |
+
self.bn1 = nn.BatchNorm2d(64)
|
154 |
+
self.conv2 = ConvBlock(64, 128)
|
155 |
+
self.conv3 = ConvBlock(128, 128)
|
156 |
+
self.conv4 = ConvBlock(128, 256)
|
157 |
+
|
158 |
+
# Stacking part
|
159 |
+
for hg_module in range(self.num_modules):
|
160 |
+
self.add_module('m' + str(hg_module), HourGlass(1, 4, 256))
|
161 |
+
self.add_module('top_m_' + str(hg_module), ConvBlock(256, 256))
|
162 |
+
self.add_module('conv_last' + str(hg_module),
|
163 |
+
nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
|
164 |
+
self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(256))
|
165 |
+
self.add_module('l' + str(hg_module), nn.Conv2d(256,
|
166 |
+
68, kernel_size=1, stride=1, padding=0))
|
167 |
+
|
168 |
+
if hg_module < self.num_modules - 1:
|
169 |
+
self.add_module(
|
170 |
+
'bl' + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
|
171 |
+
self.add_module('al' + str(hg_module), nn.Conv2d(68,
|
172 |
+
256, kernel_size=1, stride=1, padding=0))
|
173 |
+
|
174 |
+
def forward(self, x):
|
175 |
+
x = F.relu(self.bn1(self.conv1(x)), True)
|
176 |
+
x = F.avg_pool2d(self.conv2(x), 2, stride=2)
|
177 |
+
x = self.conv3(x)
|
178 |
+
x = self.conv4(x)
|
179 |
+
|
180 |
+
previous = x
|
181 |
+
|
182 |
+
outputs = []
|
183 |
+
for i in range(self.num_modules):
|
184 |
+
hg = self._modules['m' + str(i)](previous)
|
185 |
+
|
186 |
+
ll = hg
|
187 |
+
ll = self._modules['top_m_' + str(i)](ll)
|
188 |
+
|
189 |
+
ll = F.relu(self._modules['bn_end' + str(i)]
|
190 |
+
(self._modules['conv_last' + str(i)](ll)), True)
|
191 |
+
|
192 |
+
# Predict heatmaps
|
193 |
+
tmp_out = self._modules['l' + str(i)](ll)
|
194 |
+
outputs.append(tmp_out)
|
195 |
+
|
196 |
+
if i < self.num_modules - 1:
|
197 |
+
ll = self._modules['bl' + str(i)](ll)
|
198 |
+
tmp_out_ = self._modules['al' + str(i)](tmp_out)
|
199 |
+
previous = previous + ll + tmp_out_
|
200 |
+
|
201 |
+
return outputs
|
202 |
+
|
203 |
+
|
204 |
+
class ResNetDepth(nn.Module):
|
205 |
+
|
206 |
+
def __init__(self, block=Bottleneck, layers=[3, 8, 36, 3], num_classes=68):
|
207 |
+
self.inplanes = 64
|
208 |
+
super(ResNetDepth, self).__init__()
|
209 |
+
self.conv1 = nn.Conv2d(3 + 68, 64, kernel_size=7, stride=2, padding=3,
|
210 |
+
bias=False)
|
211 |
+
self.bn1 = nn.BatchNorm2d(64)
|
212 |
+
self.relu = nn.ReLU(inplace=True)
|
213 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
214 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
215 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
216 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
217 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
218 |
+
self.avgpool = nn.AvgPool2d(7)
|
219 |
+
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
220 |
+
|
221 |
+
for m in self.modules():
|
222 |
+
if isinstance(m, nn.Conv2d):
|
223 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
224 |
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
225 |
+
elif isinstance(m, nn.BatchNorm2d):
|
226 |
+
m.weight.data.fill_(1)
|
227 |
+
m.bias.data.zero_()
|
228 |
+
|
229 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
230 |
+
downsample = None
|
231 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
232 |
+
downsample = nn.Sequential(
|
233 |
+
nn.Conv2d(self.inplanes, planes * block.expansion,
|
234 |
+
kernel_size=1, stride=stride, bias=False),
|
235 |
+
nn.BatchNorm2d(planes * block.expansion),
|
236 |
+
)
|
237 |
+
|
238 |
+
layers = []
|
239 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
240 |
+
self.inplanes = planes * block.expansion
|
241 |
+
for i in range(1, blocks):
|
242 |
+
layers.append(block(self.inplanes, planes))
|
243 |
+
|
244 |
+
return nn.Sequential(*layers)
|
245 |
+
|
246 |
+
def forward(self, x):
|
247 |
+
x = self.conv1(x)
|
248 |
+
x = self.bn1(x)
|
249 |
+
x = self.relu(x)
|
250 |
+
x = self.maxpool(x)
|
251 |
+
|
252 |
+
x = self.layer1(x)
|
253 |
+
x = self.layer2(x)
|
254 |
+
x = self.layer3(x)
|
255 |
+
x = self.layer4(x)
|
256 |
+
|
257 |
+
x = self.avgpool(x)
|
258 |
+
x = x.view(x.size(0), -1)
|
259 |
+
x = self.fc(x)
|
260 |
+
|
261 |
+
return x
|
wav2lip/face_detection/utils.py
ADDED
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import time
|
5 |
+
import torch
|
6 |
+
import math
|
7 |
+
import numpy as np
|
8 |
+
import cv2
|
9 |
+
|
10 |
+
|
11 |
+
def _gaussian(
|
12 |
+
size=3, sigma=0.25, amplitude=1, normalize=False, width=None,
|
13 |
+
height=None, sigma_horz=None, sigma_vert=None, mean_horz=0.5,
|
14 |
+
mean_vert=0.5):
|
15 |
+
# handle some defaults
|
16 |
+
if width is None:
|
17 |
+
width = size
|
18 |
+
if height is None:
|
19 |
+
height = size
|
20 |
+
if sigma_horz is None:
|
21 |
+
sigma_horz = sigma
|
22 |
+
if sigma_vert is None:
|
23 |
+
sigma_vert = sigma
|
24 |
+
center_x = mean_horz * width + 0.5
|
25 |
+
center_y = mean_vert * height + 0.5
|
26 |
+
gauss = np.empty((height, width), dtype=np.float32)
|
27 |
+
# generate kernel
|
28 |
+
for i in range(height):
|
29 |
+
for j in range(width):
|
30 |
+
gauss[i][j] = amplitude * math.exp(-(math.pow((j + 1 - center_x) / (
|
31 |
+
sigma_horz * width), 2) / 2.0 + math.pow((i + 1 - center_y) / (sigma_vert * height), 2) / 2.0))
|
32 |
+
if normalize:
|
33 |
+
gauss = gauss / np.sum(gauss)
|
34 |
+
return gauss
|
35 |
+
|
36 |
+
|
37 |
+
def draw_gaussian(image, point, sigma):
|
38 |
+
# Check if the gaussian is inside
|
39 |
+
ul = [math.floor(point[0] - 3 * sigma), math.floor(point[1] - 3 * sigma)]
|
40 |
+
br = [math.floor(point[0] + 3 * sigma), math.floor(point[1] + 3 * sigma)]
|
41 |
+
if (ul[0] > image.shape[1] or ul[1] > image.shape[0] or br[0] < 1 or br[1] < 1):
|
42 |
+
return image
|
43 |
+
size = 6 * sigma + 1
|
44 |
+
g = _gaussian(size)
|
45 |
+
g_x = [int(max(1, -ul[0])), int(min(br[0], image.shape[1])) - int(max(1, ul[0])) + int(max(1, -ul[0]))]
|
46 |
+
g_y = [int(max(1, -ul[1])), int(min(br[1], image.shape[0])) - int(max(1, ul[1])) + int(max(1, -ul[1]))]
|
47 |
+
img_x = [int(max(1, ul[0])), int(min(br[0], image.shape[1]))]
|
48 |
+
img_y = [int(max(1, ul[1])), int(min(br[1], image.shape[0]))]
|
49 |
+
assert (g_x[0] > 0 and g_y[1] > 0)
|
50 |
+
image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]
|
51 |
+
] = image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]] + g[g_y[0] - 1:g_y[1], g_x[0] - 1:g_x[1]]
|
52 |
+
image[image > 1] = 1
|
53 |
+
return image
|
54 |
+
|
55 |
+
|
56 |
+
def transform(point, center, scale, resolution, invert=False):
|
57 |
+
"""Generate and affine transformation matrix.
|
58 |
+
|
59 |
+
Given a set of points, a center, a scale and a targer resolution, the
|
60 |
+
function generates and affine transformation matrix. If invert is ``True``
|
61 |
+
it will produce the inverse transformation.
|
62 |
+
|
63 |
+
Arguments:
|
64 |
+
point {torch.tensor} -- the input 2D point
|
65 |
+
center {torch.tensor or numpy.array} -- the center around which to perform the transformations
|
66 |
+
scale {float} -- the scale of the face/object
|
67 |
+
resolution {float} -- the output resolution
|
68 |
+
|
69 |
+
Keyword Arguments:
|
70 |
+
invert {bool} -- define wherever the function should produce the direct or the
|
71 |
+
inverse transformation matrix (default: {False})
|
72 |
+
"""
|
73 |
+
_pt = torch.ones(3)
|
74 |
+
_pt[0] = point[0]
|
75 |
+
_pt[1] = point[1]
|
76 |
+
|
77 |
+
h = 200.0 * scale
|
78 |
+
t = torch.eye(3)
|
79 |
+
t[0, 0] = resolution / h
|
80 |
+
t[1, 1] = resolution / h
|
81 |
+
t[0, 2] = resolution * (-center[0] / h + 0.5)
|
82 |
+
t[1, 2] = resolution * (-center[1] / h + 0.5)
|
83 |
+
|
84 |
+
if invert:
|
85 |
+
t = torch.inverse(t)
|
86 |
+
|
87 |
+
new_point = (torch.matmul(t, _pt))[0:2]
|
88 |
+
|
89 |
+
return new_point.int()
|
90 |
+
|
91 |
+
|
92 |
+
def crop(image, center, scale, resolution=256.0):
|
93 |
+
"""Center crops an image or set of heatmaps
|
94 |
+
|
95 |
+
Arguments:
|
96 |
+
image {numpy.array} -- an rgb image
|
97 |
+
center {numpy.array} -- the center of the object, usually the same as of the bounding box
|
98 |
+
scale {float} -- scale of the face
|
99 |
+
|
100 |
+
Keyword Arguments:
|
101 |
+
resolution {float} -- the size of the output cropped image (default: {256.0})
|
102 |
+
|
103 |
+
Returns:
|
104 |
+
[type] -- [description]
|
105 |
+
""" # Crop around the center point
|
106 |
+
""" Crops the image around the center. Input is expected to be an np.ndarray """
|
107 |
+
ul = transform([1, 1], center, scale, resolution, True)
|
108 |
+
br = transform([resolution, resolution], center, scale, resolution, True)
|
109 |
+
# pad = math.ceil(torch.norm((ul - br).float()) / 2.0 - (br[0] - ul[0]) / 2.0)
|
110 |
+
if image.ndim > 2:
|
111 |
+
newDim = np.array([br[1] - ul[1], br[0] - ul[0],
|
112 |
+
image.shape[2]], dtype=np.int32)
|
113 |
+
newImg = np.zeros(newDim, dtype=np.uint8)
|
114 |
+
else:
|
115 |
+
newDim = np.array([br[1] - ul[1], br[0] - ul[0]], dtype=np.int)
|
116 |
+
newImg = np.zeros(newDim, dtype=np.uint8)
|
117 |
+
ht = image.shape[0]
|
118 |
+
wd = image.shape[1]
|
119 |
+
newX = np.array(
|
120 |
+
[max(1, -ul[0] + 1), min(br[0], wd) - ul[0]], dtype=np.int32)
|
121 |
+
newY = np.array(
|
122 |
+
[max(1, -ul[1] + 1), min(br[1], ht) - ul[1]], dtype=np.int32)
|
123 |
+
oldX = np.array([max(1, ul[0] + 1), min(br[0], wd)], dtype=np.int32)
|
124 |
+
oldY = np.array([max(1, ul[1] + 1), min(br[1], ht)], dtype=np.int32)
|
125 |
+
newImg[newY[0] - 1:newY[1], newX[0] - 1:newX[1]
|
126 |
+
] = image[oldY[0] - 1:oldY[1], oldX[0] - 1:oldX[1], :]
|
127 |
+
newImg = cv2.resize(newImg, dsize=(int(resolution), int(resolution)),
|
128 |
+
interpolation=cv2.INTER_LINEAR)
|
129 |
+
return newImg
|
130 |
+
|
131 |
+
|
132 |
+
def get_preds_fromhm(hm, center=None, scale=None):
|
133 |
+
"""Obtain (x,y) coordinates given a set of N heatmaps. If the center
|
134 |
+
and the scale is provided the function will return the points also in
|
135 |
+
the original coordinate frame.
|
136 |
+
|
137 |
+
Arguments:
|
138 |
+
hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H]
|
139 |
+
|
140 |
+
Keyword Arguments:
|
141 |
+
center {torch.tensor} -- the center of the bounding box (default: {None})
|
142 |
+
scale {float} -- face scale (default: {None})
|
143 |
+
"""
|
144 |
+
max, idx = torch.max(
|
145 |
+
hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2)
|
146 |
+
idx += 1
|
147 |
+
preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float()
|
148 |
+
preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1)
|
149 |
+
preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1)
|
150 |
+
|
151 |
+
for i in range(preds.size(0)):
|
152 |
+
for j in range(preds.size(1)):
|
153 |
+
hm_ = hm[i, j, :]
|
154 |
+
pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1
|
155 |
+
if pX > 0 and pX < 63 and pY > 0 and pY < 63:
|
156 |
+
diff = torch.FloatTensor(
|
157 |
+
[hm_[pY, pX + 1] - hm_[pY, pX - 1],
|
158 |
+
hm_[pY + 1, pX] - hm_[pY - 1, pX]])
|
159 |
+
preds[i, j].add_(diff.sign_().mul_(.25))
|
160 |
+
|
161 |
+
preds.add_(-.5)
|
162 |
+
|
163 |
+
preds_orig = torch.zeros(preds.size())
|
164 |
+
if center is not None and scale is not None:
|
165 |
+
for i in range(hm.size(0)):
|
166 |
+
for j in range(hm.size(1)):
|
167 |
+
preds_orig[i, j] = transform(
|
168 |
+
preds[i, j], center, scale, hm.size(2), True)
|
169 |
+
|
170 |
+
return preds, preds_orig
|
171 |
+
|
172 |
+
def get_preds_fromhm_batch(hm, centers=None, scales=None):
|
173 |
+
"""Obtain (x,y) coordinates given a set of N heatmaps. If the centers
|
174 |
+
and the scales is provided the function will return the points also in
|
175 |
+
the original coordinate frame.
|
176 |
+
|
177 |
+
Arguments:
|
178 |
+
hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H]
|
179 |
+
|
180 |
+
Keyword Arguments:
|
181 |
+
centers {torch.tensor} -- the centers of the bounding box (default: {None})
|
182 |
+
scales {float} -- face scales (default: {None})
|
183 |
+
"""
|
184 |
+
max, idx = torch.max(
|
185 |
+
hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2)
|
186 |
+
idx += 1
|
187 |
+
preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float()
|
188 |
+
preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1)
|
189 |
+
preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1)
|
190 |
+
|
191 |
+
for i in range(preds.size(0)):
|
192 |
+
for j in range(preds.size(1)):
|
193 |
+
hm_ = hm[i, j, :]
|
194 |
+
pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1
|
195 |
+
if pX > 0 and pX < 63 and pY > 0 and pY < 63:
|
196 |
+
diff = torch.FloatTensor(
|
197 |
+
[hm_[pY, pX + 1] - hm_[pY, pX - 1],
|
198 |
+
hm_[pY + 1, pX] - hm_[pY - 1, pX]])
|
199 |
+
preds[i, j].add_(diff.sign_().mul_(.25))
|
200 |
+
|
201 |
+
preds.add_(-.5)
|
202 |
+
|
203 |
+
preds_orig = torch.zeros(preds.size())
|
204 |
+
if centers is not None and scales is not None:
|
205 |
+
for i in range(hm.size(0)):
|
206 |
+
for j in range(hm.size(1)):
|
207 |
+
preds_orig[i, j] = transform(
|
208 |
+
preds[i, j], centers[i], scales[i], hm.size(2), True)
|
209 |
+
|
210 |
+
return preds, preds_orig
|
211 |
+
|
212 |
+
def shuffle_lr(parts, pairs=None):
|
213 |
+
"""Shuffle the points left-right according to the axis of symmetry
|
214 |
+
of the object.
|
215 |
+
|
216 |
+
Arguments:
|
217 |
+
parts {torch.tensor} -- a 3D or 4D object containing the
|
218 |
+
heatmaps.
|
219 |
+
|
220 |
+
Keyword Arguments:
|
221 |
+
pairs {list of integers} -- [order of the flipped points] (default: {None})
|
222 |
+
"""
|
223 |
+
if pairs is None:
|
224 |
+
pairs = [16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0,
|
225 |
+
26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 27, 28, 29, 30, 35,
|
226 |
+
34, 33, 32, 31, 45, 44, 43, 42, 47, 46, 39, 38, 37, 36, 41,
|
227 |
+
40, 54, 53, 52, 51, 50, 49, 48, 59, 58, 57, 56, 55, 64, 63,
|
228 |
+
62, 61, 60, 67, 66, 65]
|
229 |
+
if parts.ndimension() == 3:
|
230 |
+
parts = parts[pairs, ...]
|
231 |
+
else:
|
232 |
+
parts = parts[:, pairs, ...]
|
233 |
+
|
234 |
+
return parts
|
235 |
+
|
236 |
+
|
237 |
+
def flip(tensor, is_label=False):
|
238 |
+
"""Flip an image or a set of heatmaps left-right
|
239 |
+
|
240 |
+
Arguments:
|
241 |
+
tensor {numpy.array or torch.tensor} -- [the input image or heatmaps]
|
242 |
+
|
243 |
+
Keyword Arguments:
|
244 |
+
is_label {bool} -- [denote wherever the input is an image or a set of heatmaps ] (default: {False})
|
245 |
+
"""
|
246 |
+
if not torch.is_tensor(tensor):
|
247 |
+
tensor = torch.from_numpy(tensor)
|
248 |
+
|
249 |
+
if is_label:
|
250 |
+
tensor = shuffle_lr(tensor).flip(tensor.ndimension() - 1)
|
251 |
+
else:
|
252 |
+
tensor = tensor.flip(tensor.ndimension() - 1)
|
253 |
+
|
254 |
+
return tensor
|
255 |
+
|
256 |
+
# From pyzolib/paths.py (https://bitbucket.org/pyzo/pyzolib/src/tip/paths.py)
|
257 |
+
|
258 |
+
|
259 |
+
def appdata_dir(appname=None, roaming=False):
|
260 |
+
""" appdata_dir(appname=None, roaming=False)
|
261 |
+
|
262 |
+
Get the path to the application directory, where applications are allowed
|
263 |
+
to write user specific files (e.g. configurations). For non-user specific
|
264 |
+
data, consider using common_appdata_dir().
|
265 |
+
If appname is given, a subdir is appended (and created if necessary).
|
266 |
+
If roaming is True, will prefer a roaming directory (Windows Vista/7).
|
267 |
+
"""
|
268 |
+
|
269 |
+
# Define default user directory
|
270 |
+
userDir = os.getenv('FACEALIGNMENT_USERDIR', None)
|
271 |
+
if userDir is None:
|
272 |
+
userDir = os.path.expanduser('~')
|
273 |
+
if not os.path.isdir(userDir): # pragma: no cover
|
274 |
+
userDir = '/var/tmp' # issue #54
|
275 |
+
|
276 |
+
# Get system app data dir
|
277 |
+
path = None
|
278 |
+
if sys.platform.startswith('win'):
|
279 |
+
path1, path2 = os.getenv('LOCALAPPDATA'), os.getenv('APPDATA')
|
280 |
+
path = (path2 or path1) if roaming else (path1 or path2)
|
281 |
+
elif sys.platform.startswith('darwin'):
|
282 |
+
path = os.path.join(userDir, 'Library', 'Application Support')
|
283 |
+
# On Linux and as fallback
|
284 |
+
if not (path and os.path.isdir(path)):
|
285 |
+
path = userDir
|
286 |
+
|
287 |
+
# Maybe we should store things local to the executable (in case of a
|
288 |
+
# portable distro or a frozen application that wants to be portable)
|
289 |
+
prefix = sys.prefix
|
290 |
+
if getattr(sys, 'frozen', None):
|
291 |
+
prefix = os.path.abspath(os.path.dirname(sys.executable))
|
292 |
+
for reldir in ('settings', '../settings'):
|
293 |
+
localpath = os.path.abspath(os.path.join(prefix, reldir))
|
294 |
+
if os.path.isdir(localpath): # pragma: no cover
|
295 |
+
try:
|
296 |
+
open(os.path.join(localpath, 'test.write'), 'wb').close()
|
297 |
+
os.remove(os.path.join(localpath, 'test.write'))
|
298 |
+
except IOError:
|
299 |
+
pass # We cannot write in this directory
|
300 |
+
else:
|
301 |
+
path = localpath
|
302 |
+
break
|
303 |
+
|
304 |
+
# Get path specific for this app
|
305 |
+
if appname:
|
306 |
+
if path == userDir:
|
307 |
+
appname = '.' + appname.lstrip('.') # Make it a hidden directory
|
308 |
+
path = os.path.join(path, appname)
|
309 |
+
if not os.path.isdir(path): # pragma: no cover
|
310 |
+
os.mkdir(path)
|
311 |
+
|
312 |
+
# Done
|
313 |
+
return path
|
wav2lip/hparams.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from glob import glob
|
2 |
+
import os
|
3 |
+
|
4 |
+
def get_image_list(data_root, split):
|
5 |
+
filelist = []
|
6 |
+
|
7 |
+
with open('filelists/{}.txt'.format(split)) as f:
|
8 |
+
for line in f:
|
9 |
+
line = line.strip()
|
10 |
+
if ' ' in line: line = line.split()[0]
|
11 |
+
filelist.append(os.path.join(data_root, line))
|
12 |
+
|
13 |
+
return filelist
|
14 |
+
|
15 |
+
class HParams:
|
16 |
+
def __init__(self, **kwargs):
|
17 |
+
self.data = {}
|
18 |
+
|
19 |
+
for key, value in kwargs.items():
|
20 |
+
self.data[key] = value
|
21 |
+
|
22 |
+
def __getattr__(self, key):
|
23 |
+
if key not in self.data:
|
24 |
+
raise AttributeError("'HParams' object has no attribute %s" % key)
|
25 |
+
return self.data[key]
|
26 |
+
|
27 |
+
def set_hparam(self, key, value):
|
28 |
+
self.data[key] = value
|
29 |
+
|
30 |
+
|
31 |
+
# Default hyperparameters
|
32 |
+
hparams = HParams(
|
33 |
+
num_mels=80, # Number of mel-spectrogram channels and local conditioning dimensionality
|
34 |
+
# network
|
35 |
+
rescale=True, # Whether to rescale audio prior to preprocessing
|
36 |
+
rescaling_max=0.9, # Rescaling value
|
37 |
+
|
38 |
+
# Use LWS (https://github.com/Jonathan-LeRoux/lws) for STFT and phase reconstruction
|
39 |
+
# It"s preferred to set True to use with https://github.com/r9y9/wavenet_vocoder
|
40 |
+
# Does not work if n_ffit is not multiple of hop_size!!
|
41 |
+
use_lws=False,
|
42 |
+
|
43 |
+
n_fft=800, # Extra window size is filled with 0 paddings to match this parameter
|
44 |
+
hop_size=200, # For 16000Hz, 200 = 12.5 ms (0.0125 * sample_rate)
|
45 |
+
win_size=800, # For 16000Hz, 800 = 50 ms (If None, win_size = n_fft) (0.05 * sample_rate)
|
46 |
+
sample_rate=16000, # 16000Hz (corresponding to librispeech) (sox --i <filename>)
|
47 |
+
|
48 |
+
frame_shift_ms=None, # Can replace hop_size parameter. (Recommended: 12.5)
|
49 |
+
|
50 |
+
# Mel and Linear spectrograms normalization/scaling and clipping
|
51 |
+
signal_normalization=True,
|
52 |
+
# Whether to normalize mel spectrograms to some predefined range (following below parameters)
|
53 |
+
allow_clipping_in_normalization=True, # Only relevant if mel_normalization = True
|
54 |
+
symmetric_mels=True,
|
55 |
+
# Whether to scale the data to be symmetric around 0. (Also multiplies the output range by 2,
|
56 |
+
# faster and cleaner convergence)
|
57 |
+
max_abs_value=4.,
|
58 |
+
# max absolute value of data. If symmetric, data will be [-max, max] else [0, max] (Must not
|
59 |
+
# be too big to avoid gradient explosion,
|
60 |
+
# not too small for fast convergence)
|
61 |
+
# Contribution by @begeekmyfriend
|
62 |
+
# Spectrogram Pre-Emphasis (Lfilter: Reduce spectrogram noise and helps model certitude
|
63 |
+
# levels. Also allows for better G&L phase reconstruction)
|
64 |
+
preemphasize=True, # whether to apply filter
|
65 |
+
preemphasis=0.97, # filter coefficient.
|
66 |
+
|
67 |
+
# Limits
|
68 |
+
min_level_db=-100,
|
69 |
+
ref_level_db=20,
|
70 |
+
fmin=55,
|
71 |
+
# Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To
|
72 |
+
# test depending on dataset. Pitch info: male~[65, 260], female~[100, 525])
|
73 |
+
fmax=7600, # To be increased/reduced depending on data.
|
74 |
+
|
75 |
+
###################### Our training parameters #################################
|
76 |
+
img_size=96,
|
77 |
+
fps=25,
|
78 |
+
|
79 |
+
batch_size=16,
|
80 |
+
initial_learning_rate=1e-4,
|
81 |
+
nepochs=200000000000000000, ### ctrl + c, stop whenever eval loss is consistently greater than train loss for ~10 epochs
|
82 |
+
num_workers=16,
|
83 |
+
checkpoint_interval=3000,
|
84 |
+
eval_interval=3000,
|
85 |
+
save_optimizer_state=True,
|
86 |
+
|
87 |
+
syncnet_wt=0.0, # is initially zero, will be set automatically to 0.03 later. Leads to faster convergence.
|
88 |
+
syncnet_batch_size=64,
|
89 |
+
syncnet_lr=1e-4,
|
90 |
+
syncnet_eval_interval=10000,
|
91 |
+
syncnet_checkpoint_interval=10000,
|
92 |
+
|
93 |
+
disc_wt=0.07,
|
94 |
+
disc_initial_learning_rate=1e-4,
|
95 |
+
)
|
96 |
+
|
97 |
+
|
98 |
+
def hparams_debug_string():
|
99 |
+
values = hparams.values()
|
100 |
+
hp = [" %s: %s" % (name, values[name]) for name in sorted(values) if name != "sentences"]
|
101 |
+
return "Hyperparameters:\n" + "\n".join(hp)
|
wav2lip/inference.py
ADDED
@@ -0,0 +1,484 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import cv2
|
3 |
+
import os
|
4 |
+
import argparse
|
5 |
+
import subprocess
|
6 |
+
from tqdm import tqdm
|
7 |
+
import sys
|
8 |
+
import traceback
|
9 |
+
from .audio import load_wav, melspectrogram
|
10 |
+
from .face_detection import FaceAlignment, LandmarksType
|
11 |
+
import torch
|
12 |
+
import platform
|
13 |
+
|
14 |
+
parser = argparse.ArgumentParser(description='Inference code to lip-sync videos in the wild using Wav2Lip models')
|
15 |
+
|
16 |
+
parser.add_argument('--outfile', type=str, help='Video path to save result. See default for an e.g.',
|
17 |
+
default='wav2lip/results/result_voice.mp4')
|
18 |
+
|
19 |
+
parser.add_argument('--static', type=bool,
|
20 |
+
help='If True, then use only first video frame for inference', default=False)
|
21 |
+
parser.add_argument('--fps', type=float, help='Can be specified only if input is a static image (default: 25)',
|
22 |
+
default=25., required=False)
|
23 |
+
|
24 |
+
parser.add_argument('--pads', nargs='+', type=int, default=[0, 10, 0, 0],
|
25 |
+
help='Padding (top, bottom, left, right). Please adjust to include chin at least')
|
26 |
+
|
27 |
+
parser.add_argument('--face_det_batch_size', type=int,
|
28 |
+
help='Batch size for face detection', default=32)
|
29 |
+
parser.add_argument('--wav2lip_batch_size', type=int, help='Batch size for Wav2Lip model(s)', default=512)
|
30 |
+
|
31 |
+
parser.add_argument('--resize_factor', default=1, type=int,
|
32 |
+
help='Reduce the resolution by this factor. Sometimes, best results are obtained at 480p or 720p')
|
33 |
+
|
34 |
+
parser.add_argument('--crop', nargs='+', type=int, default=[0, -1, 0, -1],
|
35 |
+
help='Crop video to a smaller region (top, bottom, left, right). Applied after resize_factor and rotate arg. '
|
36 |
+
'Useful if multiple face present. -1 implies the value will be auto-inferred based on height, width')
|
37 |
+
|
38 |
+
parser.add_argument('--box', nargs='+', type=int, default=[-1, -1, -1, -1],
|
39 |
+
help='Specify a constant bounding box for the face. Use only as a last resort if the face is not detected.'
|
40 |
+
'Also, might work only if the face is not moving around much. Syntax: (top, bottom, left, right).')
|
41 |
+
|
42 |
+
parser.add_argument('--rotate', default=False, action='store_true',
|
43 |
+
help='Sometimes videos taken from a phone can be flipped 90deg. If true, will flip video right by 90deg.'
|
44 |
+
'Use if you get a flipped result, despite feeding a normal looking video')
|
45 |
+
|
46 |
+
parser.add_argument('--nosmooth', default=False, action='store_true',
|
47 |
+
help='Prevent smoothing face detections over a short temporal window')
|
48 |
+
|
49 |
+
args = parser.parse_args()
|
50 |
+
args.img_size = 96
|
51 |
+
|
52 |
+
# Check for available devices
|
53 |
+
if torch.backends.mps.is_available():
|
54 |
+
device = 'mps' # Use Apple Silicon GPU
|
55 |
+
elif torch.cuda.is_available():
|
56 |
+
device = 'cuda'
|
57 |
+
else:
|
58 |
+
device = 'cpu'
|
59 |
+
|
60 |
+
print('Using {} for inference.'.format(device))
|
61 |
+
|
62 |
+
def get_smoothened_boxes(boxes, idx):
|
63 |
+
"""Get smoothened box for a specific index"""
|
64 |
+
if idx >= len(boxes) or boxes[idx] is None:
|
65 |
+
return None, None
|
66 |
+
|
67 |
+
# Return the face region and coordinates
|
68 |
+
if isinstance(boxes[idx], list) and len(boxes[idx]) == 2: # Format from the specified bounding box
|
69 |
+
return boxes[idx][0], boxes[idx][1]
|
70 |
+
else: # Format from face detection - [x1, y1, x2, y2]
|
71 |
+
if isinstance(boxes[idx], list) or isinstance(boxes[idx], tuple):
|
72 |
+
if len(boxes[idx]) >= 4: # Make sure we have all 4 coordinates
|
73 |
+
x1, y1, x2, y2 = boxes[idx][:4]
|
74 |
+
# Return coordinates in the expected format (y1, y2, x1, x2)
|
75 |
+
coords = (y1, y2, x1, x2)
|
76 |
+
return None, coords
|
77 |
+
|
78 |
+
print(f"WARNING: Unexpected box format at idx {idx}: {boxes[idx]}")
|
79 |
+
return None, None
|
80 |
+
|
81 |
+
def face_detect(images):
|
82 |
+
print(f"Starting face detection using {device} device...")
|
83 |
+
try:
|
84 |
+
detector = FaceAlignment(LandmarksType._2D,
|
85 |
+
flip_input=False, device=device, verbose=True)
|
86 |
+
except Exception as e:
|
87 |
+
print(f"Error initializing face detector: {str(e)}")
|
88 |
+
print("Attempting to fall back to CPU for face detection...")
|
89 |
+
detector = FaceAlignment(LandmarksType._2D,
|
90 |
+
flip_input=False, device='cpu', verbose=True)
|
91 |
+
|
92 |
+
batch_size = args.face_det_batch_size
|
93 |
+
|
94 |
+
while 1:
|
95 |
+
predictions = []
|
96 |
+
try:
|
97 |
+
for i in range(0, len(images), batch_size):
|
98 |
+
batch = np.array(images[i:i + batch_size])
|
99 |
+
print(f"Processing detection batch {i//batch_size + 1}, shape: {batch.shape}")
|
100 |
+
batch_predictions = detector.get_detections_for_batch(batch)
|
101 |
+
predictions.extend(batch_predictions)
|
102 |
+
except RuntimeError as e:
|
103 |
+
print(f"Runtime error in face detection: {str(e)}")
|
104 |
+
if batch_size == 1:
|
105 |
+
# Error when batch_size is already 1
|
106 |
+
print('Face detection failed at minimum batch size! Using fallback method...')
|
107 |
+
# Create empty predictions for all frames to allow processing to continue
|
108 |
+
predictions = [None] * len(images)
|
109 |
+
break
|
110 |
+
batch_size //= 2
|
111 |
+
print('Reducing face detection batch size to', batch_size)
|
112 |
+
continue
|
113 |
+
except Exception as e:
|
114 |
+
print(f"Unexpected error in face detection: {str(e)}")
|
115 |
+
# Create empty predictions and continue with fallback
|
116 |
+
predictions = [None] * len(images)
|
117 |
+
break
|
118 |
+
break
|
119 |
+
|
120 |
+
# Check if we have at least one valid face detection
|
121 |
+
faces_detected = sum(1 for p in predictions if p is not None)
|
122 |
+
print(f"Detected faces in {faces_detected} out of {len(images)} frames ({faces_detected/len(images)*100:.1f}%)")
|
123 |
+
|
124 |
+
results = []
|
125 |
+
pady1, pady2, padx1, padx2 = args.pads
|
126 |
+
|
127 |
+
for i, (rect, image) in enumerate(zip(predictions, images)):
|
128 |
+
if rect is None:
|
129 |
+
# Create default coordinates for face detection
|
130 |
+
h, w = image.shape[:2]
|
131 |
+
|
132 |
+
# Simple and consistent face region estimation based on center of the frame
|
133 |
+
center_x = w // 2
|
134 |
+
center_y = h // 2
|
135 |
+
|
136 |
+
# Use about 1/3 of the frame height for face
|
137 |
+
face_h = h // 3
|
138 |
+
face_w = min(w // 2, face_h)
|
139 |
+
|
140 |
+
# Create a centered box
|
141 |
+
x1 = max(0, center_x - face_w // 2 - padx1)
|
142 |
+
y1 = max(0, center_y - face_h // 2 - pady1)
|
143 |
+
x2 = min(w, center_x + face_w // 2 + padx2)
|
144 |
+
y2 = min(h, center_y + face_h // 2 + pady2)
|
145 |
+
|
146 |
+
if i == 0 or i % 100 == 0: # Log only occasionally to avoid flooding
|
147 |
+
print(f"Frame {i}: Using fallback face region at ({x1},{y1},{x2},{y2})")
|
148 |
+
|
149 |
+
results.append([x1, y1, x2, y2])
|
150 |
+
continue
|
151 |
+
|
152 |
+
# If face is detected, use its coordinates with padding
|
153 |
+
y1 = max(0, rect[1] - pady1)
|
154 |
+
y2 = min(image.shape[0], rect[3] + pady2)
|
155 |
+
x1 = max(0, rect[0] - padx1)
|
156 |
+
x2 = min(image.shape[1], rect[2] + padx2)
|
157 |
+
|
158 |
+
results.append([x1, y1, x2, y2])
|
159 |
+
|
160 |
+
return results
|
161 |
+
|
162 |
+
def datagen(frames, mels):
|
163 |
+
img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
|
164 |
+
|
165 |
+
if args.box[0] == -1:
|
166 |
+
if not args.static:
|
167 |
+
try:
|
168 |
+
print(f"Starting face detection for {len(frames)} frames...")
|
169 |
+
face_det_results = face_detect(frames) # BGR2RGB for CNN face detection
|
170 |
+
print("Face detection completed successfully")
|
171 |
+
except Exception as e:
|
172 |
+
print(f"Face detection error: {str(e)}")
|
173 |
+
print(f"Error type: {type(e).__name__}")
|
174 |
+
traceback.print_exc()
|
175 |
+
print("Using fallback method with default face regions...")
|
176 |
+
# Create default face regions for all frames
|
177 |
+
h, w = frames[0].shape[:2]
|
178 |
+
|
179 |
+
# Simple face region estimation in the center of the frame
|
180 |
+
center_x = w // 2
|
181 |
+
center_y = h // 2
|
182 |
+
|
183 |
+
# Use about 1/3 of the frame height for face
|
184 |
+
face_h = h // 3
|
185 |
+
face_w = min(w // 2, face_h)
|
186 |
+
|
187 |
+
pady1, pady2, padx1, padx2 = args.pads
|
188 |
+
x1 = max(0, center_x - face_w // 2 - padx1)
|
189 |
+
y1 = max(0, center_y - face_h // 2 - pady1)
|
190 |
+
x2 = min(w, center_x + face_w // 2 + padx2)
|
191 |
+
y2 = min(h, center_y + face_h // 2 + pady2)
|
192 |
+
|
193 |
+
print(f"Estimated face region: x1={x1}, y1={y1}, x2={x2}, y2={y2}")
|
194 |
+
|
195 |
+
# Use the same format as the face_detect function returns
|
196 |
+
face_det_results = [[x1, y1, x2, y2] for _ in range(len(frames))]
|
197 |
+
else:
|
198 |
+
try:
|
199 |
+
print("Starting face detection for static image...")
|
200 |
+
face_det_results = face_detect([frames[0]])
|
201 |
+
print("Face detection completed successfully")
|
202 |
+
except Exception as e:
|
203 |
+
print(f"Face detection error: {str(e)}")
|
204 |
+
print(f"Error type: {type(e).__name__}")
|
205 |
+
traceback.print_exc()
|
206 |
+
print("Using fallback method with default face region...")
|
207 |
+
# Create default face region for static image
|
208 |
+
h, w = frames[0].shape[:2]
|
209 |
+
|
210 |
+
# Simple face region estimation in the center of the frame
|
211 |
+
center_x = w // 2
|
212 |
+
center_y = h // 2
|
213 |
+
|
214 |
+
# Use about 1/3 of the frame height for face
|
215 |
+
face_h = h // 3
|
216 |
+
face_w = min(w // 2, face_h)
|
217 |
+
|
218 |
+
pady1, pady2, padx1, padx2 = args.pads
|
219 |
+
x1 = max(0, center_x - face_w // 2 - padx1)
|
220 |
+
y1 = max(0, center_y - face_h // 2 - pady1)
|
221 |
+
x2 = min(w, center_x + face_w // 2 + padx2)
|
222 |
+
y2 = min(h, center_y + face_h // 2 + pady2)
|
223 |
+
|
224 |
+
print(f"Estimated face region for static image: x1={x1}, y1={y1}, x2={x2}, y2={y2}")
|
225 |
+
|
226 |
+
# Use the same format as the face_detect function returns
|
227 |
+
face_det_results = [[x1, y1, x2, y2]]
|
228 |
+
else:
|
229 |
+
print('Using the specified bounding box instead of face detection...')
|
230 |
+
y1, y2, x1, x2 = args.box
|
231 |
+
face_det_results = [[x1, y1, x2, y2] for _ in range(len(frames))]
|
232 |
+
|
233 |
+
for i, m in enumerate(mels):
|
234 |
+
idx = 0 if args.static else i%len(frames)
|
235 |
+
frame_to_save = frames[idx].copy()
|
236 |
+
|
237 |
+
if args.box[0] == -1:
|
238 |
+
face, coords = get_smoothened_boxes(face_det_results, idx)
|
239 |
+
|
240 |
+
if coords is None:
|
241 |
+
print(f'Face coordinates not detected! Skipping frame {i}')
|
242 |
+
continue
|
243 |
+
|
244 |
+
# If face is None, extract it from the frame using coordinates
|
245 |
+
if face is None:
|
246 |
+
y1, y2, x1, x2 = coords
|
247 |
+
try:
|
248 |
+
if y1 >= y2 or x1 >= x2:
|
249 |
+
print(f"Invalid coordinates at frame {i}: y1={y1}, y2={y2}, x1={x1}, x2={x2}")
|
250 |
+
continue
|
251 |
+
if y1 < 0 or x1 < 0 or y2 > frame_to_save.shape[0] or x2 > frame_to_save.shape[1]:
|
252 |
+
print(f"Out of bounds coordinates at frame {i}. Adjusting...")
|
253 |
+
y1 = max(0, y1)
|
254 |
+
x1 = max(0, x1)
|
255 |
+
y2 = min(frame_to_save.shape[0], y2)
|
256 |
+
x2 = min(frame_to_save.shape[1], x2)
|
257 |
+
|
258 |
+
# Check if the region is too small
|
259 |
+
if (y2 - y1) < 10 or (x2 - x1) < 10:
|
260 |
+
print(f"Region too small at frame {i}. Skipping.")
|
261 |
+
continue
|
262 |
+
|
263 |
+
face = frames[idx][y1:y2, x1:x2]
|
264 |
+
except Exception as e:
|
265 |
+
print(f"Error extracting face at frame {i}: {str(e)}")
|
266 |
+
continue
|
267 |
+
else:
|
268 |
+
face = frames[idx][y1:y2, x1:x2]
|
269 |
+
coords = (y1, y2, x1, x2)
|
270 |
+
|
271 |
+
try:
|
272 |
+
face = cv2.resize(face, (args.img_size, args.img_size))
|
273 |
+
img_batch.append(face)
|
274 |
+
mel_batch.append(m)
|
275 |
+
frame_batch.append(frame_to_save)
|
276 |
+
coords_batch.append(coords)
|
277 |
+
except Exception as e:
|
278 |
+
print(f"Error processing frame {i}: {str(e)}")
|
279 |
+
continue
|
280 |
+
|
281 |
+
if len(img_batch) >= args.wav2lip_batch_size:
|
282 |
+
img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
|
283 |
+
|
284 |
+
img_masked = img_batch.copy()
|
285 |
+
img_masked[:, args.img_size//2:] = 0
|
286 |
+
|
287 |
+
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
|
288 |
+
mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
|
289 |
+
|
290 |
+
yield img_batch, mel_batch, frame_batch, coords_batch
|
291 |
+
img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
|
292 |
+
|
293 |
+
if len(img_batch) > 0:
|
294 |
+
img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
|
295 |
+
|
296 |
+
img_masked = img_batch.copy()
|
297 |
+
img_masked[:, args.img_size//2:] = 0
|
298 |
+
|
299 |
+
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
|
300 |
+
mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
|
301 |
+
|
302 |
+
yield img_batch, mel_batch, frame_batch, coords_batch
|
303 |
+
|
304 |
+
mel_step_size = 16
|
305 |
+
|
306 |
+
def _load(checkpoint_path):
|
307 |
+
# Handle loading for different devices
|
308 |
+
checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
|
309 |
+
return checkpoint
|
310 |
+
|
311 |
+
|
312 |
+
def main(face, audio, model, slow_mode=False):
|
313 |
+
if slow_mode:
|
314 |
+
print("Using SLOW animation mode (full face animation)")
|
315 |
+
else:
|
316 |
+
print("Using FAST animation mode (lips only)")
|
317 |
+
|
318 |
+
if not os.path.isfile(face):
|
319 |
+
raise ValueError('--face argument must be a valid path to video/image file')
|
320 |
+
|
321 |
+
elif face.split('.')[1] in ['jpg', 'png', 'jpeg'] and not slow_mode:
|
322 |
+
full_frames = [cv2.imread(face)]
|
323 |
+
fps = args.fps
|
324 |
+
|
325 |
+
else:
|
326 |
+
video_stream = cv2.VideoCapture(face)
|
327 |
+
fps = video_stream.get(cv2.CAP_PROP_FPS)
|
328 |
+
|
329 |
+
# Get video dimensions for potential downscaling of large videos
|
330 |
+
frame_width = int(video_stream.get(cv2.CAP_PROP_FRAME_WIDTH))
|
331 |
+
frame_height = int(video_stream.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
332 |
+
total_frames = int(video_stream.get(cv2.CAP_PROP_FRAME_COUNT))
|
333 |
+
|
334 |
+
# Auto-adjust resize factor for very large videos
|
335 |
+
original_resize_factor = args.resize_factor
|
336 |
+
if frame_width > 1920 or frame_height > 1080:
|
337 |
+
# For 4K or larger videos, use a higher resize factor
|
338 |
+
if frame_width >= 3840 or frame_height >= 2160:
|
339 |
+
args.resize_factor = max(4, args.resize_factor)
|
340 |
+
print(f"Auto-adjusting resize factor to {args.resize_factor} for high-resolution video")
|
341 |
+
# For 1080p-4K videos
|
342 |
+
elif frame_width > 1920 or frame_height > 1080:
|
343 |
+
args.resize_factor = max(2, args.resize_factor)
|
344 |
+
print(f"Auto-adjusting resize factor to {args.resize_factor} for high-resolution video")
|
345 |
+
|
346 |
+
print('Reading video frames...')
|
347 |
+
|
348 |
+
full_frames = []
|
349 |
+
|
350 |
+
# For large videos, report progress and limit memory usage
|
351 |
+
frame_limit = 5000 # Maximum number of frames to process at once
|
352 |
+
if total_frames > frame_limit:
|
353 |
+
print(f"Large video detected ({total_frames} frames). Will process in chunks.")
|
354 |
+
|
355 |
+
# Use tqdm for progress reporting
|
356 |
+
pbar = tqdm(total=min(total_frames, frame_limit))
|
357 |
+
frame_count = 0
|
358 |
+
|
359 |
+
while frame_count < frame_limit:
|
360 |
+
still_reading, frame = video_stream.read()
|
361 |
+
if not still_reading:
|
362 |
+
video_stream.release()
|
363 |
+
break
|
364 |
+
|
365 |
+
if args.resize_factor > 1:
|
366 |
+
frame = cv2.resize(frame, (frame.shape[1]//args.resize_factor, frame.shape[0]//args.resize_factor))
|
367 |
+
|
368 |
+
if args.rotate:
|
369 |
+
frame = cv2.rotate(frame, cv2.cv2.ROTATE_90_CLOCKWISE)
|
370 |
+
|
371 |
+
y1, y2, x1, x2 = args.crop
|
372 |
+
if x2 == -1: x2 = frame.shape[1]
|
373 |
+
if y2 == -1: y2 = frame.shape[0]
|
374 |
+
|
375 |
+
frame = frame[y1:y2, x1:x2]
|
376 |
+
|
377 |
+
full_frames.append(frame)
|
378 |
+
frame_count += 1
|
379 |
+
pbar.update(1)
|
380 |
+
|
381 |
+
# For very large videos, limit frames to avoid memory issues
|
382 |
+
if frame_count >= frame_limit:
|
383 |
+
print(f"Reached frame limit of {frame_limit}. Processing this chunk.")
|
384 |
+
break
|
385 |
+
|
386 |
+
pbar.close()
|
387 |
+
|
388 |
+
# Reset resize factor to original value after processing
|
389 |
+
args.resize_factor = original_resize_factor
|
390 |
+
|
391 |
+
print ("Number of frames available for inference: "+str(len(full_frames)))
|
392 |
+
|
393 |
+
if not audio.endswith('.wav'):
|
394 |
+
print('Extracting raw audio...')
|
395 |
+
command = 'ffmpeg -y -i {} -strict -2 {}'.format(audio, 'temp/temp.wav')
|
396 |
+
|
397 |
+
subprocess.call(command, shell=True)
|
398 |
+
audio = 'temp/temp.wav'
|
399 |
+
|
400 |
+
wav = load_wav(audio, 16000)
|
401 |
+
mel = melspectrogram(wav)
|
402 |
+
print(mel.shape)
|
403 |
+
|
404 |
+
if np.isnan(mel.reshape(-1)).sum() > 0:
|
405 |
+
raise ValueError('Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again')
|
406 |
+
|
407 |
+
mel_chunks = []
|
408 |
+
mel_idx_multiplier = 80./fps
|
409 |
+
i = 0
|
410 |
+
while 1:
|
411 |
+
start_idx = int(i * mel_idx_multiplier)
|
412 |
+
if start_idx + mel_step_size > len(mel[0]):
|
413 |
+
mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:])
|
414 |
+
break
|
415 |
+
mel_chunks.append(mel[:, start_idx : start_idx + mel_step_size])
|
416 |
+
i += 1
|
417 |
+
|
418 |
+
print("Length of mel chunks: {}".format(len(mel_chunks)))
|
419 |
+
|
420 |
+
full_frames = full_frames[:len(mel_chunks)]
|
421 |
+
|
422 |
+
batch_size = args.wav2lip_batch_size
|
423 |
+
gen = datagen(full_frames.copy(), mel_chunks)
|
424 |
+
|
425 |
+
# Initialize video writer outside the try block
|
426 |
+
out = None
|
427 |
+
try:
|
428 |
+
for i, (img_batch, mel_batch, frames, coords) in enumerate(tqdm(gen,
|
429 |
+
total=int(np.ceil(float(len(mel_chunks))/args.wav2lip_batch_size)))):
|
430 |
+
if i == 0:
|
431 |
+
#model = load_model(checkpoint_path)
|
432 |
+
print ("Model loaded")
|
433 |
+
|
434 |
+
frame_h, frame_w = full_frames[0].shape[:-1]
|
435 |
+
out = cv2.VideoWriter('wav2lip/temp/result.avi',
|
436 |
+
cv2.VideoWriter_fourcc(*'DIVX'), fps, (frame_w, frame_h))
|
437 |
+
|
438 |
+
img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
|
439 |
+
mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
|
440 |
+
|
441 |
+
with torch.no_grad():
|
442 |
+
pred = model(mel_batch, img_batch)
|
443 |
+
|
444 |
+
pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
|
445 |
+
|
446 |
+
for p, f, c in zip(pred, frames, coords):
|
447 |
+
y1, y2, x1, x2 = c
|
448 |
+
p = cv2.resize(p.astype(np.uint8), (x2 - x1, y2 - y1))
|
449 |
+
|
450 |
+
f[y1:y2, x1:x2] = p
|
451 |
+
out.write(f)
|
452 |
+
except Exception as e:
|
453 |
+
print(f"Error during processing: {str(e)}")
|
454 |
+
print("Attempting to save any completed frames...")
|
455 |
+
|
456 |
+
# Save the results - only if out was initialized
|
457 |
+
if out is not None:
|
458 |
+
out.release()
|
459 |
+
|
460 |
+
# Convert the output video to MP4 if needed - only if the AVI exists
|
461 |
+
result_path = 'wav2lip/results/result_voice.mp4'
|
462 |
+
if os.path.exists('wav2lip/temp/result.avi'):
|
463 |
+
# Check if the result file is valid (has frames)
|
464 |
+
avi_info = os.stat('wav2lip/temp/result.avi')
|
465 |
+
if avi_info.st_size > 1000: # If file is too small, it's likely empty
|
466 |
+
# Modified command to include the audio file
|
467 |
+
command = 'ffmpeg -y -i {} -i {} -c:v libx264 -preset ultrafast -c:a aac -map 0:v:0 -map 1:a:0 {}'.format(
|
468 |
+
'wav2lip/temp/result.avi', audio, result_path)
|
469 |
+
try:
|
470 |
+
subprocess.call(command, shell=True)
|
471 |
+
if os.path.exists(result_path):
|
472 |
+
print(f"Successfully created output video with audio at {result_path}")
|
473 |
+
else:
|
474 |
+
print(f"Error: Output video file was not created.")
|
475 |
+
except Exception as e:
|
476 |
+
print(f"Error during video conversion: {str(e)}")
|
477 |
+
else:
|
478 |
+
print(f"Warning: Output AVI file is too small ({avi_info.st_size} bytes). Face detection may have failed.")
|
479 |
+
else:
|
480 |
+
print("No output video was created. Face detection likely failed completely.")
|
481 |
+
# Return a default path even if no output was created
|
482 |
+
|
483 |
+
# Return even if there were errors
|
484 |
+
return result_path
|
wav2lip/models/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .wav2lip import Wav2Lip, Wav2Lip_disc_qual
|
2 |
+
from .syncnet import SyncNet_color
|
wav2lip/models/conv.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
|
5 |
+
class Conv2d(nn.Module):
|
6 |
+
def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
|
7 |
+
super().__init__(*args, **kwargs)
|
8 |
+
self.conv_block = nn.Sequential(
|
9 |
+
nn.Conv2d(cin, cout, kernel_size, stride, padding),
|
10 |
+
nn.BatchNorm2d(cout)
|
11 |
+
)
|
12 |
+
self.act = nn.ReLU()
|
13 |
+
self.residual = residual
|
14 |
+
|
15 |
+
def forward(self, x):
|
16 |
+
out = self.conv_block(x)
|
17 |
+
if self.residual:
|
18 |
+
out += x
|
19 |
+
return self.act(out)
|
20 |
+
|
21 |
+
class nonorm_Conv2d(nn.Module):
|
22 |
+
def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
|
23 |
+
super().__init__(*args, **kwargs)
|
24 |
+
self.conv_block = nn.Sequential(
|
25 |
+
nn.Conv2d(cin, cout, kernel_size, stride, padding),
|
26 |
+
)
|
27 |
+
self.act = nn.LeakyReLU(0.01, inplace=True)
|
28 |
+
|
29 |
+
def forward(self, x):
|
30 |
+
out = self.conv_block(x)
|
31 |
+
return self.act(out)
|
32 |
+
|
33 |
+
class Conv2dTranspose(nn.Module):
|
34 |
+
def __init__(self, cin, cout, kernel_size, stride, padding, output_padding=0, *args, **kwargs):
|
35 |
+
super().__init__(*args, **kwargs)
|
36 |
+
self.conv_block = nn.Sequential(
|
37 |
+
nn.ConvTranspose2d(cin, cout, kernel_size, stride, padding, output_padding),
|
38 |
+
nn.BatchNorm2d(cout)
|
39 |
+
)
|
40 |
+
self.act = nn.ReLU()
|
41 |
+
|
42 |
+
def forward(self, x):
|
43 |
+
out = self.conv_block(x)
|
44 |
+
return self.act(out)
|
wav2lip/models/syncnet.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
|
5 |
+
from .conv import Conv2d
|
6 |
+
|
7 |
+
class SyncNet_color(nn.Module):
|
8 |
+
def __init__(self):
|
9 |
+
super(SyncNet_color, self).__init__()
|
10 |
+
|
11 |
+
self.face_encoder = nn.Sequential(
|
12 |
+
Conv2d(15, 32, kernel_size=(7, 7), stride=1, padding=3),
|
13 |
+
|
14 |
+
Conv2d(32, 64, kernel_size=5, stride=(1, 2), padding=1),
|
15 |
+
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
16 |
+
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
17 |
+
|
18 |
+
Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
|
19 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
20 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
21 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
22 |
+
|
23 |
+
Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
|
24 |
+
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
|
25 |
+
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
|
26 |
+
|
27 |
+
Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
|
28 |
+
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
|
29 |
+
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
|
30 |
+
|
31 |
+
Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
|
32 |
+
Conv2d(512, 512, kernel_size=3, stride=1, padding=0),
|
33 |
+
Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
|
34 |
+
|
35 |
+
self.audio_encoder = nn.Sequential(
|
36 |
+
Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
|
37 |
+
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
|
38 |
+
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
|
39 |
+
|
40 |
+
Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
|
41 |
+
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
42 |
+
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
43 |
+
|
44 |
+
Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
|
45 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
46 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
47 |
+
|
48 |
+
Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
|
49 |
+
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
|
50 |
+
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
|
51 |
+
|
52 |
+
Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
|
53 |
+
Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
|
54 |
+
|
55 |
+
def forward(self, audio_sequences, face_sequences): # audio_sequences := (B, dim, T)
|
56 |
+
face_embedding = self.face_encoder(face_sequences)
|
57 |
+
audio_embedding = self.audio_encoder(audio_sequences)
|
58 |
+
|
59 |
+
audio_embedding = audio_embedding.view(audio_embedding.size(0), -1)
|
60 |
+
face_embedding = face_embedding.view(face_embedding.size(0), -1)
|
61 |
+
|
62 |
+
audio_embedding = F.normalize(audio_embedding, p=2, dim=1)
|
63 |
+
face_embedding = F.normalize(face_embedding, p=2, dim=1)
|
64 |
+
|
65 |
+
|
66 |
+
return audio_embedding, face_embedding
|
wav2lip/models/wav2lip.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
import math
|
5 |
+
|
6 |
+
from .conv import Conv2dTranspose, Conv2d, nonorm_Conv2d
|
7 |
+
|
8 |
+
class Wav2Lip(nn.Module):
|
9 |
+
def __init__(self):
|
10 |
+
super(Wav2Lip, self).__init__()
|
11 |
+
|
12 |
+
self.face_encoder_blocks = nn.ModuleList([
|
13 |
+
nn.Sequential(Conv2d(6, 16, kernel_size=7, stride=1, padding=3)), # 96,96
|
14 |
+
|
15 |
+
nn.Sequential(Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # 48,48
|
16 |
+
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
|
17 |
+
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True)),
|
18 |
+
|
19 |
+
nn.Sequential(Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # 24,24
|
20 |
+
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
21 |
+
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
22 |
+
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True)),
|
23 |
+
|
24 |
+
nn.Sequential(Conv2d(64, 128, kernel_size=3, stride=2, padding=1), # 12,12
|
25 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
26 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True)),
|
27 |
+
|
28 |
+
nn.Sequential(Conv2d(128, 256, kernel_size=3, stride=2, padding=1), # 6,6
|
29 |
+
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
|
30 |
+
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True)),
|
31 |
+
|
32 |
+
nn.Sequential(Conv2d(256, 512, kernel_size=3, stride=2, padding=1), # 3,3
|
33 |
+
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),),
|
34 |
+
|
35 |
+
nn.Sequential(Conv2d(512, 512, kernel_size=3, stride=1, padding=0), # 1, 1
|
36 |
+
Conv2d(512, 512, kernel_size=1, stride=1, padding=0)),])
|
37 |
+
|
38 |
+
self.audio_encoder = nn.Sequential(
|
39 |
+
Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
|
40 |
+
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
|
41 |
+
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
|
42 |
+
|
43 |
+
Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
|
44 |
+
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
45 |
+
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
46 |
+
|
47 |
+
Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
|
48 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
49 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
50 |
+
|
51 |
+
Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
|
52 |
+
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
|
53 |
+
|
54 |
+
Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
|
55 |
+
Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
|
56 |
+
|
57 |
+
self.face_decoder_blocks = nn.ModuleList([
|
58 |
+
nn.Sequential(Conv2d(512, 512, kernel_size=1, stride=1, padding=0),),
|
59 |
+
|
60 |
+
nn.Sequential(Conv2dTranspose(1024, 512, kernel_size=3, stride=1, padding=0), # 3,3
|
61 |
+
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),),
|
62 |
+
|
63 |
+
nn.Sequential(Conv2dTranspose(1024, 512, kernel_size=3, stride=2, padding=1, output_padding=1),
|
64 |
+
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
|
65 |
+
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),), # 6, 6
|
66 |
+
|
67 |
+
nn.Sequential(Conv2dTranspose(768, 384, kernel_size=3, stride=2, padding=1, output_padding=1),
|
68 |
+
Conv2d(384, 384, kernel_size=3, stride=1, padding=1, residual=True),
|
69 |
+
Conv2d(384, 384, kernel_size=3, stride=1, padding=1, residual=True),), # 12, 12
|
70 |
+
|
71 |
+
nn.Sequential(Conv2dTranspose(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1),
|
72 |
+
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
|
73 |
+
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),), # 24, 24
|
74 |
+
|
75 |
+
nn.Sequential(Conv2dTranspose(320, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
|
76 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
77 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),), # 48, 48
|
78 |
+
|
79 |
+
nn.Sequential(Conv2dTranspose(160, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
|
80 |
+
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
81 |
+
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),),]) # 96,96
|
82 |
+
|
83 |
+
self.output_block = nn.Sequential(Conv2d(80, 32, kernel_size=3, stride=1, padding=1),
|
84 |
+
nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0),
|
85 |
+
nn.Sigmoid())
|
86 |
+
|
87 |
+
def forward(self, audio_sequences, face_sequences):
|
88 |
+
# audio_sequences = (B, T, 1, 80, 16)
|
89 |
+
B = audio_sequences.size(0)
|
90 |
+
|
91 |
+
input_dim_size = len(face_sequences.size())
|
92 |
+
if input_dim_size > 4:
|
93 |
+
audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0)
|
94 |
+
face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0)
|
95 |
+
|
96 |
+
audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1
|
97 |
+
|
98 |
+
feats = []
|
99 |
+
x = face_sequences
|
100 |
+
for f in self.face_encoder_blocks:
|
101 |
+
x = f(x)
|
102 |
+
feats.append(x)
|
103 |
+
|
104 |
+
x = audio_embedding
|
105 |
+
for f in self.face_decoder_blocks:
|
106 |
+
x = f(x)
|
107 |
+
try:
|
108 |
+
x = torch.cat((x, feats[-1]), dim=1)
|
109 |
+
except Exception as e:
|
110 |
+
print(x.size())
|
111 |
+
print(feats[-1].size())
|
112 |
+
raise e
|
113 |
+
|
114 |
+
feats.pop()
|
115 |
+
|
116 |
+
x = self.output_block(x)
|
117 |
+
|
118 |
+
if input_dim_size > 4:
|
119 |
+
x = torch.split(x, B, dim=0) # [(B, C, H, W)]
|
120 |
+
outputs = torch.stack(x, dim=2) # (B, C, T, H, W)
|
121 |
+
|
122 |
+
else:
|
123 |
+
outputs = x
|
124 |
+
|
125 |
+
return outputs
|
126 |
+
|
127 |
+
class Wav2Lip_disc_qual(nn.Module):
|
128 |
+
def __init__(self):
|
129 |
+
super(Wav2Lip_disc_qual, self).__init__()
|
130 |
+
|
131 |
+
self.face_encoder_blocks = nn.ModuleList([
|
132 |
+
nn.Sequential(nonorm_Conv2d(3, 32, kernel_size=7, stride=1, padding=3)), # 48,96
|
133 |
+
|
134 |
+
nn.Sequential(nonorm_Conv2d(32, 64, kernel_size=5, stride=(1, 2), padding=2), # 48,48
|
135 |
+
nonorm_Conv2d(64, 64, kernel_size=5, stride=1, padding=2)),
|
136 |
+
|
137 |
+
nn.Sequential(nonorm_Conv2d(64, 128, kernel_size=5, stride=2, padding=2), # 24,24
|
138 |
+
nonorm_Conv2d(128, 128, kernel_size=5, stride=1, padding=2)),
|
139 |
+
|
140 |
+
nn.Sequential(nonorm_Conv2d(128, 256, kernel_size=5, stride=2, padding=2), # 12,12
|
141 |
+
nonorm_Conv2d(256, 256, kernel_size=5, stride=1, padding=2)),
|
142 |
+
|
143 |
+
nn.Sequential(nonorm_Conv2d(256, 512, kernel_size=3, stride=2, padding=1), # 6,6
|
144 |
+
nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=1)),
|
145 |
+
|
146 |
+
nn.Sequential(nonorm_Conv2d(512, 512, kernel_size=3, stride=2, padding=1), # 3,3
|
147 |
+
nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=1),),
|
148 |
+
|
149 |
+
nn.Sequential(nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=0), # 1, 1
|
150 |
+
nonorm_Conv2d(512, 512, kernel_size=1, stride=1, padding=0)),])
|
151 |
+
|
152 |
+
self.binary_pred = nn.Sequential(nn.Conv2d(512, 1, kernel_size=1, stride=1, padding=0), nn.Sigmoid())
|
153 |
+
self.label_noise = .0
|
154 |
+
|
155 |
+
def get_lower_half(self, face_sequences):
|
156 |
+
return face_sequences[:, :, face_sequences.size(2)//2:]
|
157 |
+
|
158 |
+
def to_2d(self, face_sequences):
|
159 |
+
B = face_sequences.size(0)
|
160 |
+
face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0)
|
161 |
+
return face_sequences
|
162 |
+
|
163 |
+
def perceptual_forward(self, false_face_sequences):
|
164 |
+
false_face_sequences = self.to_2d(false_face_sequences)
|
165 |
+
false_face_sequences = self.get_lower_half(false_face_sequences)
|
166 |
+
|
167 |
+
false_feats = false_face_sequences
|
168 |
+
for f in self.face_encoder_blocks:
|
169 |
+
false_feats = f(false_feats)
|
170 |
+
|
171 |
+
false_pred_loss = F.binary_cross_entropy(self.binary_pred(false_feats).view(len(false_feats), -1),
|
172 |
+
torch.ones((len(false_feats), 1)).cuda())
|
173 |
+
|
174 |
+
return false_pred_loss
|
175 |
+
|
176 |
+
def forward(self, face_sequences):
|
177 |
+
face_sequences = self.to_2d(face_sequences)
|
178 |
+
face_sequences = self.get_lower_half(face_sequences)
|
179 |
+
|
180 |
+
x = face_sequences
|
181 |
+
for f in self.face_encoder_blocks:
|
182 |
+
x = f(x)
|
183 |
+
|
184 |
+
return self.binary_pred(x).view(len(x), -1)
|