1inkusFace commited on
Commit
58cc987
·
verified ·
1 Parent(s): b5ba988

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +171 -15
app.py CHANGED
@@ -6,17 +6,16 @@ import time
6
  import os
7
  import random
8
  import subprocess
 
9
 
10
- subprocess.run(['sh', './sky.sh'])
11
  sys.path.append("./SkyReels-V1")
12
 
13
  from skyreelsinfer import TaskType
14
  from skyreelsinfer.offload import OffloadConfig
15
- from skyreelsinfer.skyreels_video_infer import Predictor
16
  from diffusers.utils import export_to_video
17
- from diffusers.utils import load_image
18
-
19
- task_type = None
20
 
21
  import torch
22
 
@@ -32,15 +31,169 @@ torch.set_float32_matmul_precision("highest")
32
 
33
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
34
 
35
- import spaces
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  _predictor = None # Global variable to hold the predictor
38
 
39
  @spaces.GPU(duration=90) # We DO need @spaces.GPU on init_predictor
40
  def init_predictor():
41
  global _predictor
42
- import logging
43
- logger = logging.getLogger(__name__) #Logger
44
 
45
  if _predictor is None:
46
  _predictor = SkyReelsVideoSingleGpuInfer(
@@ -61,9 +214,11 @@ def init_predictor():
61
 
62
 
63
 
64
- @spaces.GPU(duration=90)
65
  def generate_video(prompt, seed, image=None):
66
  global task_type
 
 
67
  print(f"image:{type(image)}")
68
  if seed == -1:
69
  random.seed(time.time())
@@ -81,19 +236,20 @@ def generate_video(prompt, seed, image=None):
81
  "cfg_for": False,
82
  }
83
  assert image is not None, "please input image"
84
- kwargs["image"] = load_image(image=image)
85
- global _predictor
86
 
87
  if _predictor is None:
88
- init_predictor()
 
 
89
 
90
- output = predictor.infer(**kwargs)
91
  save_dir = f"./result/{task_type}"
92
  os.makedirs(save_dir, exist_ok=True)
93
  video_out_file = f"{save_dir}/{prompt[:100].replace('/','')}_{seed}.mp4"
94
  print(f"generate video, local path: {video_out_file}")
95
  export_to_video(output, video_out_file, fps=24)
96
- return video_out_file, kwargs
97
 
98
  def create_gradio_interface():
99
  with gr.Blocks() as demo:
@@ -113,4 +269,4 @@ def create_gradio_interface():
113
 
114
  if __name__ == "__main__":
115
  demo = create_gradio_interface()
116
- demo.launch()
 
6
  import os
7
  import random
8
  import subprocess
9
+ from PIL import Image # Import PIL.Image
10
 
11
+ # subprocess.run(['sh', './sky.sh']) # Keep this if needed for setup
12
  sys.path.append("./SkyReels-V1")
13
 
14
  from skyreelsinfer import TaskType
15
  from skyreelsinfer.offload import OffloadConfig
16
+ # from skyreelsinfer.skyreels_video_infer import Predictor # Correct: No Predictor import.
17
  from diffusers.utils import export_to_video
18
+ # from diffusers.utils import load_image # Removed: Use PIL directly
 
 
19
 
20
  import torch
21
 
 
31
 
32
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
33
 
