1inkusFace commited on
Commit
c0db3ab
·
verified ·
1 Parent(s): 37c56fd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -196
app.py CHANGED
@@ -1,23 +1,24 @@
1
  import spaces
2
  import gradio as gr
3
- import argparse
4
  import sys
5
- import time
6
  import os
7
  import random
8
  import subprocess
9
- from PIL import Image # Import PIL.Image
 
 
10
 
11
- subprocess.run(['sh', './sky.sh']) # Keep this if needed for setup
12
  sys.path.append("./SkyReels-V1")
13
 
14
- from skyreelsinfer import TaskType
15
- from skyreelsinfer.offload import OffloadConfig
16
- # from skyreelsinfer.skyreels_video_infer import Predictor # Correct: No Predictor import.
 
17
  from diffusers.utils import export_to_video
18
- # from diffusers.utils import load_image # Removed: Use PIL directly
19
 
20
  import torch
 
21
 
22
  torch.backends.cuda.matmul.allow_tf32 = False
23
  torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
@@ -25,204 +26,58 @@ torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
25
  torch.backends.cudnn.allow_tf32 = False
26
  torch.backends.cudnn.deterministic = False
27
  torch.backends.cudnn.benchmark = False
28
- # torch.backends.cuda.preferred_blas_library="cublas"
29
- # torch.backends.cuda.preferred_linalg_library="cusolver"
30
  torch.set_float32_matmul_precision("highest")
31
 
32
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
33
 
34
- import logging # Correct: Keep logging
35
-
36
- # --- Dummy Classes (Keep these for standalone execution) ---
37
- class OffloadConfig:
38
- def __init__(self, high_cpu_memory=False, parameters_level=False, compiler_transformer=False, compiler_cache=""):
39
- self.high_cpu_memory = high_cpu_memory
40
- self.parameters_level = parameters_level
41
- self.compiler_transformer = compiler_transformer
42
- self.compiler_cache = compiler_cache
43
-
44
- class TaskType:
45
- T2V = 0
46
- I2V = 1
47
-
48
- class LlamaModel:
49
- @staticmethod
50
- def from_pretrained(*args, **kwargs):
51
- return LlamaModel()
52
- def to(self, device):
53
- return self
54
-
55
- class HunyuanVideoTransformer3DModel:
56
- @staticmethod
57
- def from_pretrained(*args, **kwargs):
58
- return HunyuanVideoTransformer3DModel()
59
- def to(self, device):
60
- return self
61
-
62
- class SkyreelsVideoPipeline:
63
- @staticmethod
64
- def from_pretrained(*args, **kwargs):
65
- return SkyreelsVideoPipeline()
66
- def to(self, device):
67
- return self
68
- def __call__(self, *args, **kwargs):
69
- frames = [torch.randn(1, 3, 512, 512)] # Dummy frames
70
- return type('obj', (object,), {'frames' : frames})()
71
- class vae:
72
- @staticmethod
73
- def enable_tiling():
74
- return
75
-
76
- def quantize_(*args, **kwargs):
77
- return
78
-
79
- def float8_weight_only():
80
- return
81
- # --- End of Dummy Classes/Functions ---
82
  logger = logging.getLogger(__name__)
83
 
