rahul7star commited on
Commit
1717df0
·
verified ·
1 Parent(s): 8750a83

Update app_fast.py

Browse files
Files changed (1) hide show
  1. app_fast.py +203 -172
app_fast.py CHANGED
@@ -1,216 +1,247 @@
1
- # PyTorch 2.8 (temporary hack)
2
- import os
3
- os.system('pip install --upgrade --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu126 "torch<2.9" spaces')
4
- #os.system('pip install --upgrade --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu128 "torch<2.9" spaces')
5
-
6
- # Actual demo code
7
  import spaces
8
  import torch
9
- from diffusers import WanPipeline, AutoencoderKLWan
10
- from diffusers.models.transformers.transformer_wan import WanTransformer3DModel
11
- from diffusers.utils.export_utils import export_to_video
12
  import gradio as gr
13
  import tempfile
 
14
  import numpy as np
15
  from PIL import Image
16
  import random
17
- import gc
18
- from optimization import optimize_pipeline_
19
- from accelerate import Accelerator
20
-
21
-
22
- def get_memory_gb(device):
23
- """Get current allocated memory in GB"""
24
- return torch.cuda.memory_allocated(device) / 1024**3
25
-
26
- accelerator = Accelerator()
27
- device = accelerator.device
28
- # Set up device and dtype using accelerator
29
-
30
- initial_memory = get_memory_gb(device)
31
- print(f"Initial memory: {initial_memory:.2f}GB")
32
 
 
 
 
 
 
 
33
 
 
 
 
34
 
35
- MODEL_ID = "Wan-AI/Wan2.2-T2V-A14B-Diffusers"
 
36
 
37
- LANDSCAPE_WIDTH = 832
38
- LANDSCAPE_HEIGHT = 480
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  MAX_SEED = np.iinfo(np.int32).max
40
-
41
- FIXED_FPS = 16
42
- MIN_FRAMES_MODEL = 8
43
- MAX_FRAMES_MODEL = 81
44
-
45
- MIN_DURATION = round(MIN_FRAMES_MODEL/FIXED_FPS,1)
46
- MAX_DURATION = round(MAX_FRAMES_MODEL/FIXED_FPS,1)
47
-
48
- vae = AutoencoderKLWan.from_pretrained("Wan-AI/Wan2.2-T2V-A14B-Diffusers", subfolder="vae", torch_dtype=torch.float32)
49
- pipe = WanPipeline.from_pretrained(MODEL_ID,
50
- transformer=WanTransformer3DModel.from_pretrained('linoyts/Wan2.2-T2V-A14B-Diffusers-BF16',
51
- subfolder='transformer',
52
- torch_dtype=torch.bfloat16,
53
- device_map='cuda',
54
- ),
55
- transformer_2=WanTransformer3DModel.from_pretrained('linoyts/Wan2.2-T2V-A14B-Diffusers-BF16',
56
- subfolder='transformer_2',
57
- torch_dtype=torch.bfloat16,
58
- device_map='cuda',
59
- ),
60
- vae=vae,
61
- torch_dtype=torch.bfloat16,
62
- ).to('cuda')
63
-
64
-
65
- for i in range(3):
66
- gc.collect()
67
- torch.cuda.synchronize()
68
- torch.cuda.empty_cache()
69
-
70
- optimize_pipeline_(pipe,
71
- prompt='prompt',
72
- height=LANDSCAPE_HEIGHT,
73
- width=LANDSCAPE_WIDTH,
74
- num_frames=MAX_FRAMES_MODEL,
75
- )
76
-
77
-
78
- default_prompt_t2v = "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
79
- default_negative_prompt = "色调艳丽, 过曝, 静态, 细节模糊不清, 字幕, 风格, 作品, 画作, 画面, 静止, 整体发灰, 最差质量, 低质量, JPEG压缩残留, 丑陋的, 残缺的, 多余的手指, 画得不好的手部, 画得不好的脸部, 畸形的, 毁容的, 形态畸形的肢体, 手指融合, 静止不动的画面, 杂乱的背景, 三条腿, 背景人很多, 倒着走"
80
-
81
-
82
- def get_duration(
83
- prompt,
84
- negative_prompt,
85
- duration_seconds,
86
- guidance_scale,
87
- guidance_scale_2,
88
- steps,
89
- seed,
90
- randomize_seed,
91
- progress,
92
- ):
93
- return steps * 15
94
 