34
+ import logging # Correct: Keep logging
35
+
36
+ # --- Dummy Classes (Keep these for standalone execution) ---
37
+ class OffloadConfig:
38
+ def __init__(self, high_cpu_memory=False, parameters_level=False, compiler_transformer=False, compiler_cache=""):
39
+ self.high_cpu_memory = high_cpu_memory
40
+ self.parameters_level = parameters_level
41
+ self.compiler_transformer = compiler_transformer
42
+ self.compiler_cache = compiler_cache
43
+
44
+ class TaskType:
45
+ T2V = 0
46
+ I2V = 1
47
+
48
+ class LlamaModel:
49
+ @staticmethod
50
+ def from_pretrained(*args, **kwargs):
51
+ return LlamaModel()
52
+ def to(self, device):
53
+ return self
54
+
55
+ class HunyuanVideoTransformer3DModel:
56
+ @staticmethod
57
+ def from_pretrained(*args, **kwargs):
58
+ return HunyuanVideoTransformer3DModel()
59
+ def to(self, device):
60
+ return self
61
+
62
+ class SkyreelsVideoPipeline:
63
+ @staticmethod
64
+ def from_pretrained(*args, **kwargs):
65
+ return SkyreelsVideoPipeline()
66
+ def to(self, device):
67
+ return self
68
+ def __call__(self, *args, **kwargs):
69
+ frames = [torch.randn(1, 3, 512, 512)] # Dummy frames
70
+ return type('obj', (object,), {'frames' : frames})()
71
+ class vae:
72
+ @staticmethod
73
+ def enable_tiling():
74
+ return
75
+
76
+ def quantize_(*args, **kwargs):
77
+ return
78
+
79
+ def float8_weight_only():
80
+ return
81
+ # --- End of Dummy Classes/Functions ---
82
+ logger = logging.getLogger(__name__)
83
+
84
+ class SkyReelsVideoSingleGpuInfer: # No more multiprocessing!
85
+ def __init__(
86
+ self,
87
+ task_type: TaskType,
88
+ model_id: str,
89
+ quant_model: bool = True,
90
+ is_offload: bool = True,
91
+ offload_config: OffloadConfig = OffloadConfig(),
92
+ enable_cfg_parallel: bool = True, # Remove world_size, local_rank
93
+ ):
94
+ self.task_type = task_type
95
+ self.model_id = model_id
96
+ self.quant_model = quant_model
97
+ self.is_offload = is_offload
98
+ self.offload_config = offload_config
99
+ self.enable_cfg_parallel = enable_cfg_parallel # Keep this
100
+ self.pipe = None
101
+ self.is_initialized = False
102
+ self.gpu_device = None
103
+
104
+ def _load_model(self, model_id: str, base_model_id: str = "hunyuanvideo-community/HunyuanVideo", quant_model: bool = True):
105
+ logger.info(f"load model model_id:{model_id} quan_model:{quant_model}")
106
+ text_encoder = LlamaModel.from_pretrained(
107
+ base_model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16
108
+ ).to("cpu")
109
+ transformer = HunyuanVideoTransformer3DModel.from_pretrained(
110
+ model_id, torch_dtype=torch.bfloat16, device="cpu"
111
+ ).to("cpu")
112
+
113
+ if quant_model:
114
+ quantize_(text_encoder, float8_weight_only())
115
+ text_encoder.to("cpu")
116
+ torch.cuda.empty_cache()
117
+ quantize_(transformer, float8_weight_only())
118
+ transformer.to("cpu")
119
+ torch.cuda.empty_cache()
120
+
121
+ pipe = SkyreelsVideoPipeline.from_pretrained(
122
+ base_model_id, transformer=transformer, text_encoder=text_encoder, torch_dtype=torch.bfloat16
123
+ ).to("cpu")
124
+ pipe.vae.enable_tiling()
125
+ torch.cuda.empty_cache()
126
+ return pipe
127
+
128
+ def initialize(self):
129
+ """Initializes the model and moves it to the GPU."""
130
+ if self.is_initialized:
131
+ return
132
+
133
+ if not torch.cuda.is_available():
134
+ raise RuntimeError("CUDA is not available. Cannot initialize model.")
135
+
136
+ self.gpu_device = "cuda:0" # Always cuda:0 in single-GPU case
137
+ self.pipe = self._load_model(model_id=self.model_id, quant_model=self.quant_model)
138
+
139
+ # Simplified: No need for max_batch_dim_size with single GPU
140
+ if self.is_offload:
141
+ pass
142
+ else:
143
+ self.pipe.to(self.gpu_device)
144
+
145
+ if self.offload_config.compiler_transformer:
146
+ torch._dynamo.config.suppress_errors = True
147
+ os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1"
148
+ # No world_size in cache directory name
149
+ os.environ["TORCHINDUCTOR_CACHE_DIR"] = f"{self.offload_config.compiler_cache}"
150
+ self.pipe.transformer = torch.compile(
151
+ self.pipe.transformer, mode="max-autotune-no-cudagraphs", dynamic=True
152
+ )
153
+ if self.offload_config.compiler_transformer: # Only warm up if compiling
154
+ self.warm_up()
155
+ self.is_initialized = True
156
+
157
+ def warm_up(self):
158
+ if not self.is_initialized:
159
+ raise RuntimeError("Model must be initialized before warm-up.")
160
+
161
+ init_kwargs = {
162
+ "prompt": "A woman is dancing in a room",
163
+ "height": 544,
164
+ "width": 960,
165
+ "guidance_scale": 6,
166
+ "num_inference_steps": 1,
167
+ "negative_prompt": "bad quality",
168
+ "num_frames": 16,
169
+ "generator": torch.Generator(self.gpu_device).manual_seed(42),
170
+ "embedded_guidance_scale": 1.0,
171
+ }
172
+ if self.task_type == TaskType.I2V:
173
+ init_kwargs["image"] = Image.new("RGB", (544,960), color="black") #Dummy
174
+ self.pipe(**init_kwargs)
175
+ logger.info("Warm-up complete.")
176
+
177
+ def infer(self, **kwargs):
178
+ """Handles inference requests."""
179
+ if not self.is_initialized:
180
+ self.initialize()
181
+ if "seed" in kwargs:
182
+ kwargs["generator"] = torch.Generator(self.gpu_device).manual_seed(kwargs["seed"])
183
+ del kwargs["seed"]
184
+ assert (self.task_type == TaskType.I2V and "image" in kwargs) or self.task_type == TaskType.T2V
185
+ result = self.pipe(**kwargs).frames[0]
186
+ return result
187
+
188
+
189
+ # --- Spaces Integration ---
190
 
