apptest / app.py
ashishninehertz's picture
Upload folder using huggingface_hub
bed7409 verified
import os
import sys
import streamlit as st
from src.gradio_demo import SadTalker
import tempfile
from PIL import Image
# Set page configuration
st.set_page_config(
page_title="SadTalker - Talking Face Animation",
layout="wide",
initial_sidebar_state="expanded"
)
# Custom CSS styling
st.markdown("""
<style>
.header {
text-align: center;
padding: 1.5rem 0;
margin-bottom: 2rem;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
border-radius: 10px;
box-shadow: 0 4px 8px rgba(0,0,0,0.1);
}
.header h1 {
margin-bottom: 0.5rem;
font-size: 2.5rem;
}
.header p {
margin-bottom: 0;
font-size: 1.1rem;
}
.tab-content {
padding: 1.5rem;
background: #f8f9fa;
border-radius: 10px;
margin-bottom: 1.5rem;
}
.stVideo {
border-radius: 10px;
box-shadow: 0 4px 8px rgba(0,0,0,0.1);
}
.stImage {
border-radius: 10px;
box-shadow: 0 4px 8px rgba(0,0,0,0.1);
}
.settings-section {
background: #ffffff;
padding: 1.5rem;
border-radius: 10px;
margin-bottom: 1.5rem;
box-shadow: 0 2px 4px rgba(0,0,0,0.05);
}
.warning-box {
background-color: #fff3cd;
color: #856404;
padding: 0.75rem 1.25rem;
border-radius: 0.25rem;
margin-bottom: 1rem;
border: 1px solid #ffeeba;
}
.download-btn {
display: flex;
justify-content: center;
margin-top: 1rem;
}
</style>
""", unsafe_allow_html=True)
# Initialize SadTalker with caching
@st.cache_resource
def load_sadtalker():
return SadTalker('checkpoints', 'src/config', lazy_load=True)
sad_talker = load_sadtalker()
# Check if running in webui
try:
import webui
in_webui = True
except:
in_webui = False
# Header section
st.markdown("""
<div class="header">
<h1>😭 SadTalker</h1>
<p>Learning Realistic 3D Motion Coefficients for Stylized Audio-Driven Single Image Talking Face Animation (CVPR 2023)</p>
<div style="display: flex; justify-content: center; gap: 1.5rem; margin-top: 0.5rem;">
<a href="https://arxiv.org/abs/2211.12194" style="color: white; text-decoration: none; font-weight: 500;">πŸ“„ Arxiv</a>
<a href="https://sadtalker.github.io" style="color: white; text-decoration: none; font-weight: 500;">🌐 Homepage</a>
<a href="https://github.com/Winfredy/SadTalker" style="color: white; text-decoration: none; font-weight: 500;">πŸ’» GitHub</a>
</div>
</div>
""", unsafe_allow_html=True)
# Initialize session state
if 'generated_video' not in st.session_state:
st.session_state.generated_video = None
if 'tts_audio' not in st.session_state:
st.session_state.tts_audio = None
if 'source_image' not in st.session_state:
st.session_state.source_image = None
if 'driven_audio' not in st.session_state:
st.session_state.driven_audio = None
# Main columns layout
col1, col2 = st.columns([1, 1], gap="large")
with col1:
st.markdown("### Input Settings")
# Source Image Upload
with st.expander("🎨 Source Image", expanded=True):
uploaded_image = st.file_uploader(
"Upload a clear frontal face image",
type=["jpg", "jpeg", "png"],
key="source_image_upload"
)
if uploaded_image:
st.session_state.source_image = uploaded_image
image = Image.open(uploaded_image)
st.image(image, caption="Source Image", use_container_width=True)
elif st.session_state.source_image:
image = Image.open(st.session_state.source_image)
st.image(image, caption="Source Image (from session)", use_container_width=True)
else:
st.warning("Please upload a source image")
# Audio Input
with st.expander("🎡 Audio Input", expanded=True):
input_method = st.radio(
"Select input method:",
["Upload audio file", "Text-to-speech"],
index=0,
key="audio_input_method",
horizontal=True
)
if input_method == "Upload audio file":
audio_file = st.file_uploader(
"Upload an audio file (WAV, MP3)",
type=["wav", "mp3"],
key="audio_file_upload"
)
if audio_file:
st.session_state.driven_audio = audio_file
st.audio(audio_file)
elif st.session_state.driven_audio and isinstance(st.session_state.driven_audio, str):
st.audio(st.session_state.driven_audio)
else:
if sys.platform != 'win32' and not in_webui:
from src.utils.text2speech import TTSTalker
tts_talker = TTSTalker()
input_text = st.text_area(
"Enter text for speech synthesis:",
height=150,
placeholder="Type what you want the face to say...",
key="tts_input_text"
)
if st.button("Generate Speech", key="tts_generate_button"):
if input_text.strip():
with st.spinner("Generating audio from text..."):
try:
audio_path = tts_talker.test(input_text)
st.session_state.driven_audio = audio_path
st.session_state.tts_audio = audio_path
st.audio(audio_path)
st.success("Audio generated successfully!")
except Exception as e:
st.error(f"Error generating audio: {str(e)}")
else:
st.warning("Please enter some text first")
else:
st.markdown("""
<div class="warning-box">
⚠️ Text-to-speech is not available on Windows or in webui mode.
Please use audio upload instead.
</div>
""", unsafe_allow_html=True)
with col2:
st.markdown("### Generation Settings")
with st.container():
st.markdown("""
<div class="settings-section">
<h4>βš™οΈ Animation Parameters</h4>
""", unsafe_allow_html=True)
# First row of settings
col_a, col_b = st.columns(2)
with col_a:
preprocess_type = st.radio(
"Preprocessing Method",
['crop', 'resize', 'full', 'extcrop', 'extfull'],
index=0,
key="preprocess_type",
help="How to handle the input image before processing"
)
size_of_image = st.radio(
"Face Model Resolution",
[256, 512],
index=0,
key="size_of_image",
horizontal=True,
help="Higher resolution (512) may produce better quality but requires more resources"
)
with col_b:
is_still_mode = st.checkbox(
"Still Mode",
value=False,
key="is_still_mode",
help="Produces fewer head movements (works best with 'full' preprocessing)"
)
enhancer = st.checkbox(
"Use GFPGAN Enhancer",
value=False,
key="enhancer",
help="Improves face quality using GFPGAN (may slow down processing)"
)
# Second row of settings
pose_style = st.slider(
"Pose Style",
min_value=0,
max_value=46,
value=0,
step=1,
key="pose_style",
help="Different head poses and expressions"
)
batch_size = st.slider(
"Batch Size",
min_value=1,
max_value=10,
value=2,
step=1,
key="batch_size",
help="Number of frames processed at once (higher may be faster but uses more memory)"
)
st.markdown("</div>", unsafe_allow_html=True)
# Generate button
if st.button(
"✨ Generate Talking Face Animation",
type="primary",
use_container_width=True,
key="generate_button"
):
if not st.session_state.source_image:
st.error("Please upload a source image first")
elif input_method == "Upload audio file" and not st.session_state.driven_audio:
st.error("Please upload an audio file first")
elif input_method == "Text-to-speech" and not st.session_state.driven_audio:
st.error("Please generate audio from text first")
else:
with st.spinner("Generating talking face animation. This may take a few minutes..."):
try:
# Save uploaded files to temp files
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp_image:
image = Image.open(st.session_state.source_image)
image.save(tmp_image.name)
audio_path = None
if input_method == "Upload audio file":
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_audio:
tmp_audio.write(st.session_state.driven_audio.read())
audio_path = tmp_audio.name
else:
audio_path = st.session_state.driven_audio
# Generate video
try:
# Ensure all paths are absolute
tmp_image_path = os.path.abspath(tmp_image.name)
audio_path = os.path.abspath(audio_path) if audio_path else None
# Convert all parameters to correct types
video_path = sad_talker.test(
source_image=tmp_image_path,
driven_audio=audio_path,
preprocess_type=str(preprocess_type),
is_still_mode=bool(is_still_mode),
enhancer=bool(enhancer),
batch_size=int(batch_size),
size_of_image=int(size_of_image),
pose_style=int(pose_style),
# These additional parameters might be needed:
)
# Verify the output
if not os.path.exists(video_path):
raise FileNotFoundError(f"Output video not created at {video_path}")
st.session_state.generated_video = video_path
except Exception as e:
st.error(f"Generation failed: {str(e)}")
# Debug information
st.text(f"Parameters used:")
st.json({
"source_image": tmp_image_path,
"driven_audio": audio_path,
"preprocess_type": preprocess_type,
"is_still_mode": is_still_mode,
"enhancer": enhancer,
"batch_size": batch_size,
"size_of_image": size_of_image,
"pose_style": pose_style
})
# Store the generated video in session state
st.session_state.generated_video = video_path
# Clean up temp files
os.unlink(tmp_image.name)
if audio_path and os.path.exists(audio_path) and input_method == "Upload audio file":
os.unlink(audio_path)
st.success("Generation complete! View your result below.")
except Exception as e:
st.error(f"An error occurred during generation: {str(e)}")
st.error("Please check your inputs and try again")
# Results section
if st.session_state.generated_video:
st.markdown("---")
st.markdown("### Generated Animation")
# Display video and download options
col_video, col_download = st.columns([3, 1])
with col_video:
st.video(st.session_state.generated_video)
with col_download:
# Download button
with open(st.session_state.generated_video, "rb") as f:
video_bytes = f.read()
st.download_button(
label="Download Video",
data=video_bytes,
file_name="sadtalker_animation.mp4",
mime="video/mp4",
use_container_width=True,
key="download_button"
)
# Regenerate button
if st.button(
"πŸ”„ Regenerate with Same Settings",
use_container_width=True,
key="regenerate_button"
):
st.experimental_rerun()
# New generation button
if st.button(
"πŸ†• Start New Generation",
use_container_width=True,
key="new_generation_button"
):
st.session_state.generated_video = None
st.session_state.tts_audio = None
st.session_state.source_image = None
st.session_state.driven_audio = None
st.experimental_rerun()
# Footer
st.markdown("---")
st.markdown("""
<div style="text-align: center; color: #666; padding: 1.5rem 0; font-size: 0.9rem;">
<p>SadTalker: Learning Realistic 3D Motion Coefficients for Stylized Audio-Driven Single Image Talking Face Animation</p>
<p>CVPR 2023 | <a href="https://github.com/Winfredy/SadTalker" target="_blank">GitHub Repository</a> | <a href="https://sadtalker.github.io" target="_blank">Project Page</a></p>
</div>
""", unsafe_allow_html=True)