84
- class SkyReelsVideoSingleGpuInfer: # No more multiprocessing!
85
- def __init__(
86
- self,
87
- task_type: TaskType,
88
- model_id: str,
89
- quant_model: bool = True,
90
- is_offload: bool = True,
91
- offload_config: OffloadConfig = OffloadConfig(),
92
- enable_cfg_parallel: bool = True, # Remove world_size, local_rank
93
- ):
94
- self.task_type = task_type
95
- self.model_id = model_id
96
- self.quant_model = quant_model
97
- self.is_offload = is_offload
98
- self.offload_config = offload_config
99
- self.enable_cfg_parallel = enable_cfg_parallel # Keep this
100
- self.pipe = None
101
- self.is_initialized = False
102
- self.gpu_device = None
103
-
104
- def _load_model(self, model_id: str, base_model_id: str = "hunyuanvideo-community/HunyuanVideo", quant_model: bool = True):
105
- logger.info(f"load model model_id:{model_id} quan_model:{quant_model}")
106
- text_encoder = LlamaModel.from_pretrained(
107
- base_model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16
108
- ).to("cpu")
109
- transformer = HunyuanVideoTransformer3DModel.from_pretrained(
110
- model_id, torch_dtype=torch.bfloat16, device="cpu"
111
- ).to("cpu")
112
-
113
- if quant_model:
114
- quantize_(text_encoder, float8_weight_only())
115
- text_encoder.to("cpu")
116
- torch.cuda.empty_cache()
117
- quantize_(transformer, float8_weight_only())
118
- transformer.to("cpu")
119
- torch.cuda.empty_cache()
120
-
121
- pipe = SkyreelsVideoPipeline.from_pretrained(
122
- base_model_id, transformer=transformer, text_encoder=text_encoder, torch_dtype=torch.bfloat16
123
- ).to("cpu")
124
- pipe.vae.enable_tiling()
125
- torch.cuda.empty_cache()
126
- return pipe
127
-
128
- def initialize(self):
129
- """Initializes the model and moves it to the GPU."""
130
- if self.is_initialized:
131
- return
132
-
133
- if not torch.cuda.is_available():
134
- raise RuntimeError("CUDA is not available. Cannot initialize model.")
135
-
136
- self.gpu_device = "cuda:0" # Always cuda:0 in single-GPU case
137
- self.pipe = self._load_model(model_id=self.model_id, quant_model=self.quant_model)
138
-
139
- # Simplified: No need for max_batch_dim_size with single GPU
140
- if self.is_offload:
141
- pass
142
- else:
143
- self.pipe.to(self.gpu_device)
144
-
145
- if self.offload_config.compiler_transformer:
146
- torch._dynamo.config.suppress_errors = True
147
- os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1"
148
- # No world_size in cache directory name
149
- os.environ["TORCHINDUCTOR_CACHE_DIR"] = f"{self.offload_config.compiler_cache}"
150
- self.pipe.transformer = torch.compile(
151
- self.pipe.transformer, mode="max-autotune-no-cudagraphs", dynamic=True
152
- )
153
- if self.offload_config.compiler_transformer: # Only warm up if compiling
154
- self.warm_up()
155
- self.is_initialized = True
156
 
157
- def warm_up(self):
158
- if not self.is_initialized:
159
- raise RuntimeError("Model must be initialized before warm-up.")
160
 
161
- init_kwargs = {
162
- "prompt": "A woman is dancing in a room",
163
- "height": 544,
164
- "width": 960,
165
- "guidance_scale": 6,
166
- "num_inference_steps": 1,
167
- "negative_prompt": "bad quality",
168
- "num_frames": 16,
169
- "generator": torch.Generator(self.gpu_device).manual_seed(42),
170
- "embedded_guidance_scale": 1.0,
171
- }
172
- if self.task_type == TaskType.I2V:
173
- init_kwargs["image"] = Image.new("RGB", (544,960), color="black") #Dummy
174
- self.pipe(**init_kwargs)
175
- logger.info("Warm-up complete.")
176
 
177
- def infer(self, **kwargs):
178
- """Handles inference requests."""
179
- if not self.is_initialized:
180
- self.initialize()
181
- if "seed" in kwargs:
182
- kwargs["generator"] = torch.Generator(self.gpu_device).manual_seed(kwargs["seed"])
183
- del kwargs["seed"]
184
- assert (self.task_type == TaskType.I2V and "image" in kwargs) or self.task_type == TaskType.T2V
185
- result = self.pipe(**kwargs).frames[0]
186
- return result
187
-
188
-
189
- # --- Spaces Integration ---
190
-
191
- _predictor = None # Global variable to hold the predictor
192
-
193
- @spaces.GPU(duration=90) # We DO need @spaces.GPU on init_predictor
194
  def init_predictor():
195
  global _predictor
196
- logger = logging.getLogger(__name__) # Correct: Logger inside function
 
197
 
198
  if _predictor is None:
 
 
 
 
 
 
 