95
  @spaces.GPU(duration=get_duration)
96
- def generate_video(
97
- prompt,
98
- negative_prompt=default_negative_prompt,
99
- duration_seconds = MAX_DURATION,
100
- guidance_scale = 1,
101
- guidance_scale_2 = 3,
102
- steps = 4,
103
- seed = 42,
104
- randomize_seed = False,
105
- progress=gr.Progress(track_tqdm=True),
106
- ):
107
- """
108
- Generate a video from a text prompt using the Wan 2.2 14B T2V model with Lightning LoRA.
109
-
110
- This function takes an input prompt and generates a video animation based on the provided
111
- prompt and parameters. It uses an FP8 qunatized Wan 2.2 14B Text-to-Video model with Lightning LoRA
112
- for fast generation in 4-8 steps.
113
-
114
- Args:
115
- prompt (str): Text prompt describing the desired animation or motion.
116
- negative_prompt (str, optional): Negative prompt to avoid unwanted elements.
117
- Defaults to default_negative_prompt (contains unwanted visual artifacts).
118
- duration_seconds (float, optional): Duration of the generated video in seconds.
119
- Defaults to 2. Clamped between MIN_FRAMES_MODEL/FIXED_FPS and MAX_FRAMES_MODEL/FIXED_FPS.
120
- guidance_scale (float, optional): Controls adherence to the prompt. Higher values = more adherence.
121
- Defaults to 1.0. Range: 0.0-20.0.
122
- guidance_scale_2 (float, optional): Controls adherence to the prompt. Higher values = more adherence.
123
- Defaults to 1.0. Range: 0.0-20.0.
124
- steps (int, optional): Number of inference steps. More steps = higher quality but slower.
125
- Defaults to 4. Range: 1-30.
126
- seed (int, optional): Random seed for reproducible results. Defaults to 42.
127
- Range: 0 to MAX_SEED (2147483647).
128
- randomize_seed (bool, optional): Whether to use a random seed instead of the provided seed.
129
- Defaults to False.
130
- progress (gr.Progress, optional): Gradio progress tracker. Defaults to gr.Progress(track_tqdm=True).
131
-
132
- Returns:
133
- tuple: A tuple containing:
134
- - video_path (str): Path to the generated video file (.mp4)
135
- - current_seed (int): The seed used for generation (useful when randomize_seed=True)
136
-
137
- Raises:
138
- gr.Error: If input_image is None (no image uploaded).
139
-
140
- Note:
141
- - The function automatically resizes the input image to the target dimensions
142
- - Frame count is calculated as duration_seconds * FIXED_FPS (24)
143
- - Output dimensions are adjusted to be multiples of MOD_VALUE (32)
144
- - The function uses GPU acceleration via the @spaces.GPU decorator
145
- - Generation time varies based on steps and duration (see get_duration function)
146
- """
147
-
148
  num_frames = np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL)
 
149
  current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
150
 
151
- output_frames_list = pipe(
152
- prompt=prompt,
153
- negative_prompt=negative_prompt,
154
- height=480,
155
- width=832,
156
- num_frames=num_frames,
157
- guidance_scale=float(guidance_scale),
158
- guidance_scale_2=float(guidance_scale_2),
159
- num_inference_steps=int(steps),
160
- generator=torch.Generator(device="cuda").manual_seed(current_seed),
161
- ).frames[0]
 
 
 
 
 
 
162
 
163
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
164
  video_path = tmpfile.name
165
-
166
  export_to_video(output_frames_list, video_path, fps=FIXED_FPS)
167
-
168
  return video_path, current_seed
169
 
 
170
  with gr.Blocks() as demo:
171
- gr.Markdown("# Fast 4 steps Wan 2.2 T2V (14B) with Lightning LoRA")
172
- gr.Markdown("run Wan 2.2 in just 4-8 steps, with [Wan 2.2 Lightning LoRA](https://huggingface.co/Kijai/WanVideo_comfy/tree/main/Wan22-Lightning), fp8 quantization & AoT compilation - compatible with 🧨 diffusers and ZeroGPU⚡️")
173
  with gr.Row():
174
  with gr.Column():
