import os import gc import time import random import torch import imageio import gradio as gr from diffusers.utils import load_image from skyreels_v2_infer import DiffusionForcingPipeline from skyreels_v2_infer.modules import download_model from skyreels_v2_infer.pipelines import PromptEnhancer, resizecrop def generate_diffusion_forced_video( prompt, model_id, resolution, num_frames, image=None, ar_step=0, causal_attention=False, causal_block_size=1, base_num_frames=97, overlap_history=None, addnoise_condition=0, guidance_scale=6.0, shift=8.0, inference_steps=30, use_usp=False, offload=True, fps=24, seed=None, prompt_enhancer=False, teacache=False, teacache_thresh=0.2, use_ret_steps=False ): model_id = download_model(model_id) if resolution == "540P": height, width = 544, 960 elif resolution == "720P": height, width = 720, 1280 else: raise ValueError(f"Invalid resolution: {resolution}") if seed is None: random.seed(time.time()) seed = int(random.randrange(4294967294)) if num_frames > base_num_frames and overlap_history is None: raise ValueError("Specify `overlap_history` for long video generation. Try 17 or 37.") if addnoise_condition > 60: print("Warning: Large `addnoise_condition` may reduce consistency. Recommended: 20.") if image is not None: image = load_image(image).convert("RGB") image_width, image_height = image.size if image_height > image_width: height, width = width, height image = resizecrop(image, height, width) negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" prompt_input = prompt if prompt_enhancer and image is None: enhancer = PromptEnhancer() prompt_input = enhancer(prompt_input) del enhancer gc.collect() torch.cuda.empty_cache() pipe = DiffusionForcingPipeline( model_id, dit_path=model_id, device=torch.device("cuda"), weight_dtype=torch.bfloat16, use_usp=use_usp, offload=offload, ) if causal_attention: pipe.transformer.set_ar_attention(causal_block_size) if teacache: if ar_step > 0: num_steps = ( inference_steps + (((base_num_frames - 1) // 4 + 1) // causal_block_size - 1) * ar_step ) else: num_steps = inference_steps pipe.transformer.initialize_teacache( enable_teacache=True, num_steps=num_steps, teacache_thresh=teacache_thresh, use_ret_steps=use_ret_steps, ckpt_dir=model_id, ) with torch.amp.autocast("cuda", dtype=pipe.transformer.dtype), torch.no_grad(): video_frames = pipe( prompt=prompt_input, negative_prompt=negative_prompt, image=image, height=height, width=width, num_frames=num_frames, num_inference_steps=inference_steps, shift=shift, guidance_scale=guidance_scale, generator=torch.Generator(device="cuda").manual_seed(seed), overlap_history=overlap_history, addnoise_condition=addnoise_condition, base_num_frames=base_num_frames, ar_step=ar_step, causal_block_size=causal_block_size, fps=fps, )[0] os.makedirs("gradio_df_videos", exist_ok=True) timestamp = time.strftime("%Y%m%d_%H%M%S") output_path = f"gradio_df_videos/{prompt[:50].replace('/', '')}_{seed}_{timestamp}.mp4" imageio.mimwrite(output_path, video_frames, fps=fps, quality=8, output_params=["-loglevel", "error"]) return output_path # Gradio UI resolution_options = ["540P", "720P"] model_options = ["Skywork/SkyReels-V2-DF-1.3B-540P"] # Update if there are more with gr.Blocks() as demo: with gr.Column(): gr.Markdown("# SkyReels V2") with gr.Row(): with gr.Column(): prompt = gr.Textbox(label="Prompt") model_id = gr.Dropdown(choices=model_options, value=model_options[0], label="Model ID") resolution = gr.Radio(choices=resolution_options, value="540P", label="Resolution", interactive=False) num_frames = gr.Slider(minimum=16, maximum=200, value=97, step=1, label="Number of Frames") image = gr.Image(type="filepath", label="Input Image (optional)") with gr.Accordion("Advanced Settings", open=False): ar_step = gr.Number(label="AR Step", value=0) causal_attention = gr.Checkbox(label="Causal Attention") causal_block_size = gr.Number(label="Causal Block Size", value=1) base_num_frames = gr.Number(label="Base Num Frames", value=97) overlap_history = gr.Number(label="Overlap History (set for long videos)", value=None) addnoise_condition = gr.Number(label="AddNoise Condition", value=0) guidance_scale = gr.Slider(minimum=1.0, maximum=20.0, value=6.0, step=0.1, label="Guidance Scale") shift = gr.Slider(minimum=0.0, maximum=20.0, value=8.0, step=0.1, label="Shift") inference_steps = gr.Slider(minimum=1, maximum=100, value=30, step=1, label="Inference Steps") use_usp = gr.Checkbox(label="Use USP") offload = gr.Checkbox(label="Offload", value=True, interactive=False) fps = gr.Slider(minimum=1, maximum=60, value=24, step=1, label="FPS") seed = gr.Number(label="Seed (optional)", precision=0) prompt_enhancer = gr.Checkbox(label="Prompt Enhancer") use_teacache = gr.Checkbox(label="Use TeaCache") teacache_thresh = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.01, label="TeaCache Threshold") use_ret_steps = gr.Checkbox(label="Use Retention Steps") submit_btn = gr.Button("Generate") with gr.Column(): output_video = gr.Video(label="Generated Video") gr.Examples( examples = [ ["A graceful white swan with a curved neck and delicate feathers swimming in a serene lake at dawn, its reflection perfectly mirrored in the still water as mist rises from the surface, with the swan occasionally dipping its head into the water to feed.", "./examples/swan.jpeg"], ["A graceful white swan with a curved neck and delicate feathers swimming in a serene lake at dawn, its reflection perfectly mirrored in the still water as mist rises from the surface, with the swan occasionally dipping its head into the water to feed.", None], ["A sea turtle swimming near a shipwreck", "./examples/turtle.jpeg"], ["A sea turtle swimming near a shipwreck", None], ], inputs = [prompt, image] ) submit_btn.click( fn = generate_diffusion_forced_video, inputs = [ prompt, model_id, resolution, num_frames, image, ar_step, causal_attention, causal_block_size, base_num_frames, overlap_history, addnoise_condition, guidance_scale, shift, inference_steps, use_usp, offload, fps, seed, prompt_enhancer, use_teacache, teacache_thresh, use_ret_steps ], outputs = [ output_video ] ) demo.launch(show_error=True, show_api=False, share=False)