1inkusFace commited on
Commit
9493512
·
verified ·
1 Parent(s): d1f7e93

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +143 -0
app.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import imageio
4
+ import os
5
+ import time
6
+ import random
7
+ import gc
8
+ from PIL import Image
9
+
10
+ # Import necessary components from the cloned repository
11
+ from skyreels_v2_infer.modules import download_model
12
+ from skyreels_v2_infer.pipelines import Image2VideoPipeline, resizecrop
13
+
14
+ # --- Global Configuration & Model Loading ---
15
+ MODEL_ID = "Skywork/SkyReels-V2-I2V-14B-720P"
16
+ HEIGHT = 720
17
+ WIDTH = 1280
18
+ OUTPUT_DIR = "video_out"
19
+
20
+ # Create output directory
21
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
22
+
23
+ print("Downloading and loading model... This may take a while.")
24
+ # Download model files to the cache
25
+ cached_model_path = download_model(MODEL_ID)
26
+
27
+ # Load the pipeline. This is done once when the Space starts.
28
+ # We enable offload by default to be compatible with GPUs like A10G-Large (24GB VRAM)
29
+ pipe = Image2VideoPipeline(
30
+ model_path=cached_model_path,
31
+ dit_path=cached_model_path,
32
+ use_usp=False,
33
+ offload=True # Enable CPU offload to save VRAM
34
+ )
35
+ print("Model loaded successfully.")
36
+
37
+ # --- Inference Function ---
38
+ def generate_video(input_image, prompt, guidance_scale, inference_steps, num_frames, fps, seed):
39
+ """
40
+ Main function to generate video from an image and a prompt.
41
+ """
42
+ if input_image is None:
43
+ raise gr.Error("You must upload an initial image.")
44
+ if not prompt:
45
+ raise gr.Error("Prompt cannot be empty.")
46
+
47
+ # Use provided seed or generate a random one
48
+ if seed == -1:
49
+ seed = random.randint(0, 2**32 - 1)
50
+
51
+ generator = torch.Generator(device="cuda").manual_seed(seed)
52
+
53
+ # Prepare the input image (resize and crop)
54
+ image = Image.fromarray(input_image).convert("RGB")
55
+ processed_image = resizecrop(image, HEIGHT, WIDTH)
56
+
57
+ # Define a default negative prompt
58
+ negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, worst quality, low quality, JPEG compression residue, ugly, deformed."
59
+
60
+ # Set up generation parameters
61
+ kwargs = {
62
+ "image": processed_image,
63
+ "prompt": prompt,
64
+ "negative_prompt": negative_prompt,
65
+ "num_frames": num_frames,
66
+ "num_inference_steps": inference_steps,
67
+ "guidance_scale": guidance_scale,
68
+ "shift": 8.0, # Default value from original script
69
+ "generator": generator,
70
+ "height": HEIGHT,
71
+ "width": WIDTH,
72
+ }
73
+
74
+ print(f"Generating video with seed: {seed}")
75
+ start_time = time.time()
76
+
77
+ # Run inference
78
+ with torch.cuda.amp.autocast(dtype=pipe.transformer.dtype), torch.no_grad():
79
+ video_frames = pipe(**kwargs)[0]
80
+
81
+ end_time = time.time()
82
+ print(f"Inference took {end_time - start_time:.2f} seconds.")
83
+
84
+ # Save the output video
85
+ # Sanitize prompt for filename
86
+ safe_prompt = "".join(c for c in prompt if c.isalnum() or c in " _-").strip()[:50]
87
+ output_filename = f"{safe_prompt}_{seed}.mp4"
88
+ output_path = os.path.join(OUTPUT_DIR, output_filename)
89
+
90
+ imageio.mimwrite(output_path, video_frames, fps=fps, quality=8, output_params=["-loglevel", "error"])
91
+
92
+ print(f"Video saved to {output_path}")
93
+
94
+ # Clean up memory
95
+ gc.collect()
96
+ torch.cuda.empty_cache()
97
+
98
+ return output_path
99
+
100
+ # --- Gradio UI ---
101
+ with gr.Blocks(css="footer {display: none !important}") as demo:
102
+ gr.Markdown(
103
+ """
104
+ # SkyReels-V2 Image-to-Video Generator
105
+ ### Model: Skywork/SkyReels-V2-I2V-14B-720P
106
+ This Space demonstrates the SkyReels V2 model for generating video from a single starting image and a text prompt.
107
+ **Note:** This is a very large model. Generation can take several minutes, even on powerful GPUs.
108
+ """
109
+ )
110
+ with gr.Row():
111
+ with gr.Column():
112
+ input_image = gr.Image(type="numpy", label="Initial Image")
113
+ prompt = gr.Textbox(label="Prompt", placeholder="e.g., A cinematic shot of a car driving on a rainy street at night.")
114
+
115
+ with gr.Accordion("Advanced Settings", open=False):
116
+ guidance_scale = gr.Slider(minimum=1.0, maximum=15.0, value=6.0, step=0.5, label="Guidance Scale")
117
+ inference_steps = gr.Slider(minimum=10, maximum=100, value=30, step=1, label="Inference Steps")
118
+ num_frames = gr.Slider(minimum=25, maximum=145, value=97, step=8, label="Number of Frames")
119
+ fps = gr.Slider(minimum=8, maximum=30, value=24, step=1, label="Frames Per Second (FPS)")
120
+ seed = gr.Number(value=-1, label="Seed (-1 for random)")
121
+
122
+ with gr.Column():
123
+ output_video = gr.Video(label="Generated Video")
124
+ run_button = gr.Button("Generate Video", variant="primary")
125
+
126
+ gr.Examples(
127
+ examples=[
128
+ ["./examples/car.png", "A cinematic shot of a car driving on a rainy street at night, neon lights reflecting on the wet pavement.", 7.0, 30, 97, 24, 12345],
129
+ ["./examples/castle.png", "An epic fantasy castle in the mountains, dragons flying in the sky, cinematic lighting.", 6.0, 40, 97, 12, 54321],
130
+ ],
131
+ inputs=[input_image, prompt, guidance_scale, inference_steps, num_frames, fps, seed],
132
+ outputs=output_video,
133
+ fn=generate_video,
134
+ cache_examples=False, # Set to True if you have GPU and want to pre-process examples
135
+ )
136
+
137
+ # Add example images to your space in a folder named 'examples' for this to work
138
+ # Or simply remove the gr.Examples block.
139
+
140
+ run_button.click(fn=generate_video, inputs=[input_image, prompt, guidance_scale, inference_steps, num_frames, fps, seed], outputs=output_video)
141
+
142
+ if __name__ == "__main__":
143
+ demo.launch()