# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import argparse import os import sys import datetime import imageio import numpy as np import torch import gradio as gr from huggingface_hub import snapshot_download snapshot_download( repo_id = "Wan-AI/Wan2.1-VACE-1.3B", local_dir = "./models/Wan2.1-VACE-1.3B" ) is_shared_ui = True if "fffiloni/Wan2.1-VACE-1.3B" in os.environ['SPACE_ID'] else False sys.path.insert(0, os.path.sep.join(os.path.realpath(__file__).split(os.path.sep)[:-2])) import wan from wan import WanVace, WanVaceMP from wan.configs import WAN_CONFIGS, SIZE_CONFIGS class FixedSizeQueue: def __init__(self, max_size): self.max_size = max_size self.queue = [] def add(self, item): self.queue.insert(0, item) if len(self.queue) > self.max_size: self.queue.pop() def get(self): return self.queue def __repr__(self): return str(self.queue) class VACEInference: def __init__(self, cfg, skip_load=False, gallery_share=False, gallery_share_limit=5): self.cfg = cfg self.save_dir = cfg.save_dir self.gallery_share = gallery_share self.gallery_share_data = FixedSizeQueue(max_size=gallery_share_limit) if not skip_load: if not args.mp: self.pipe = WanVace( config=WAN_CONFIGS[cfg.model_name], checkpoint_dir=cfg.ckpt_dir, device_id=0, rank=0, t5_fsdp=False, dit_fsdp=False, use_usp=False, ) else: self.pipe = WanVaceMP( config=WAN_CONFIGS[cfg.model_name], checkpoint_dir=cfg.ckpt_dir, use_usp=True, ulysses_size=cfg.ulysses_size, ring_size=cfg.ring_size ) def create_ui(self, *args, **kwargs): gr.Markdown("# VACE-WAN 1.3B Demo") gr.Markdown("All-in-One Video Creation and Editing") gr.HTML("""
""") with gr.Row(variant='panel', equal_height=True): with gr.Column(scale=1, min_width=0): self.src_video = gr.Video( label="src_video", sources=['upload'], value=None, interactive=True) with gr.Column(scale=1, min_width=0): self.src_mask = gr.Video( label="src_mask", sources=['upload'], value=None, interactive=True) # with gr.Row(variant='panel', equal_height=True): with gr.Column(scale=1, min_width=0): with gr.Row(equal_height=True): self.src_ref_image_1 = gr.Image(label='src_ref_image_1', height=200, interactive=True, type='filepath', image_mode='RGB', sources=['upload'], elem_id="src_ref_image_1", format='png') self.src_ref_image_2 = gr.Image(label='src_ref_image_2', height=200, interactive=True, type='filepath', image_mode='RGB', sources=['upload'], elem_id="src_ref_image_2", format='png') self.src_ref_image_3 = gr.Image(label='src_ref_image_3', height=200, interactive=True, type='filepath', image_mode='RGB', sources=['upload'], elem_id="src_ref_image_3", format='png') with gr.Row(variant='panel', equal_height=True): with gr.Column(scale=1): self.prompt = gr.Textbox( show_label=False, placeholder="positive_prompt_input", elem_id='positive_prompt', container=True, autofocus=True, elem_classes='type_row', visible=True, lines=2) self.negative_prompt = gr.Textbox( show_label=False, value="Bright and saturated tones, overexposed, static, unclear details, subtitles, style, work, painting, frame, still, overall grayish, worst quality, low quality, JPEG compression artifacts, ugly, deformed, extra fingers, poorly drawn hands, poorly drawn face, deformed, disfigured, misshapen limbs, fused fingers, motionless frame, cluttered background, three legs, crowded background, walking backwards.", placeholder="negative_prompt_input", elem_id='negative_prompt', container=True, autofocus=False, elem_classes='type_row', visible=True, interactive=True, lines=1) # with gr.Row(variant='panel', equal_height=True): with gr.Column(scale=1, min_width=0): with gr.Row(equal_height=True): self.shift_scale = gr.Slider( label='shift_scale', minimum=0.0, maximum=100.0, step=1.0, value=16.0, interactive=True) self.sample_steps = gr.Slider( label='sample_steps', minimum=1, maximum=100, step=1, value=25, interactive=False if is_shared_ui else True) self.context_scale = gr.Slider( label='context_scale', minimum=0.0, maximum=2.0, step=0.1, value=1.0, interactive=True) self.guide_scale = gr.Slider( label='guide_scale', minimum=1, maximum=10, step=0.5, value=5.0, interactive=True) self.infer_seed = gr.Slider(minimum=-1, maximum=10000000, value=2025, label="Seed") # with gr.Accordion(label="Usable without source video", open=False): with gr.Row(equal_height=True): self.output_height = gr.Textbox( label='resolutions_height', value=480, #value=720, interactive=True) self.output_width = gr.Textbox( label='resolutions_width', value=832, #value=1280, interactive=True) self.frame_rate = gr.Textbox( label='frame_rate', value=16, interactive=True) self.num_frames = gr.Textbox( label='num_frames', value=81, interactive=True) # with gr.Row(equal_height=True): with gr.Column(scale=5): self.generate_button = gr.Button( value='Run', elem_classes='type_row', elem_id='generate_button', visible=True) with gr.Column(scale=1): self.refresh_button = gr.Button(value='\U0001f504') # 🔄 # self.output_gallery = gr.Gallery( label="output_gallery", value=[], interactive=False, allow_preview=True, preview=True) def generate(self, output_gallery, src_video, src_mask, src_ref_image_1, src_ref_image_2, src_ref_image_3, prompt, negative_prompt, shift_scale, sample_steps, context_scale, guide_scale, infer_seed, output_height, output_width, frame_rate, num_frames, progress=gr.Progress(track_tqdm=True)): output_height, output_width, frame_rate, num_frames = int(output_height), int(output_width), int(frame_rate), int(num_frames) src_ref_images = [x for x in [src_ref_image_1, src_ref_image_2, src_ref_image_3] if x is not None] src_video, src_mask, src_ref_images = self.pipe.prepare_source([src_video], [src_mask], [src_ref_images], num_frames=num_frames, image_size=SIZE_CONFIGS[f"{output_width}*{output_height}"], device=self.pipe.device) video = self.pipe.generate( prompt, src_video, src_mask, src_ref_images, size=(output_width, output_height), context_scale=context_scale, shift=shift_scale, sampling_steps=sample_steps, guide_scale=guide_scale, n_prompt=negative_prompt, seed=infer_seed, offload_model=True) name = '{0:%Y%m%d%-H%M%S}'.format(datetime.datetime.now()) base_save_dir = './output' save_dir_path = os.path.join(base_save_dir, name) # Create the directory os.makedirs(save_dir_path, exist_ok=True) print(f"✅ Folder created: {save_dir_path}") video_path = os.path.join(save_dir_path, f'cur_gallery_{name}.mp4') video_frames = (torch.clamp(video / 2 + 0.5, min=0.0, max=1.0).permute(1, 2, 3, 0) * 255).cpu().numpy().astype(np.uint8) try: writer = imageio.get_writer(video_path, fps=frame_rate, codec='libx264', quality=8, macro_block_size=1) for frame in video_frames: writer.append_data(frame) writer.close() print(video_path) except Exception as e: raise gr.Error(f"Video save error: {e}") #if self.gallery_share: # self.gallery_share_data.add(video_path) # return self.gallery_share_data.get() else: return [video_path] def set_callbacks(self, **kwargs): self.gen_inputs = [self.output_gallery, self.src_video, self.src_mask, self.src_ref_image_1, self.src_ref_image_2, self.src_ref_image_3, self.prompt, self.negative_prompt, self.shift_scale, self.sample_steps, self.context_scale, self.guide_scale, self.infer_seed, self.output_height, self.output_width, self.frame_rate, self.num_frames] self.gen_outputs = [self.output_gallery] self.generate_button.click(self.generate, inputs=self.gen_inputs, outputs=self.gen_outputs, queue=True) self.refresh_button.click(lambda x: self.gallery_share_data.get() if self.gallery_share else x, inputs=[self.output_gallery], outputs=[self.output_gallery]) if __name__ == '__main__': parser = argparse.ArgumentParser(description='Argparser for VACE-WAN Demo:\n') parser.add_argument('--server_port', dest='server_port', help='', type=int, default=7860) parser.add_argument('--server_name', dest='server_name', help='', default='0.0.0.0') parser.add_argument('--root_path', dest='root_path', help='', default=None) parser.add_argument('--save_dir', dest='save_dir', help='', default='cache') parser.add_argument("--mp", action="store_true", help="Use Multi-GPUs",) parser.add_argument("--model_name", type=str, default="vace-1.3B", choices=list(WAN_CONFIGS.keys()), help="The model name to run.") parser.add_argument("--ulysses_size", type=int, default=1, help="The size of the ulysses parallelism in DiT.") parser.add_argument("--ring_size", type=int, default=1, help="The size of the ring attention parallelism in DiT.") parser.add_argument( "--ckpt_dir", type=str, # default='models/VACE-Wan2.1-1.3B-Preview', default='models/Wan2.1-VACE-1.3B/', help="The path to the checkpoint directory.", ) parser.add_argument( "--offload_to_cpu", action="store_true", help="Offloading unnecessary computations to CPU.", ) args = parser.parse_args() if not os.path.exists(args.save_dir): os.makedirs(args.save_dir, exist_ok=True) with gr.Blocks() as demo: infer_gr = VACEInference(args, skip_load=False, gallery_share=True, gallery_share_limit=5) infer_gr.create_ui() infer_gr.set_callbacks() allowed_paths = [args.save_dir] demo.queue(status_update_rate=1).launch(server_name=args.server_name, server_port=args.server_port, root_path=args.root_path, allowed_paths=allowed_paths, show_error=True, debug=True)