Spaces:
Runtime error
Runtime error
import os | |
import sys | |
import streamlit as st | |
from src.gradio_demo import SadTalker | |
import tempfile | |
from PIL import Image | |
# Set page configuration | |
st.set_page_config( | |
page_title="SadTalker - Talking Face Animation", | |
layout="wide", | |
initial_sidebar_state="expanded" | |
) | |
# Custom CSS styling | |
st.markdown(""" | |
<style> | |
.header { | |
text-align: center; | |
padding: 1.5rem 0; | |
margin-bottom: 2rem; | |
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
color: white; | |
border-radius: 10px; | |
box-shadow: 0 4px 8px rgba(0,0,0,0.1); | |
} | |
.header h1 { | |
margin-bottom: 0.5rem; | |
font-size: 2.5rem; | |
} | |
.header p { | |
margin-bottom: 0; | |
font-size: 1.1rem; | |
} | |
.tab-content { | |
padding: 1.5rem; | |
background: #f8f9fa; | |
border-radius: 10px; | |
margin-bottom: 1.5rem; | |
} | |
.stVideo { | |
border-radius: 10px; | |
box-shadow: 0 4px 8px rgba(0,0,0,0.1); | |
} | |
.stImage { | |
border-radius: 10px; | |
box-shadow: 0 4px 8px rgba(0,0,0,0.1); | |
} | |
.settings-section { | |
background: #ffffff; | |
padding: 1.5rem; | |
border-radius: 10px; | |
margin-bottom: 1.5rem; | |
box-shadow: 0 2px 4px rgba(0,0,0,0.05); | |
} | |
.warning-box { | |
background-color: #fff3cd; | |
color: #856404; | |
padding: 0.75rem 1.25rem; | |
border-radius: 0.25rem; | |
margin-bottom: 1rem; | |
border: 1px solid #ffeeba; | |
} | |
.download-btn { | |
display: flex; | |
justify-content: center; | |
margin-top: 1rem; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
# Initialize SadTalker with caching | |
def load_sadtalker(): | |
return SadTalker('checkpoints', 'src/config', lazy_load=True) | |
sad_talker = load_sadtalker() | |
# Check if running in webui | |
try: | |
import webui | |
in_webui = True | |
except: | |
in_webui = False | |
# Header section | |
st.markdown(""" | |
<div class="header"> | |
<h1>π SadTalker</h1> | |
<p>Learning Realistic 3D Motion Coefficients for Stylized Audio-Driven Single Image Talking Face Animation (CVPR 2023)</p> | |
<div style="display: flex; justify-content: center; gap: 1.5rem; margin-top: 0.5rem;"> | |
<a href="https://arxiv.org/abs/2211.12194" style="color: white; text-decoration: none; font-weight: 500;">π Arxiv</a> | |
<a href="https://sadtalker.github.io" style="color: white; text-decoration: none; font-weight: 500;">π Homepage</a> | |
<a href="https://github.com/Winfredy/SadTalker" style="color: white; text-decoration: none; font-weight: 500;">π» GitHub</a> | |
</div> | |
</div> | |
""", unsafe_allow_html=True) | |
# Initialize session state | |
if 'generated_video' not in st.session_state: | |
st.session_state.generated_video = None | |
if 'tts_audio' not in st.session_state: | |
st.session_state.tts_audio = None | |
if 'source_image' not in st.session_state: | |
st.session_state.source_image = None | |
if 'driven_audio' not in st.session_state: | |
st.session_state.driven_audio = None | |
# Main columns layout | |
col1, col2 = st.columns([1, 1], gap="large") | |
with col1: | |
st.markdown("### Input Settings") | |
# Source Image Upload | |
with st.expander("π¨ Source Image", expanded=True): | |
uploaded_image = st.file_uploader( | |
"Upload a clear frontal face image", | |
type=["jpg", "jpeg", "png"], | |
key="source_image_upload" | |
) | |
if uploaded_image: | |
st.session_state.source_image = uploaded_image | |
image = Image.open(uploaded_image) | |
st.image(image, caption="Source Image", use_container_width=True) | |
elif st.session_state.source_image: | |
image = Image.open(st.session_state.source_image) | |
st.image(image, caption="Source Image (from session)", use_container_width=True) | |
else: | |
st.warning("Please upload a source image") | |
# Audio Input | |
with st.expander("π΅ Audio Input", expanded=True): | |
input_method = st.radio( | |
"Select input method:", | |
["Upload audio file", "Text-to-speech"], | |
index=0, | |
key="audio_input_method", | |
horizontal=True | |
) | |
if input_method == "Upload audio file": | |
audio_file = st.file_uploader( | |
"Upload an audio file (WAV, MP3)", | |
type=["wav", "mp3"], | |
key="audio_file_upload" | |
) | |
if audio_file: | |
st.session_state.driven_audio = audio_file | |
st.audio(audio_file) | |
elif st.session_state.driven_audio and isinstance(st.session_state.driven_audio, str): | |
st.audio(st.session_state.driven_audio) | |
else: | |
if sys.platform != 'win32' and not in_webui: | |
from src.utils.text2speech import TTSTalker | |
tts_talker = TTSTalker() | |
input_text = st.text_area( | |
"Enter text for speech synthesis:", | |
height=150, | |
placeholder="Type what you want the face to say...", | |
key="tts_input_text" | |
) | |
if st.button("Generate Speech", key="tts_generate_button"): | |
if input_text.strip(): | |
with st.spinner("Generating audio from text..."): | |
try: | |
audio_path = tts_talker.test(input_text) | |
st.session_state.driven_audio = audio_path | |
st.session_state.tts_audio = audio_path | |
st.audio(audio_path) | |
st.success("Audio generated successfully!") | |
except Exception as e: | |
st.error(f"Error generating audio: {str(e)}") | |
else: | |
st.warning("Please enter some text first") | |
else: | |
st.markdown(""" | |
<div class="warning-box"> | |
β οΈ Text-to-speech is not available on Windows or in webui mode. | |
Please use audio upload instead. | |
</div> | |
""", unsafe_allow_html=True) | |
with col2: | |
st.markdown("### Generation Settings") | |
with st.container(): | |
st.markdown(""" | |
<div class="settings-section"> | |
<h4>βοΈ Animation Parameters</h4> | |
""", unsafe_allow_html=True) | |
# First row of settings | |
col_a, col_b = st.columns(2) | |
with col_a: | |
preprocess_type = st.radio( | |
"Preprocessing Method", | |
['crop', 'resize', 'full', 'extcrop', 'extfull'], | |
index=0, | |
key="preprocess_type", | |
help="How to handle the input image before processing" | |
) | |
size_of_image = st.radio( | |
"Face Model Resolution", | |
[256, 512], | |
index=0, | |
key="size_of_image", | |
horizontal=True, | |
help="Higher resolution (512) may produce better quality but requires more resources" | |
) | |
with col_b: | |
is_still_mode = st.checkbox( | |
"Still Mode", | |
value=False, | |
key="is_still_mode", | |
help="Produces fewer head movements (works best with 'full' preprocessing)" | |
) | |
enhancer = st.checkbox( | |
"Use GFPGAN Enhancer", | |
value=False, | |
key="enhancer", | |
help="Improves face quality using GFPGAN (may slow down processing)" | |
) | |
# Second row of settings | |
pose_style = st.slider( | |
"Pose Style", | |
min_value=0, | |
max_value=46, | |
value=0, | |
step=1, | |
key="pose_style", | |
help="Different head poses and expressions" | |
) | |
batch_size = st.slider( | |
"Batch Size", | |
min_value=1, | |
max_value=10, | |
value=2, | |
step=1, | |
key="batch_size", | |
help="Number of frames processed at once (higher may be faster but uses more memory)" | |
) | |
st.markdown("</div>", unsafe_allow_html=True) | |
# Generate button | |
if st.button( | |
"β¨ Generate Talking Face Animation", | |
type="primary", | |
use_container_width=True, | |
key="generate_button" | |
): | |
if not st.session_state.source_image: | |
st.error("Please upload a source image first") | |
elif input_method == "Upload audio file" and not st.session_state.driven_audio: | |
st.error("Please upload an audio file first") | |
elif input_method == "Text-to-speech" and not st.session_state.driven_audio: | |
st.error("Please generate audio from text first") | |
else: | |
with st.spinner("Generating talking face animation. This may take a few minutes..."): | |
try: | |
# Save uploaded files to temp files | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp_image: | |
image = Image.open(st.session_state.source_image) | |
image.save(tmp_image.name) | |
audio_path = None | |
if input_method == "Upload audio file": | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_audio: | |
tmp_audio.write(st.session_state.driven_audio.read()) | |
audio_path = tmp_audio.name | |
else: | |
audio_path = st.session_state.driven_audio | |
# Generate video | |
try: | |
# Ensure all paths are absolute | |
tmp_image_path = os.path.abspath(tmp_image.name) | |
audio_path = os.path.abspath(audio_path) if audio_path else None | |
# Convert all parameters to correct types | |
video_path = sad_talker.test( | |
source_image=tmp_image_path, | |
driven_audio=audio_path, | |
preprocess_type=str(preprocess_type), | |
is_still_mode=bool(is_still_mode), | |
enhancer=bool(enhancer), | |
batch_size=int(batch_size), | |
size_of_image=int(size_of_image), | |
pose_style=int(pose_style), | |
# These additional parameters might be needed: | |
) | |
# Verify the output | |
if not os.path.exists(video_path): | |
raise FileNotFoundError(f"Output video not created at {video_path}") | |
st.session_state.generated_video = video_path | |
except Exception as e: | |
st.error(f"Generation failed: {str(e)}") | |
# Debug information | |
st.text(f"Parameters used:") | |
st.json({ | |
"source_image": tmp_image_path, | |
"driven_audio": audio_path, | |
"preprocess_type": preprocess_type, | |
"is_still_mode": is_still_mode, | |
"enhancer": enhancer, | |
"batch_size": batch_size, | |
"size_of_image": size_of_image, | |
"pose_style": pose_style | |
}) | |
# Store the generated video in session state | |
st.session_state.generated_video = video_path | |
# Clean up temp files | |
os.unlink(tmp_image.name) | |
if audio_path and os.path.exists(audio_path) and input_method == "Upload audio file": | |
os.unlink(audio_path) | |
st.success("Generation complete! View your result below.") | |
except Exception as e: | |
st.error(f"An error occurred during generation: {str(e)}") | |
st.error("Please check your inputs and try again") | |
# Results section | |
if st.session_state.generated_video: | |
st.markdown("---") | |
st.markdown("### Generated Animation") | |
# Display video and download options | |
col_video, col_download = st.columns([3, 1]) | |
with col_video: | |
st.video(st.session_state.generated_video) | |
with col_download: | |
# Download button | |
with open(st.session_state.generated_video, "rb") as f: | |
video_bytes = f.read() | |
st.download_button( | |
label="Download Video", | |
data=video_bytes, | |
file_name="sadtalker_animation.mp4", | |
mime="video/mp4", | |
use_container_width=True, | |
key="download_button" | |
) | |
# Regenerate button | |
if st.button( | |
"π Regenerate with Same Settings", | |
use_container_width=True, | |
key="regenerate_button" | |
): | |
st.experimental_rerun() | |
# New generation button | |
if st.button( | |
"π Start New Generation", | |
use_container_width=True, | |
key="new_generation_button" | |
): | |
st.session_state.generated_video = None | |
st.session_state.tts_audio = None | |
st.session_state.source_image = None | |
st.session_state.driven_audio = None | |
st.experimental_rerun() | |
# Footer | |
st.markdown("---") | |
st.markdown(""" | |
<div style="text-align: center; color: #666; padding: 1.5rem 0; font-size: 0.9rem;"> | |
<p>SadTalker: Learning Realistic 3D Motion Coefficients for Stylized Audio-Driven Single Image Talking Face Animation</p> | |
<p>CVPR 2023 | <a href="https://github.com/Winfredy/SadTalker" target="_blank">GitHub Repository</a> | <a href="https://sadtalker.github.io" target="_blank">Project Page</a></p> | |
</div> | |
""", unsafe_allow_html=True) |