import os import sys import streamlit as st from src.gradio_demo import SadTalker import tempfile from PIL import Image # Set page configuration st.set_page_config( page_title="SadTalker - Talking Face Animation", layout="wide", initial_sidebar_state="expanded" ) # Custom CSS styling st.markdown(""" """, unsafe_allow_html=True) # Initialize SadTalker with caching @st.cache_resource def load_sadtalker(): return SadTalker('checkpoints', 'src/config', lazy_load=True) sad_talker = load_sadtalker() # Check if running in webui try: import webui in_webui = True except: in_webui = False # Header section st.markdown("""

😭 SadTalker

Learning Realistic 3D Motion Coefficients for Stylized Audio-Driven Single Image Talking Face Animation (CVPR 2023)

📄 Arxiv 🌐 Homepage 💻 GitHub
""", unsafe_allow_html=True) # Initialize session state if 'generated_video' not in st.session_state: st.session_state.generated_video = None if 'tts_audio' not in st.session_state: st.session_state.tts_audio = None if 'source_image' not in st.session_state: st.session_state.source_image = None if 'driven_audio' not in st.session_state: st.session_state.driven_audio = None # Main columns layout col1, col2 = st.columns([1, 1], gap="large") with col1: st.markdown("### Input Settings") # Source Image Upload with st.expander("🎨 Source Image", expanded=True): uploaded_image = st.file_uploader( "Upload a clear frontal face image", type=["jpg", "jpeg", "png"], key="source_image_upload" ) if uploaded_image: st.session_state.source_image = uploaded_image image = Image.open(uploaded_image) st.image(image, caption="Source Image", use_container_width=True) elif st.session_state.source_image: image = Image.open(st.session_state.source_image) st.image(image, caption="Source Image (from session)", use_container_width=True) else: st.warning("Please upload a source image") # Audio Input with st.expander("🎵 Audio Input", expanded=True): input_method = st.radio( "Select input method:", ["Upload audio file", "Text-to-speech"], index=0, key="audio_input_method", horizontal=True ) if input_method == "Upload audio file": audio_file = st.file_uploader( "Upload an audio file (WAV, MP3)", type=["wav", "mp3"], key="audio_file_upload" ) if audio_file: st.session_state.driven_audio = audio_file st.audio(audio_file) elif st.session_state.driven_audio and isinstance(st.session_state.driven_audio, str): st.audio(st.session_state.driven_audio) else: if sys.platform != 'win32' and not in_webui: from src.utils.text2speech import TTSTalker tts_talker = TTSTalker() input_text = st.text_area( "Enter text for speech synthesis:", height=150, placeholder="Type what you want the face to say...", key="tts_input_text" ) if st.button("Generate Speech", key="tts_generate_button"): if input_text.strip(): with st.spinner("Generating audio from text..."): try: audio_path = tts_talker.test(input_text) st.session_state.driven_audio = audio_path st.session_state.tts_audio = audio_path st.audio(audio_path) st.success("Audio generated successfully!") except Exception as e: st.error(f"Error generating audio: {str(e)}") else: st.warning("Please enter some text first") else: st.markdown("""
⚠️ Text-to-speech is not available on Windows or in webui mode. Please use audio upload instead.
""", unsafe_allow_html=True) with col2: st.markdown("### Generation Settings") with st.container(): st.markdown("""

⚙️ Animation Parameters

