Spaces:
Running
on
L40S
Running
on
L40S
| import os | |
| os.system("pip install pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cu121_pyt221/download.html") | |
| import shutil | |
| import math | |
| from huggingface_hub import snapshot_download | |
| os.makedirs("pretrained_models", exist_ok=True) | |
| snapshot_download( | |
| repo_id="multimodalart/diffposetalk", | |
| local_dir="pretrained_models/diffposetalk" | |
| ) | |
| base_dir = "pretrained_models" | |
| os.makedirs(base_dir, exist_ok=True) | |
| # Download FLAME, mediapipe, and smirk | |
| for model in ["FLAME", "mediapipe", "smirk"]: | |
| # Download to a temp folder first | |
| temp_dir = f"{base_dir}/{model}_temp" | |
| snapshot_download( | |
| repo_id="Skywork/SkyReels-A1", | |
| local_dir=temp_dir, | |
| allow_patterns=f"extra_models/{model}/**" | |
| ) | |
| # Move files from nested extra_models/model to the proper location | |
| src_dir = f"{temp_dir}/extra_models/{model}" | |
| dst_dir = f"{base_dir}/{model}" | |
| os.makedirs(dst_dir, exist_ok=True) | |
| # Move all contents | |
| for item in os.listdir(src_dir): | |
| shutil.move(f"{src_dir}/{item}", f"{dst_dir}/{item}") | |
| # Clean up temp directory | |
| shutil.rmtree(temp_dir) | |
| # Download SkyReels-A1-5B | |
| snapshot_download( | |
| repo_id="Skywork/SkyReels-A1", | |
| local_dir=f"{base_dir}/SkyReels-A1-5B", | |
| ) | |
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| import cv2 | |
| import gc | |
| import tempfile | |
| import moviepy.editor as mp | |
| from facexlib.utils.face_restoration_helper import FaceRestoreHelper | |
| from diffusers.utils import export_to_video, load_image | |
| # Import required modules from SkyReels | |
| from skyreels_a1.models.transformer3d import CogVideoXTransformer3DModel | |
| from skyreels_a1.skyreels_a1_i2v_pipeline import SkyReelsA1ImagePoseToVideoPipeline | |
| from skyreels_a1.pre_process_lmk3d import FaceAnimationProcessor | |
| from skyreels_a1.src.media_pipe.mp_utils import LMKExtractor | |
| from skyreels_a1.src.media_pipe.draw_util_2d import FaceMeshVisualizer2d | |
| from diffusers.models import AutoencoderKLCogVideoX | |
| from transformers import SiglipImageProcessor, SiglipVisionModel | |
| from diffposetalk.diffposetalk import DiffPoseTalk | |
| # Helper functions from the original script | |
| def parse_video(driving_frames, max_frame_num, fps=25): | |
| video_length = len(driving_frames) | |
| duration = video_length / fps | |
| target_times = np.arange(0, duration, 1/12) | |
| frame_indices = (target_times * fps).astype(np.int32) | |
| frame_indices = frame_indices[frame_indices < video_length] | |
| new_driving_frames = [] | |
| for idx in frame_indices: | |
| new_driving_frames.append(driving_frames[idx]) | |
| if len(new_driving_frames) >= max_frame_num - 1: | |
| break | |
| video_lenght_add = max_frame_num - len(new_driving_frames) - 1 | |
| new_driving_frames = [new_driving_frames[0]]*2 + new_driving_frames[1:len(new_driving_frames)-1] + [new_driving_frames[-1]] * video_lenght_add | |
| return new_driving_frames | |
| def write_mp4(video_path, samples, fps=12): | |
| clip = mp.ImageSequenceClip(samples, fps=fps) | |
| clip.write_videofile(video_path, audio_codec="aac", codec="libx264", | |
| ffmpeg_params=["-crf", "18", "-preset", "slow"]) | |
| def save_video_with_audio(video_path, audio_path, save_path): | |
| video_clip = mp.VideoFileClip(video_path) | |
| audio_clip = mp.AudioFileClip(audio_path) | |
| if audio_clip.duration > video_clip.duration: | |
| audio_clip = audio_clip.subclip(0, video_clip.duration) | |
| video_with_audio = video_clip.set_audio(audio_clip) | |
| video_with_audio.write_videofile(save_path, fps=12, codec="libx264", audio_codec="aac") | |
| # Clean up | |
| video_clip.close() | |
| audio_clip.close() | |
| return save_path | |
| def pad_video(driving_frames, fps=25): | |
| video_length = len(driving_frames) | |
| duration = video_length / fps | |
| target_times = np.arange(0, duration, 1/12) | |
| frame_indices = (target_times * fps).astype(np.int32) | |
| frame_indices = frame_indices[frame_indices < video_length] | |
| new_driving_frames = [] | |
| for idx in frame_indices: | |
| new_driving_frames.append(driving_frames[idx]) | |
| pad_length = math.ceil(len(new_driving_frames) / 48) * 48 - len(new_driving_frames) | |
| new_driving_frames.extend([new_driving_frames[-1]]*pad_length) | |
| return new_driving_frames, pad_length | |
| # Global parameters | |
| model_name = "pretrained_models/SkyReels-A1-5B/" | |
| siglip_name = "pretrained_models/SkyReels-A1-5B/siglip-so400m-patch14-384" | |
| weight_dtype = torch.bfloat16 | |
| max_frame_num = 49 | |
| sample_size = [480, 720] | |
| # Preload all models in global context | |
| print("Loading models...") | |
| # Load LMK extractor and processors | |
| lmk_extractor = LMKExtractor() | |
| processor = FaceAnimationProcessor(checkpoint='pretrained_models/smirk/SMIRK_em1.pt') | |
| vis = FaceMeshVisualizer2d(forehead_edge=False, draw_head=False, draw_iris=False) | |
| face_helper = FaceRestoreHelper(upscale_factor=1, face_size=512, crop_ratio=(1, 1), | |
| det_model='retinaface_resnet50', save_ext='png', device="cuda") | |
| # Load siglip visual encoder | |
| siglip = SiglipVisionModel.from_pretrained(siglip_name) | |
| siglip_normalize = SiglipImageProcessor.from_pretrained(siglip_name) | |
| # Load diffposetalk | |
| diffposetalk = DiffPoseTalk() | |
| # Load SkyReels models | |
| transformer = CogVideoXTransformer3DModel.from_pretrained( | |
| model_name, | |
| subfolder="transformer" | |
| ).to(weight_dtype) | |
| vae = AutoencoderKLCogVideoX.from_pretrained( | |
| model_name, | |
| subfolder="vae" | |
| ).to(weight_dtype) | |
| lmk_encoder = AutoencoderKLCogVideoX.from_pretrained( | |
| model_name, | |
| subfolder="pose_guider", | |
| ).to(weight_dtype) | |
| # Set up pipeline | |
| pipe = SkyReelsA1ImagePoseToVideoPipeline.from_pretrained( | |
| model_name, | |
| transformer=transformer, | |
| vae=vae, | |
| lmk_encoder=lmk_encoder, | |
| image_encoder=siglip, | |
| feature_extractor=siglip_normalize, | |
| torch_dtype=torch.bfloat16 | |
| ) | |
| pipe.to("cuda") | |
| pipe.transformer = torch.compile(pipe.transformer) | |
| pipe.vae.enable_tiling() | |
| pipe.vae = torch.compile(pipe.vae) | |
| # pipe.enable_model_cpu_offload() | |
| print("Models loaded successfully!") | |
| def process_image_audio(image_path, audio_path, guidance_scale=3.0, steps=10, progress=gr.Progress()): | |
| progress(0.1, desc="Processing inputs...") | |
| # Create a directory for outputs if it doesn't exist | |
| output_dir = "gradio_outputs" | |
| os.makedirs(output_dir, exist_ok=True) | |
| # Create temp files for processing | |
| with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_video_file, \ | |
| tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_output_file: | |
| temp_video_path = temp_video_file.name | |
| final_output_path = temp_output_file.name | |
| # Set seed | |
| # seed = 43 | |
| # generator = torch.Generator(device="cuda").manual_seed(seed) | |
| progress(0.2, desc="Processing image...") | |
| # Load and process image | |
| image = load_image(image=image_path) | |
| image = processor.crop_and_resize(image, sample_size[0], sample_size[1]) | |
| # Crop face | |
| ref_image, x1, y1 = processor.face_crop(np.array(image)) | |
| face_h, face_w, _ = ref_image.shape | |
| source_image = ref_image | |
| progress(0.3, desc="Processing facial landmarks...") | |
| # Process source image | |
| source_outputs, source_tform, image_original = processor.process_source_image(source_image) | |
| progress(0.4, desc="Processing audio...") | |
| # Process audio and generate driving outputs | |
| driving_outputs = diffposetalk.infer_from_file( | |
| audio_path, | |
| source_outputs["shape_params"].view(-1)[:100].detach().cpu().numpy() | |
| ) | |
| progress(0.5, desc="Processing landmarks from coefficients...") | |
| # Process landmarks | |
| out_frames = processor.preprocess_lmk3d_from_coef( | |
| source_outputs, source_tform, image_original.shape, driving_outputs | |
| ) | |
| out_frames, pad_length = pad_video(out_frames) | |
| print(len(out_frames), pad_length) | |
| # out_frames = parse_video(out_frames, max_frame_num) | |
| rescale_motions = np.zeros_like(image)[np.newaxis, :].repeat(len(out_frames), axis=0) | |
| for ii in range(rescale_motions.shape[0]): | |
| rescale_motions[ii][y1:y1+face_h, x1:x1+face_w] = out_frames[ii] | |
| ref_image_resized = cv2.resize(ref_image, (512, 512)) | |
| ref_lmk = lmk_extractor(ref_image_resized[:, :, ::-1]) | |
| ref_img = vis.draw_landmarks_v3( | |
| (512, 512), (face_w, face_h), | |
| ref_lmk['lmks'].astype(np.float32), normed=True | |
| ) | |
| first_motion = np.zeros_like(np.array(image)) | |
| first_motion[y1:y1+face_h, x1:x1+face_w] = ref_img | |
| first_motion = first_motion[np.newaxis, :] | |
| # motions = np.concatenate([first_motion, rescale_motions]) | |
| # input_video = motions[:max_frame_num] | |
| # Face alignment | |
| face_helper.clean_all() | |
| face_helper.read_image(np.array(image)[:, :, ::-1]) | |
| face_helper.get_face_landmarks_5(only_center_face=True) | |
| face_helper.align_warp_face() | |
| align_face = face_helper.cropped_faces[0] | |
| image_face = align_face[:, :, ::-1] | |
| # Prepare input video | |
| # input_video = torch.from_numpy(np.array(input_video)).permute([3, 0, 1, 2]).unsqueeze(0) | |
| # input_video = input_video / 255 | |
| progress(0.6, desc="Generating animation (this may take a while)...") | |
| # Generate video | |
| out_samples = [] | |
| for i in range(0, len(rescale_motions), 48): | |
| motions = np.concatenate([first_motion, rescale_motions[i:i+48]]) | |
| input_video = motions | |
| input_video = torch.from_numpy(np.array(input_video)).permute([3, 0, 1, 2]).unsqueeze(0) | |
| input_video = input_video / 255 | |
| with torch.no_grad(): | |
| sample = pipe( | |
| image=image, | |
| image_face=image_face, | |
| control_video=input_video, | |
| prompt="", | |
| negative_prompt="", | |
| height=480, | |
| width=720, | |
| num_frames=49, | |
| # generator=generator, | |
| guidance_scale=guidance_scale, | |
| num_inference_steps=steps, | |
| ) | |
| if i == 0: | |
| out_samples.extend(sample.frames[0]) | |
| else: | |
| out_samples.extend(sample.frames[0][1:]) | |
| # out_samples = sample.frames[0] | |
| # out_samples = out_samples[2:] # Skip first two frames | |
| if pad_length == 0: | |
| out_samples = out_samples[1:] | |
| else: | |
| out_samples = out_samples[1:-pad_length] | |
| progress(0.8, desc="Creating output video...") | |
| # Export video | |
| export_to_video(out_samples, temp_video_path, fps=12) | |
| progress(0.9, desc="Adding audio to video...") | |
| # Add audio to video | |
| result_path = save_video_with_audio(temp_video_path, audio_path, final_output_path) | |
| # Create side-by-side comparison | |
| target_h, target_w = sample_size[0], sample_size[1] | |
| final_images = [] | |
| for i in range(len(out_samples)): | |
| frame1 = image | |
| frame2 = Image.fromarray(np.array(out_samples[i])).convert("RGB") | |
| result = Image.new('RGB', (target_w * 2, target_h)) | |
| result.paste(frame1, (0, 0)) | |
| result.paste(frame2, (target_w, 0)) | |
| final_images.append(np.array(result)) | |
| comparison_path = os.path.join(output_dir, "comparison.mp4") | |
| write_mp4(comparison_path, final_images, fps=12) | |
| # Add audio to comparison video | |
| comparison_with_audio = os.path.join(output_dir, "comparison_with_audio.mp4") | |
| comparison_with_audio = save_video_with_audio(comparison_path, audio_path, comparison_with_audio) | |
| progress(1.0, desc="Done!") | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| return result_path, comparison_with_audio | |
| # Create Gradio interface | |
| with gr.Blocks(title="SkyReels A1 Talking Head") as app: | |
| gr.Markdown("# SkyReels A1 Talking Head") | |
| gr.Markdown('''Upload a portrait image and an audio file to animate the face. 💡 Enjoying this demo? Share your feedback or review, and you might earn exclusive rewards! 🚀✨ | |
| 📩 [Contact us on Discord](https://discord.com/invite/PwM6NYtccQ) for details. 🔥 [Code](https://github.com/SkyworkAI/SkyReels-A1) [Huggingface](https://huggingface.co/Skywork/SkyReels-A1)''') | |
| gr.Markdown('''✨ Try our **AI Office** for more productivity tools! [Visit Skywork AI Agent](https://skywork.ai/?utm_source=skyworkspace) ✨''') | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Row(): | |
| image_input = gr.Image(type="filepath", label="Portrait Image") | |
| audio_input = gr.Audio(type="filepath", label="Driving Audio") | |
| with gr.Row(): | |
| guidance_scale = gr.Slider(minimum=1.0, maximum=7.0, value=3.0, step=0.1, label="Guidance Scale") | |
| inference_steps = gr.Slider(minimum=5, maximum=30, value=10, step=1, label="Inference Steps") | |
| generate_button = gr.Button("Generate Animation", variant="primary") | |
| with gr.Column(): | |
| output_video = gr.Video(label="Animation Result") | |
| comparison_video = gr.Video(label="Side-by-Side Comparison") | |
| generate_button.click( | |
| fn=process_image_audio, | |
| inputs=[image_input, audio_input, guidance_scale, inference_steps], | |
| outputs=[output_video, comparison_video] | |
| ) | |
| gr.Markdown(""" | |
| ## Instructions | |
| 1. Upload a portrait image (frontal face works best) | |
| 2. Upload an audio file (wav format recommended) | |
| 3. Adjust parameters if needed | |
| 4. Click "Generate Animation" to create the video | |
| Note: Processing may take several minutes depending on your hardware. | |
| """) | |
| if __name__ == "__main__": | |
| app.launch(share=True) |