175
- prompt_input = gr.Textbox(label="Prompt", value=default_prompt_t2v)
176
- duration_seconds_input = gr.Slider(minimum=MIN_DURATION, maximum=MAX_DURATION, step=0.1, value=MAX_DURATION, label="Duration (seconds)", info=f"Clamped to model's {MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL} frames at {FIXED_FPS}fps.")
177
-
 
178
  with gr.Accordion("Advanced Settings", open=False):
179
  negative_prompt_input = gr.Textbox(label="Negative Prompt", value=default_negative_prompt, lines=3)
180
  seed_input = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42, interactive=True)
181
  randomize_seed_checkbox = gr.Checkbox(label="Randomize seed", value=True, interactive=True)
182
- steps_slider = gr.Slider(minimum=1, maximum=30, step=1, value=4, label="Inference Steps")
183
- guidance_scale_input = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=1, label="Guidance Scale - high noise stage")
184
- guidance_scale_2_input = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=3, label="Guidance Scale 2 - low noise stage")
185
-
 
186
  generate_button = gr.Button("Generate Video", variant="primary")
187
  with gr.Column():
188
  video_output = gr.Video(label="Generated Video", autoplay=True, interactive=False)
189
-
 
 
 
 
 
 
 
 
 
 
 
 
190
  ui_inputs = [
191
- prompt_input,
192
  negative_prompt_input, duration_seconds_input,
193
- guidance_scale_input, guidance_scale_2_input, steps_slider, seed_input, randomize_seed_checkbox
194
  ]
195
  generate_button.click(fn=generate_video, inputs=ui_inputs, outputs=[video_output, seed_input])
196
 
197
  gr.Examples(
198
  examples=[
199
- [
200
- "POV selfie video, white cat with sunglasses standing on surfboard, relaxed smile, tropical beach behind (clear water, green hills, blue sky with clouds). Surfboard tips, cat falls into ocean, camera plunges underwater with bubbles and sunlight beams. Brief underwater view of cat’s face, then cat resurfaces, still filming selfie, playful summer vacation mood.",
201
- ],
202
- [
203
- "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
204
- ],
205
- [
206
- "A cinematic shot of a boat sailing on a calm sea at sunset.",
207
- ],
208
- [
209
- "Drone footage flying over a futuristic city with flying cars.",
210
- ],
211
  ],
212
- inputs=[prompt_input], outputs=[video_output, seed_input], fn=generate_video, cache_examples="lazy"
213
  )
214
 
215
  if __name__ == "__main__":
216
- demo.queue().launch(mcp_server=True)
 
 
 
 
 
 
 
1
  import spaces
2
  import torch
3
+ from diffusers import AutoencoderKLWan, WanPipeline, WanImageToVideoPipeline, UniPCMultistepScheduler
4
+ from diffusers.utils import export_to_video
 
5
  import gradio as gr
6
  import tempfile
7
+ from huggingface_hub import snapshot_download, hf_hub_download
8
  import numpy as np
9
  from PIL import Image
10
  import random
11
+ import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
+ # ---------------------------
14
+ # Configuration - edit these
15
+ # ---------------------------
16
+ # This should be the HF repo id that contains the model files (configs + safetensors)
17
+ # e.g. "username/wan2.2-t2v-rapid-aio-v6" or the repo where v6/wan2.2... is stored.
18
+ MODEL_REPO = "Phr00t/WAN2.2-14B-Rapid-AllInOne"
19
 
20
+ # If your model files are nested under a folder in the repo (e.g. "v6/wan2.2-t2v-rapid-aio-v6.safetensors"),
21
+ # set MODEL_FILE to that relative path (optional). If the repo is already standard diffusers layout, leave None.
22
+ MODEL_FILE = "v6/wan2.2-t2v-rapid-aio-v6.safetensors" # or None
23
 
24
+ # Preferred dtype for model weights on GPU (float16 or bfloat16). If your GPU doesn't support bfloat16 use float16.
25
+ PREFERRED_DTYPE = torch.bfloat16
26
 