199
  _predictor = SkyReelsVideoSingleGpuInfer(
200
- task_type=TaskType.I2V,
201
- model_id="Skywork/SkyReels-V1-Hunyuan-I2V", # Replace!
202
  quant_model=True,
203
  is_offload=True,
204
  offload_config=OffloadConfig(
205
  high_cpu_memory=True,
206
  parameters_level=True,
207
- compiler_transformer=False, # Set to True to enable compilation/warm-up
208
  ),
209
  )
210
- _predictor.initialize() # Initialize *after* creation
211
  logger.info("Predictor initialized")
212
  else:
213
  logger.warning("Predictor already initialized (should be rare).")
214
 
215
-
216
-
217
- @spaces.GPU(duration=90) # Now needed, because we write files.
218
  def generate_video(prompt, seed, image=None):
 
219
  global task_type
220
- global _predictor # Correct: Access global _predictor
221
 
222
- print(f"image:{type(image)}")
223
  if seed == -1:
224
- random.seed(time.time())
225
  seed = int(random.randrange(4294967294))
 
226
  kwargs = {
227
  "prompt": prompt,
228
  "height": 512,
@@ -235,38 +90,60 @@ def generate_video(prompt, seed, image=None):
235
  "negative_prompt": "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion",
236
  "cfg_for": False,
237
  }
238
- assert image is not None, "please input image"
239
- # kwargs["image"] = load_image(image=image) # Removed: load image directly with PIL
240
- kwargs["image"] = Image.open(image) # Use PIL.Image.open
 
 
 
 
 
241
 
242
  if _predictor is None:
243
  init_predictor()
244
 
245
- output = _predictor.infer(**kwargs) # Correct: Use _predictor
246
 
247
- save_dir = f"./result/{task_type}"
248
  os.makedirs(save_dir, exist_ok=True)
249
- video_out_file = f"{save_dir}/{prompt[:100].replace('/','')}_{seed}.mp4"
250
  print(f"generate video, local path: {video_out_file}")
251
  export_to_video(output, video_out_file, fps=24)
252
- return video_out_file, kwargs # Correct: Return filename, kwargs
 
253
 
254
  def create_gradio_interface():
255
- with gr.Blocks() as demo:
256
- with gr.Row():
 
257
  image = gr.Image(label="Upload Image", type="filepath")
258
  prompt = gr.Textbox(label="Input Prompt")
259
- seed = gr.Number(label="Random Seed", value=-1)
260
- submit_button = gr.Button("Generate Video")
261
- output_video = gr.Video(label="Generated Video")
262
- output_params = gr.Textbox(label="Output Parameters")
263
- submit_button.click(
264
- fn=generate_video,
265
- inputs=[prompt, seed, image],
266
- outputs=[output_video, output_params],
267
- )
268
- return demo
 
 
 
269
 
270
  if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
271
  demo = create_gradio_interface()
272
- demo.queue().launch() # Add queue for async
 
1
  import spaces
2
  import gradio as gr
3
+ import argparse # Import argparse
4
  import sys
 
5
  import os
6
  import random
7
  import subprocess
8
+ from PIL import Image # Keep PIL import
9
+
10
+ subprocess.run(['sh', './sky.sh']) # Keep if needed
11
 
 
12
  sys.path.append("./SkyReels-V1")
13
 
14
+ # Corrected Relative Imports
15
+ from SkyReels-V1.skyreelsinfer import TaskType # Now imported correctly
16
+ from SkyReels-V1.skyreelsinfer.offload import OffloadConfig
17
+ from SkyReels-V1.skyreelsinfer.skyreels_video_infer import SkyReelsVideoSingleGpuInfer # Import the class
18
  from diffusers.utils import export_to_video
 
19
 
20
  import torch
21
+ import logging
22
 
23
  torch.backends.cuda.matmul.allow_tf32 = False
24
  torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
 
26
  torch.backends.cudnn.allow_tf32 = False
27
  torch.backends.cudnn.deterministic = False
28
  torch.backends.cudnn.benchmark = False
 
 
29
  torch.set_float32_matmul_precision("highest")
30
 
31
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  logger = logging.getLogger(__name__)
34
 
35
+ # --- Dummy Classes (Moved to skyreelsinfer/__init__.py) ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
+ # --- Global Variables and Argument Parsing ---
 
 
38
 
