import os import gc import time import random import torch import imageio import gradio as gr from diffusers.utils import load_image from skyreels_v2_infer import DiffusionForcingPipeline from skyreels_v2_infer.modules import download_model from skyreels_v2_infer.pipelines import PromptEnhancer, resizecrop def generate_diffusion_forced_video( prompt, model_id, resolution, num_frames, image=None, ar_step=0, causal_attention=False, causal_block_size=1, base_num_frames=97, overlap_history=None, addnoise_condition=0, guidance_scale=6.0, shift=8.0, inference_steps=30, use_usp=False, offload=True, fps=24, seed=None, prompt_enhancer=False, teacache=False, teacache_thresh=0.2, use_ret_steps=False ): model_id = download_model(model_id) if resolution == "540P": height, width = 544, 960 elif resolution == "720P": height, width = 720, 1280 else: raise ValueError(f"Invalid resolution: {resolution}") if seed is None: random.seed(time.time()) seed = int(random.randrange(4294967294)) if num_frames > base_num_frames and overlap_history is None: raise ValueError("Specify `overlap_history` for long video generation. Try 17 or 37.") if addnoise_condition > 60: print("Warning: Large `addnoise_condition` may reduce consistency. Recommended: 20.") if image is not None: image = load_image(image).convert("RGB") image_width, image_height = image.size if image_height > image_width: height, width = width, height image = resizecrop(image, height, width) negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" prompt_input = prompt if prompt_enhancer and image is None: enhancer = PromptEnhancer() prompt_input = enhancer(prompt_input) del enhancer gc.collect() torch.cuda.empty_cache() pipe = DiffusionForcingPipeline( model_id, dit_path=model_id, device=torch.device("cuda"), weight_dtype=torch.bfloat16, use_usp=use_usp, offload=offload, ) if causal_attention: pipe.transformer.set_ar_attention(causal_block_size) if teacache: if ar_step > 0: num_steps = ( inference_steps + (((base_num_frames - 1) // 4 + 1) // causal_block_size - 1) * ar_step ) else: num_steps = inference_steps pipe.transformer.initialize_teacache( enable_teacache=True, num_steps=num_steps, teacache_thresh=teacache_thresh, use_ret_steps=use_ret_steps, ckpt_dir=model_id, ) with torch.amp.autocast("cuda", dtype=pipe.transformer.dtype), torch.no_grad(): video_frames = pipe( prompt=prompt_input, negative_prompt=negative_prompt, image=image, height=height, width=width, num_frames=num_frames, num_inference_steps=inference_steps, shift=shift, guidance_scale=guidance_scale, generator=torch.Generator(device="cuda").manual_seed(seed), overlap_history=overlap_history, addnoise_condition=addnoise_condition, base_num_frames=base_num_frames, ar_step=ar_step, causal_block_size=causal_block_size, fps=fps, )[0] os.makedirs("gradio_df_videos", exist_ok=True) timestamp = time.strftime("%Y%m%d_%H%M%S") output_path = f"gradio_df_videos/{prompt[:50].replace('/', '')}_{seed}_{timestamp}.mp4" imageio.mimwrite(output_path, video_frames, fps=fps, quality=8, output_params=["-loglevel", "error"]) return output_path # Gradio UI resolution_options = ["540P", "720P"] model_options = ["Skywork/SkyReels-V2-DF-1.3B-540P"] # Update if there are more with gr.Blocks() as demo: with gr.Column(): gr.Markdown("# SkyReels V2") with gr.Row(): with gr.Column(): prompt = gr.Textbox(label="Prompt") model_id = gr.Dropdown(choices=model_options, value=model_options[0], label="Model ID") resolution = gr.Radio(choices=resolution_options, value="540P", label="Resolution", interactive=False) num_frames = gr.Slider(minimum=16, maximum=200, value=97, step=1, label="Number of Frames") image = gr.Image(type="filepath", label="Input Image (optional)") with gr.Accordion("Advanced Settings", open=False): ar_step = gr.Number(label="AR Step", value=0) causal_attention = gr.Checkbox(label="Causal Attention") causal_block_size = gr.Number(label="Causal Block Size", value=1) base_num_frames = gr.Number(label="Base Num Frames", value=97) overlap_history = gr.Number(label="Overlap History (set for long videos)", value=None) addnoise_condition = gr.Number(label="AddNoise Condition", value=0) guidance_scale = gr.Slider(minimum=1.0, maximum=20.0, value=6.0, step=0.1, label="Guidance Scale") shift = gr.Slider(minimum=0.0, maximum=20.0, value=8.0, step=0.1, label="Shift") inference_steps = gr.Slider(minimum=1, maximum=100, value=30, step=1, label="Inference Steps") use_usp = gr.Checkbox(label="Use USP") offload = gr.Checkbox(label="Offload", value=True, interactive=False) fps = gr.Slider(minimum=1, maximum=60, value=24, step=1, label="FPS") seed = gr.Number(label="Seed (optional)", precision=0) prompt_enhancer = gr.Checkbox(label="Prompt Enhancer") use_teacache = gr.Checkbox(label="Use TeaCache") teacache_thresh = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.01, label="TeaCache Threshold") use_ret_steps = gr.Checkbox(label="Use Retention Steps") submit_btn = gr.Button("Generate") with gr.Column(): output_video = gr.Video(label="Generated Video") submit_btn.click( fn = generate_diffusion_forced_video, inputs = [ prompt, model_id, resolution, num_frames, image, ar_step, causal_attention, causal_block_size, base_num_frames, overlap_history, addnoise_condition, guidance_scale, shift, inference_steps, use_usp, offload, fps, seed, prompt_enhancer, use_teacache, teacache_thresh, use_ret_steps ], outputs = [ output_video ] ) demo.launch(show_error=True, show_api=False, share=False)