27
+ # ---------------------------
28
+ # Helper: prepare local model
29
+ # ---------------------------
30
+ def prepare_model_local(repo_id: str, model_file: str | None = None):
31
+ """
32
+ Downloads the repo snapshot and returns local_model_dir.
33
+ If model_file is provided we ensure it's present (hf_hub_download fallback).
34
+ """
35
+ print(f"[prepare_model_local] snapshot_download(repo_id={repo_id})")
36
+ local_dir = snapshot_download(repo_id)
37
+ # if model_file is specified, ensure it's present locally (some repos may not be in root)
38
+ if model_file:
39
+ candidate = os.path.join(local_dir, model_file)
40
+ if not os.path.exists(candidate):
41
+ # try to download the single file into the local_dir
42
+ try:
43
+ print(f"[prepare_model_local] hf_hub_download(repo_id={repo_id}, filename={model_file})")
44
+ path = hf_hub_download(repo_id=repo_id, filename=model_file, cache_dir=None)
45
+ # Move / copy into the snapshot folder for consistent load paths (optional)
46
+ # but usually snapshot_download should have returned the file already
47
+ print(f"[prepare_model_local] got file at {path}")
48
+ except Exception as e:
49
+ print(f"[prepare_model_local] WARNING: couldn't find {model_file} in repo: {e}")
50
+ else:
51
+ print(f"[prepare_model_local] found {candidate}")
52
+ return local_dir
53
+
54
+ # ---------------------------
55
+ # Download & load model
56
+ # ---------------------------
57
+ local_model_dir = prepare_model_local(MODEL_REPO, MODEL_FILE)
58
+
59
+ # Try to load a VAE if present in subfolder "vae"
60
+ vae = None
61
+ vae_dir = os.path.join(local_model_dir, "vae")
62
+ try:
63
+ if os.path.isdir(vae_dir):
64
+ print("[main] Loading VAE from", vae_dir)
65
+ vae = AutoencoderKLWan.from_pretrained(vae_dir, torch_dtype=torch.float32)
66
+ else:
67
+ # If repo layout packs everything into root, WanPipeline.from_pretrained will handle it
68
+ print("[main] No vae/ subfolder detected - will rely on pipeline repo layout.")
69
+ except Exception as e:
70
+ print("[main] VAE load failed, continuing without explicit VAE:", e)
71
+ vae = None
72
+
73
+ # Attempt model load with preferred dtype; fall back to float32 if necessary
74
+ def load_pipeline_with_fallback(pipeline_cls, repo_or_dir, vae_obj=None, dtype=PREFERRED_DTYPE):
75
+ try:
76
+ print(f"[load_pipeline_with_fallback] trying dtype={dtype}")
77
+ return pipeline_cls.from_pretrained(repo_or_dir, vae=vae_obj, torch_dtype=dtype)
78
+ except Exception as e:
79
+ print(f"[load_pipeline_with_fallback] failed with dtype={dtype}: {e}")
80
+ print("[load_pipeline_with_fallback] retrying with float32")
81
+ return pipeline_cls.from_pretrained(repo_or_dir, vae=vae_obj, torch_dtype=torch.float32)
82
+
83
+ # Load pipelines
84
+ text_to_video_pipe = load_pipeline_with_fallback(WanPipeline, local_model_dir, vae_obj=vae)
85
+ image_to_video_pipe = load_pipeline_with_fallback(WanImageToVideoPipeline, local_model_dir, vae_obj=vae)
86
+
87
+ # Adjust scheduler and move to GPU
88
+ for pipe in [text_to_video_pipe, image_to_video_pipe]:
89
+ try:
90
+ # keep your custom scheduler tweak
91
+ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=8.0)
92
+ except Exception as e:
93
+ print("[main] Warning: couldn't replace scheduler:", e)
94
+ # move to CUDA
95
+ pipe.to("cuda")
96
+ # optionally enable attention slicing / memory optimizations if needed:
97
+ try:
98
+ pipe.enable_attention_slicing()
99
+ except Exception:
100
+ pass
101
+
102
+ # -------------------------------------------------------------------------
103
+ # (rest of your code remains the same)
104
+ # -------------------------------------------------------------------------
105
+ # Constants
106
+ MOD_VALUE = 32
107
+ DEFAULT_H_SLIDER_VALUE = 896
108
+ DEFAULT_W_SLIDER_VALUE = 896
109
+ NEW_FORMULA_MAX_AREA = 720 * 1024
110
+ SLIDER_MIN_H, SLIDER_MAX_H = 256, 1024
111
+ SLIDER_MIN_W, SLIDER_MAX_W = 256, 1024
112
  MAX_SEED = np.iinfo(np.int32).max