39
+ _predictor = None
40
+ task_type = TaskType.I2V # Default task type. IMPORTANT: Set a default.
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
+ @spaces.GPU(duration=90)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  def init_predictor():
44
  global _predictor
45
+ global task_type # Access the global task_type
46
+ logger = logging.getLogger(__name__)
47
 
48
  if _predictor is None:
49
+ if task_type == TaskType.I2V:
50
+ model_id = "Skywork/SkyReels-V1-Hunyuan-I2V"
51
+ elif task_type == TaskType.T2V:
52
+ model_id = "your_t2v_model_id" # Replace with your T2V model ID
53
+ else:
54
+ raise ValueError(f"Invalid task_type: {task_type}")
55
+
56
  _predictor = SkyReelsVideoSingleGpuInfer(
57
+ task_type=task_type,
58
+ model_id=model_id,
59
  quant_model=True,
60
  is_offload=True,
61
  offload_config=OffloadConfig(
62
  high_cpu_memory=True,
63
  parameters_level=True,
64
+ compiler_transformer=False,
65
  ),
66
  )
67
+ _predictor.initialize()
68
  logger.info("Predictor initialized")
69
  else:
70
  logger.warning("Predictor already initialized (should be rare).")
71
 
72
+ @spaces.GPU(duration=90) # Needed, because we are saving a file
 
 
73
  def generate_video(prompt, seed, image=None):
74
+ global _predictor
75
  global task_type
 
76
 
 
77
  if seed == -1:
78
+ random.seed() # Use system time for randomness if seed is -1
79
  seed = int(random.randrange(4294967294))
80
+
81
  kwargs = {
82
  "prompt": prompt,
83
  "height": 512,
 
90
  "negative_prompt": "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion",
91
  "cfg_for": False,
92
  }
93
+
94
+ if task_type == TaskType.I2V:
95
+ assert image is not None, "Please input an image for I2V task."
96
+ kwargs["image"] = Image.open(image) # Use PIL.Image.open
97
+ elif task_type == TaskType.T2V:
98
+ pass # No image needed.
99
+ else:
100
+ raise ValueError("Invalid Tasktype")
101
 
102
  if _predictor is None:
103
  init_predictor()
104
 
105
+ output = _predictor.infer(**kwargs)
106
 
107
+ save_dir = f"./result/{task_type.name}" # Use task_type.name for directory
108
  os.makedirs(save_dir, exist_ok=True)
109
+ video_out_file = f"{save_dir}/{prompt[:100].replace('/', '')}_{seed}.mp4"
110
  print(f"generate video, local path: {video_out_file}")
111
  export_to_video(output, video_out_file, fps=24)
112
+ return video_out_file, kwargs # Return the file path
113
+
114
 
115
  def create_gradio_interface():
116
+ with gr.Blocks() as demo:
117
+ with gr.Row():
118
+ with gr.Column():
119
  image = gr.Image(label="Upload Image", type="filepath")
120
  prompt = gr.Textbox(label="Input Prompt")
121
+ seed = gr.Number(label="Random Seed", value=-1) # Default to -1
122
+ with gr.Column():
123
+ submit_button = gr.Button("Generate Video")
124
+ output_video = gr.Video(label="Generated Video")
125
+ output_params = gr.Textbox(label="Output Parameters")
126
+
127
+ submit_button.click(
128
+ fn=generate_video,
129
+ inputs=[prompt, seed, image],
130
+ outputs=[output_video, output_params],
131
+ )
132
+ return demo
133
+
134
 
135
  if __name__ == "__main__":
136
+ parser = argparse.ArgumentParser()
137
+ parser.add_argument("--task_type", type=str, default="i2v", choices=["t2v", "i2v"],
138
+ help="Task type, 't2v' for text-to-video, 'i2v' for image-to-video.")
139
+ args = parser.parse_args()
140
+
141
+ # Set the global task_type based on command-line arguments
142
+ if args.task_type == "t2v":
143
+ task_type = TaskType.T2V
144
+ elif args.task_type == "i2v":
145
+ task_type = TaskType.I2V
146
+ # No else needed, default is already set
147
+
148
  demo = create_gradio_interface()
149
+ demo.queue().launch() # Add queue