""", unsafe_allow_html=True) # First row of settings col_a, col_b = st.columns(2) with col_a: preprocess_type = st.radio( "Preprocessing Method", ['crop', 'resize', 'full', 'extcrop', 'extfull'], index=0, key="preprocess_type", help="How to handle the input image before processing" ) size_of_image = st.radio( "Face Model Resolution", [256, 512], index=0, key="size_of_image", horizontal=True, help="Higher resolution (512) may produce better quality but requires more resources" ) with col_b: is_still_mode = st.checkbox( "Still Mode", value=False, key="is_still_mode", help="Produces fewer head movements (works best with 'full' preprocessing)" ) enhancer = st.checkbox( "Use GFPGAN Enhancer", value=False, key="enhancer", help="Improves face quality using GFPGAN (may slow down processing)" ) # Second row of settings pose_style = st.slider( "Pose Style", min_value=0, max_value=46, value=0, step=1, key="pose_style", help="Different head poses and expressions" ) batch_size = st.slider( "Batch Size", min_value=1, max_value=10, value=2, step=1, key="batch_size", help="Number of frames processed at once (higher may be faster but uses more memory)" ) st.markdown("
", unsafe_allow_html=True) # Generate button if st.button( "✨ Generate Talking Face Animation", type="primary", use_container_width=True, key="generate_button" ): if not st.session_state.source_image: st.error("Please upload a source image first") elif input_method == "Upload audio file" and not st.session_state.driven_audio: st.error("Please upload an audio file first") elif input_method == "Text-to-speech" and not st.session_state.driven_audio: st.error("Please generate audio from text first") else: with st.spinner("Generating talking face animation. This may take a few minutes..."): try: # Save uploaded files to temp files with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp_image: image = Image.open(st.session_state.source_image) image.save(tmp_image.name) audio_path = None if input_method == "Upload audio file": with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_audio: tmp_audio.write(st.session_state.driven_audio.read()) audio_path = tmp_audio.name else: audio_path = st.session_state.driven_audio # Generate video try: # Ensure all paths are absolute tmp_image_path = os.path.abspath(tmp_image.name) audio_path = os.path.abspath(audio_path) if audio_path else None # Convert all parameters to correct types video_path = sad_talker.test( source_image=tmp_image_path, driven_audio=audio_path, preprocess_type=str(preprocess_type), is_still_mode=bool(is_still_mode), enhancer=bool(enhancer), batch_size=int(batch_size), size_of_image=int(size_of_image), pose_style=int(pose_style), # These additional parameters might be needed: ) # Verify the output if not os.path.exists(video_path): raise FileNotFoundError(f"Output video not created at {video_path}") st.session_state.generated_video = video_path except Exception as e: st.error(f"Generation failed: {str(e)}") # Debug information st.text(f"Parameters used:") st.json({ "source_image": tmp_image_path, "driven_audio": audio_path, "preprocess_type": preprocess_type, "is_still_mode": is_still_mode, "enhancer": enhancer, "batch_size": batch_size, "size_of_image": size_of_image, "pose_style": pose_style }) # Store the generated video in session state st.session_state.generated_video = video_path # Clean up temp files os.unlink(tmp_image.name) if audio_path and os.path.exists(audio_path) and input_method == "Upload audio file": os.unlink(audio_path) st.success("Generation complete! View your result below.") except Exception as e: st.error(f"An error occurred during generation: {str(e)}") st.error("Please check your inputs and try again") # Results section if st.session_state.generated_video: st.markdown("---") st.markdown("### Generated Animation") # Display video and download options col_video, col_download = st.columns([3, 1]) with col_video: st.video(st.session_state.generated_video) with col_download: # Download button with open(st.session_state.generated_video, "rb") as f: video_bytes = f.read() st.download_button( label="Download Video", data=video_bytes, file_name="sadtalker_animation.mp4", mime="video/mp4", use_container_width=True, key="download_button" ) # Regenerate button if st.button( "🔄 Regenerate with Same Settings", use_container_width=True, key="regenerate_button" ): st.experimental_rerun() # New generation button if st.button( "🆕 Start New Generation", use_container_width=True, key="new_generation_button" ): st.session_state.generated_video = None st.session_state.tts_audio = None st.session_state.source_image = None st.session_state.driven_audio = None st.experimental_rerun() # Footer st.markdown("---") st.markdown("""

SadTalker: Learning Realistic 3D Motion Coefficients for Stylized Audio-Driven Single Image Talking Face Animation

CVPR 2023 | GitHub Repository | Project Page

""", unsafe_allow_html=True)