113
+ FIXED_FPS = 24
114
+ MIN_FRAMES_MODEL = 25
115
+ MAX_FRAMES_MODEL = 193
116
+
117
+ default_prompt_i2v = "make this image come alive, cinematic motion, smooth animation"
118
+ default_negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards, watermark, text, signature"
119
+
120
+ def _calculate_new_dimensions_wan(pil_image, mod_val, calculation_max_area, min_slider_h, max_slider_h, min_slider_w, max_slider_w, default_h, default_w):
121
+ orig_w, orig_h = pil_image.size
122
+ if orig_w <= 0 or orig_h <= 0:
123
+ return default_h, default_w
124
+ aspect_ratio = orig_h / orig_w
125
+
126
+ calc_h = round(np.sqrt(calculation_max_area * aspect_ratio))
127
+ calc_w = round(np.sqrt(calculation_max_area / aspect_ratio))
128
+ calc_h = max(mod_val, (calc_h // mod_val) * mod_val)
129
+ calc_w = max(mod_val, (calc_w // mod_val) * mod_val)
130
+
131
+ new_h = int(np.clip(calc_h, min_slider_h, (max_slider_h // mod_val) * mod_val))
132
+ new_w = int(np.clip(calc_w, min_slider_w, (max_slider_w // mod_val) * mod_val))
133
+
134
+ return new_h, new_w
135
+
136
+ def handle_image_upload_for_dims_wan(uploaded_pil_image, current_h_val, current_w_val):
137
+ if uploaded_pil_image is None:
138
+ return gr.update(value=DEFAULT_H_SLIDER_VALUE), gr.update(value=DEFAULT_W_SLIDER_VALUE)
139
+ try:
140
+ new_h, new_w = _calculate_new_dimensions_wan(
141
+ uploaded_pil_image, MOD_VALUE, NEW_FORMULA_MAX_AREA,
142
+ SLIDER_MIN_H, SLIDER_MAX_H, SLIDER_MIN_W, SLIDER_MAX_W,
143
+ DEFAULT_H_SLIDER_VALUE, DEFAULT_W_SLIDER_VALUE
144
+ )
145
+ return gr.update(value=new_h), gr.update(value=new_w)
146
+ except Exception as e:
147
+ gr.Warning("Error attempting to calculate new dimensions")
148
+ return gr.update(value=DEFAULT_H_SLIDER_VALUE), gr.update(value=DEFAULT_W_SLIDER_VALUE)
149
+
150
+ def get_duration(input_image, prompt, height, width,
151
+ negative_prompt, duration_seconds,
152
+ guidance_scale, steps,
153
+ seed, randomize_seed,
154
+ progress):
155
+ if steps > 4 and duration_seconds > 4:
156
+ return 90
157
+ elif steps > 4 or duration_seconds > 4:
158
+ return 75
159
+ else:
160
+ return 60
 
 
 
 
 
 
161
 
162
  @spaces.GPU(duration=get_duration)
163
+ def generate_video(input_image, prompt, height, width, negative_prompt=default_negative_prompt, duration_seconds=2, guidance_scale=0, steps=4, seed=44, randomize_seed=False, progress=gr.Progress(track_tqdm=True)):
164
+ target_h = max(MOD_VALUE, (int(height) // MOD_VALUE) * MOD_VALUE)
165
+ target_w = max(MOD_VALUE, (int(width) // MOD_VALUE) * MOD_VALUE)
166
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  num_frames = np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL)
168
+
169
  current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
170
 
171
+ if input_image is not None:
172
+ resized_image = input_image.resize((target_w, target_h))
173
+ with torch.inference_mode():
174
+ output_frames_list = image_to_video_pipe(
175
+ image=resized_image, prompt=prompt, negative_prompt=negative_prompt,
176
+ height=target_h, width=target_w, num_frames=num_frames,
177
+ guidance_scale=float(guidance_scale), num_inference_steps=int(steps),
178
+ generator=torch.Generator(device="cuda").manual_seed(current_seed)
179
+ ).frames[0]
180
+ else:
181
+ with torch.inference_mode():
182
+ output_frames_list = text_to_video_pipe(
183
+ prompt=prompt, negative_prompt=negative_prompt,
184
+ height=target_h, width=target_w, num_frames=num_frames,
185
+ guidance_scale=float(guidance_scale), num_inference_steps=int(steps),
186
+ generator=torch.Generator(device="cuda").manual_seed(current_seed)
187
+ ).frames[0]
188
 
189
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
190
  video_path = tmpfile.name
 
191
  export_to_video(output_frames_list, video_path, fps=FIXED_FPS)
 
192
  return video_path, current_seed
193
 
194
+ # UI (unchanged)
195
  with gr.Blocks() as demo:
196
+ gr.Markdown("# Fast Wan 2.2 TI2V 5B Demo")
197
+ gr.Markdown("""This Demo is using your Wan2.2 model from the specified HF repo.""")
198
  with gr.Row():
199
  with gr.Column():
200
+ input_image_component = gr.Image(type="pil", label="Input Image (optional, auto-resized to target H/W)")
201
+ prompt_input = gr.Textbox(label="Prompt", value=default_prompt_i2v)
202
+ duration_seconds_input = gr.Slider(minimum=round(MIN_FRAMES_MODEL/FIXED_FPS,1), maximum=round(MAX_FRAMES_MODEL/FIXED_FPS,1), step=0.1, value=2, label="Duration (seconds)", info=f"Clamped to model's {MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL} frames at {FIXED_FPS}fps.")
203
+
204
  with gr.Accordion("Advanced Settings", open=False):
205
  negative_prompt_input = gr.Textbox(label="Negative Prompt", value=default_negative_prompt, lines=3)
206
  seed_input = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42, interactive=True)
207
  randomize_seed_checkbox = gr.Checkbox(label="Randomize seed", value=True, interactive=True)
208
+ with gr.Row():
209
+ height_input = gr.Slider(minimum=SLIDER_MIN_H, maximum=SLIDER_MAX_H, step=MOD_VALUE, value=DEFAULT_H_SLIDER_VALUE, label=f"Output Height (multiple of {MOD_VALUE})")
210
+ width_input = gr.Slider(minimum=SLIDER_MIN_W, maximum=SLIDER_MAX_W, step=MOD_VALUE, value=DEFAULT_W_SLIDER_VALUE, label=f"Output Width (multiple of {MOD_VALUE})")
211
+ steps_slider = gr.Slider(minimum=1, maximum=8, step=1, value=4, label="Inference Steps")
212
+ guidance_scale_input = gr.Slider(minimum=0.0, maximum=5.0, step=0.01, value=0.0, label="Guidance Scale")
213
  generate_button = gr.Button("Generate Video", variant="primary")
214
  with gr.Column():
215
  video_output = gr.Video(label="Generated Video", autoplay=True, interactive=False)
216
+
217
+ input_image_component.upload(
218
+ fn=handle_image_upload_for_dims_wan,
219
+ inputs=[input_image_component, height_input, width_input],
220
+ outputs=[height_input, width_input]
221
+ )
222
+
223
+ input_image_component.clear(
224
+ fn=handle_image_upload_for_dims_wan,
225
+ inputs=[input_image_component, height_input, width_input],
226
+ outputs=[height_input, width_input]
227
+ )
228
+
229
  ui_inputs = [
230
+ input_image_component, prompt_input, height_input, width_input,
231
  negative_prompt_input, duration_seconds_input,
232
+ guidance_scale_input, steps_slider, seed_input, randomize_seed_checkbox
233
  ]
234
  generate_button.click(fn=generate_video, inputs=ui_inputs, outputs=[video_output, seed_input])
235
 
236
  gr.Examples(
237
  examples=[
238
+ [None, "A person eating spaghetti", 1024, 720],
239
+ ["cat.png", "The cat removes the glasses from its eyes.", 1088, 800],
240
+ [None, "a penguin playfully dancing in the snow, Antarctica", 1024, 720],
241
+ ["peng.png", "a penguin running towards camera joyfully, Antarctica", 896, 512],
 
 
 
 
 
 
 
 
242
  ],
243
+ inputs=[input_image_component, prompt_input, height_input, width_input], outputs=[video_output, seed_input], fn=generate_video, cache_examples="lazy"
244
  )
245
 
246
  if __name__ == "__main__":
247
+ demo.queue().launch()