191
  _predictor = None # Global variable to hold the predictor
192
 
193
  @spaces.GPU(duration=90) # We DO need @spaces.GPU on init_predictor
194
  def init_predictor():
195
  global _predictor
196
+ logger = logging.getLogger(__name__) # Correct: Logger inside function
 
197
 
198
  if _predictor is None:
199
  _predictor = SkyReelsVideoSingleGpuInfer(
 
214
 
215
 
216
 
217
+ @spaces.GPU(duration=90) # Now needed, because we write files.
218
  def generate_video(prompt, seed, image=None):
219
  global task_type
220
+ global _predictor # Correct: Access global _predictor
221
+
222
  print(f"image:{type(image)}")
223
  if seed == -1:
224
  random.seed(time.time())
 
236
  "cfg_for": False,
237
  }
238
  assert image is not None, "please input image"
239
+ # kwargs["image"] = load_image(image=image) # Removed: load image directly with PIL
240
+ kwargs["image"] = Image.open(image) # Use PIL.Image.open
241
 
242
  if _predictor is None:
243
+ init_predictor()
244
+
245
+ output = _predictor.infer(**kwargs) # Correct: Use _predictor
246
 
 
247
  save_dir = f"./result/{task_type}"
248
  os.makedirs(save_dir, exist_ok=True)
249
  video_out_file = f"{save_dir}/{prompt[:100].replace('/','')}_{seed}.mp4"
250
  print(f"generate video, local path: {video_out_file}")
251
  export_to_video(output, video_out_file, fps=24)
252
+ return video_out_file, kwargs # Correct: Return filename, kwargs
253
 
254
  def create_gradio_interface():
255
  with gr.Blocks() as demo:
 
269
 
270
  if __name__ == "__main__":
271
  demo = create_gradio_interface()
272
+ demo.queue().launch() # Add queue for async