1inkusFace commited on
Commit
4d4355a
·
verified ·
1 Parent(s): 5abf5ea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +207 -70
app.py CHANGED
@@ -27,100 +27,247 @@ torch.backends.cudnn.allow_tf32 = False
27
  torch.backends.cudnn.deterministic = False
28
  torch.backends.cudnn.benchmark = False
29
  torch.set_float32_matmul_precision("highest")
30
-
31
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
32
  logger = logging.getLogger(__name__)
 
 
 
 
 
 
 
33
 
 
 
 
34
 
35
- _predictor = None
36
- task_type = TaskType.I2V
 
 
 
 
37
 
38
- def init_predictor():
39
- global _predictor
40
- global task_type # Access global task_type
41
- logger = logging.getLogger(__name__) # Logger within function
 
 
42
 
43
- if _predictor is None:
44
- if task_type == TaskType.I2V:
45
- model_id = "Skywork/SkyReels-V1-Hunyuan-I2V"
46
- elif task_type == TaskType.T2V:
47
- model_id = "your_t2v_model_id" # REPLACE with your T2V model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  else:
49
- raise ValueError(f"Invalid task_type: {task_type}")
50
 
51
- _predictor = SkyReelsVideoSingleGpuInfer(
52
- task_type=task_type, # Pass the task_type
53
- model_id=model_id,
54
- quant_model=True,
55
- is_offload=True,
56
- offload_config=OffloadConfig(
57
- high_cpu_memory=True,
58
- parameters_level=True,
59
- compiler_transformer=False,
60
- ),
61
- )
62
- _predictor.initialize()
63
- logger.info("Predictor initialized")
64
- else:
65
- logger.warning("Predictor already initialized (should be rare).")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  @spaces.GPU(duration=90)
68
  def generate_video(prompt, seed, image=None):
69
  global _predictor
70
- global task_type
71
 
72
  if seed == -1:
73
  random.seed()
74
  seed = int(random.randrange(4294967294))
75
 
76
- kwargs = {
77
- "prompt": prompt,
78
- "height": 512,
79
- "width": 512,
80
- "num_frames": 97,
81
- "num_inference_steps": 30,
82
- "seed": seed,
83
- "guidance_scale": 6.0,
84
- "embedded_guidance_scale": 1.0,
85
- "negative_prompt": "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion",
86
- "cfg_for": False,
87
- }
88
-
89
- if task_type == TaskType.I2V:
90
- assert image is not None, "Please input an image for I2V task."
91
- kwargs["image"] = Image.open(image)
92
- elif task_type == TaskType.T2V:
93
- pass #No image necessary
94
  else:
95
- raise ValueError(f"Invalid Tasktype: {task_type}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  if _predictor is None:
98
- init_predictor()
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
  output = _predictor.infer(**kwargs)
101
 
 
102
  output = (output.cpu().numpy() * 255).astype(np.uint8)
103
  output = output.transpose(0, 2, 3, 4, 1)
104
 
105
- save_dir = f"./result/{task_type.name}"
106
  os.makedirs(save_dir, exist_ok=True)
107
  video_out_file = f"{save_dir}/{prompt[:100].replace('/','')}_{seed}.mp4"
108
  print(f"generate video, local path: {video_out_file}")
109
- export_to_video(output, video_out_file, fps=24)
110
  return video_out_file, kwargs
111
 
112
-
113
  def create_gradio_interface():
114
  with gr.Blocks() as demo:
115
  with gr.Row():
116
- with gr.Column():
117
- image = gr.Image(label="Upload Image", type="filepath")
118
- prompt = gr.Textbox(label="Input Prompt")
119
- seed = gr.Number(label="Random Seed", value=-1)
120
- with gr.Column():
121
- submit_button = gr.Button("Generate Video")
122
- output_video = gr.Video(label="Generated Video")
123
- output_params = gr.Textbox(label="Output Parameters")
 
124
  submit_button.click(
125
  fn=generate_video,
126
  inputs=[prompt, seed, image],
@@ -130,15 +277,5 @@ def create_gradio_interface():
130
 
131
 
132
  if __name__ == "__main__":
133
- parser = argparse.ArgumentParser()
134
- parser.add_argument("--task_type", type=str, default="i2v", choices=["t2v", "i2v"],
135
- help="Task type, 't2v' for text-to-video, 'i2v' for image-to-video.")
136
- args = parser.parse_args()
137
-
138
- if args.task_type == "t2v":
139
- task_type = TaskType.T2V
140
- elif args.task_type == "i2v":
141
- task_type = TaskType.I2V
142
-
143
  demo = create_gradio_interface()
144
  demo.queue().launch()
 
27
  torch.backends.cudnn.deterministic = False
28
  torch.backends.cudnn.benchmark = False
29
  torch.set_float32_matmul_precision("highest")
 
30
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
31
+
32
  logger = logging.getLogger(__name__)
33
+ # --- Dummy Classes (Keep for standalone execution) ---
34
+ class OffloadConfig:
35
+ def __init__(self, high_cpu_memory=False, parameters_level=False, compiler_transformer=False, compiler_cache=""):
36
+ self.high_cpu_memory = high_cpu_memory
37
+ self.parameters_level = parameters_level
38
+ self.compiler_transformer = compiler_transformer
39
+ self.compiler_cache = compiler_cache
40
 
41
+ class TaskType: #Keep here for infer
42
+ T2V = 0
43
+ I2V = 1
44
 
45
+ class LlamaModel:
46
+ @staticmethod
47
+ def from_pretrained(*args, **kwargs):
48
+ return LlamaModel()
49
+ def to(self, device):
50
+ return self
51
 
52
+ class HunyuanVideoTransformer3DModel:
53
+ @staticmethod
54
+ def from_pretrained(*args, **kwargs):
55
+ return HunyuanVideoTransformer3DModel()
56
+ def to(self, device):
57
+ return self
58
 
59
+ class SkyreelsVideoPipeline:
60
+ @staticmethod
61
+ def from_pretrained(*args, **kwargs):
62
+ return SkyreelsVideoPipeline()
63
+ def to(self, device):
64
+ return self
65
+ def __call__(self, *args, **kwargs):
66
+ frames = torch.randn(1, 3, 16, 512, 512) # Correct dummy output
67
+ return type('obj', (object,), {'frames' : [frames]})()
68
+ def __init__(self):
69
+ super().__init__()
70
+ self._modules = OrderedDict()
71
+ self.vae = self.VAE()
72
+ self._modules["vae"] = self.vae
73
+
74
+ def named_children(self):
75
+ return self._modules.items()
76
+
77
+ class VAE:
78
+ def enable_tiling(self):
79
+ pass
80
+
81
+ def quantize_(*args, **kwargs):
82
+ return
83
+
84
+ def float8_weight_only():
85
+ return
86
+
87
+ # --- End Dummy Classes ---
88
+
89
+ class SkyReelsVideoSingleGpuInfer:
90
+ def _load_model(self, model_id: str, base_model_id: str = "hunyuanvideo-community/HunyuanVideo", quant_model: bool = True):
91
+ logger.info(f"load model model_id:{model_id} quan_model:{quant_model}")
92
+ text_encoder = LlamaModel.from_pretrained(
93
+ base_model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16
94
+ ).to("cpu")
95
+ transformer = HunyuanVideoTransformer3DModel.from_pretrained(
96
+ model_id, torch_dtype=torch.bfloat16, device="cpu"
97
+ ).to("cpu")
98
+
99
+ if quant_model:
100
+ quantize_(text_encoder, float8_weight_only())
101
+ text_encoder.to("cpu")
102
+ torch.cuda.empty_cache()
103
+ quantize_(transformer, float8_weight_only())
104
+ transformer.to("cpu")
105
+ torch.cuda.empty_cache()
106
+
107
+ pipe = SkyreelsVideoPipeline.from_pretrained(
108
+ base_model_id, transformer=transformer, text_encoder=text_encoder, torch_dtype=torch.bfloat16
109
+ ).to("cpu")
110
+ pipe.vae.enable_tiling()
111
+ torch.cuda.empty_cache()
112
+ return pipe
113
+
114
+ def __init__(
115
+ self,
116
+ task_type: TaskType,
117
+ model_id: str,
118
+ quant_model: bool = True,
119
+ is_offload: bool = True,
120
+ offload_config: OffloadConfig = OffloadConfig(),
121
+ enable_cfg_parallel: bool = True,
122
+ ):
123
+ self.task_type = task_type
124
+ self.model_id = model_id
125
+ self.quant_model = quant_model
126
+ self.is_offload = is_offload
127
+ self.offload_config = offload_config
128
+ self.enable_cfg_parallel = enable_cfg_parallel
129
+ self.pipe = None
130
+ self.is_initialized = False
131
+ self.gpu_device = None
132
+
133
+ def initialize(self):
134
+ """Initializes the model and moves it to the GPU."""
135
+ if self.is_initialized:
136
+ return
137
+
138
+ if not torch.cuda.is_available():
139
+ raise RuntimeError("CUDA is not available. Cannot initialize model.")
140
+
141
+ self.gpu_device = "cuda:0"
142
+ self.pipe = self._load_model(model_id=self.model_id, quant_model=self.quant_model)
143
+
144
+ if self.is_offload:
145
+ pass
146
  else:
147
+ self.pipe.to(self.gpu_device)
148
 
149
+ if self.offload_config.compiler_transformer:
150
+ torch._dynamo.config.suppress_errors = True
151
+ os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1"
152
+ os.environ["TORCHINDUCTOR_CACHE_DIR"] = f"{self.offload_config.compiler_cache}"
153
+ self.pipe.transformer = torch.compile(
154
+ self.pipe.transformer, mode="max-autotune-no-cudagraphs", dynamic=True
155
+ )
156
+ if self.offload_config.compiler_transformer:
157
+ self.warm_up()
158
+ self.is_initialized = True
159
+
160
+ def warm_up(self):
161
+ if not self.is_initialized:
162
+ raise RuntimeError("Model must be initialized before warm-up.")
163
+
164
+ init_kwargs = {
165
+ "prompt": "A woman is dancing in a room",
166
+ "height": 544,
167
+ "width": 960,
168
+ "guidance_scale": 6,
169
+ "num_inference_steps": 1,
170
+ "negative_prompt": "bad quality",
171
+ "num_frames": 16,
172
+ "generator": torch.Generator(self.gpu_device).manual_seed(42),
173
+ "embedded_guidance_scale": 1.0,
174
+ }
175
+ if self.task_type == TaskType.I2V:
176
+ init_kwargs["image"] = Image.new("RGB",(544,960), color="black")
177
+ self.pipe(**init_kwargs)
178
+ logger.info("Warm-up complete.")
179
+
180
+ def infer(self, **kwargs):
181
+ """Handles inference requests."""
182
+ if not self.is_initialized:
183
+ self.initialize()
184
+ if "seed" in kwargs:
185
+ kwargs["generator"] = torch.Generator(self.gpu_device).manual_seed(kwargs["seed"])
186
+ del kwargs["seed"]
187
+ assert (self.task_type == TaskType.I2V and "image" in kwargs) or self.task_type == TaskType.T2V
188
+ result = self.pipe(**kwargs).frames[0]
189
+ return result
190
+ _predictor = None # Global _predictor
191
 
192
  @spaces.GPU(duration=90)
193
  def generate_video(prompt, seed, image=None):
194
  global _predictor
 
195
 
196
  if seed == -1:
197
  random.seed()
198
  seed = int(random.randrange(4294967294))
199
 
200
+ if image is None:
201
+ task_type = TaskType.T2V
202
+ model_id = "Skywork/SkyReels-V1-Hunyuan-T2V" # Need to change this when you use the real model.
203
+ kwargs = { # Text-to-Video kwargs
204
+ "prompt": prompt,
205
+ "height": 512,
206
+ "width": 512,
207
+ "num_frames": 16, # Use a reasonable default
208
+ "num_inference_steps": 30,
209
+ "seed": seed,
210
+ "guidance_scale": 7.5, # Adjust as needed
211
+ "negative_prompt": "bad quality, worst quality", # Your negative prompt
212
+ }
 
 
 
 
 
213
  else:
214
+ task_type = TaskType.I2V
215
+ model_id = "Skywork/SkyReels-V1-Hunyuan-I2V"
216
+ kwargs = { # Image-to-Video kwargs
217
+ "prompt": prompt,
218
+ "image": Image.open(image),
219
+ "height": 512,
220
+ "width": 512,
221
+ "num_frames": 97,
222
+ "num_inference_steps": 30,
223
+ "seed": seed,
224
+ "guidance_scale": 6.0,
225
+ "embedded_guidance_scale": 1.0,
226
+ "negative_prompt": "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion",
227
+ "cfg_for": False,
228
+ }
229
 
230
  if _predictor is None:
231
+ # Initialize _predictor based on task type
232
+ _predictor = SkyReelsVideoSingleGpuInfer(
233
+ task_type=task_type,
234
+ model_id=model_id,
235
+ quant_model=True,
236
+ is_offload=True,
237
+ offload_config=OffloadConfig(
238
+ high_cpu_memory=True,
239
+ parameters_level=True,
240
+ compiler_transformer=False, # Change to True for warm-up
241
+ ),
242
+ )
243
+ _predictor.initialize()
244
+ logger.info("Predictor initialized")
245
 
246
  output = _predictor.infer(**kwargs)
247
 
248
+ # Convert and save video
249
  output = (output.cpu().numpy() * 255).astype(np.uint8)
250
  output = output.transpose(0, 2, 3, 4, 1)
251
 
252
+ save_dir = f"./result/{task_type.name}" # Use task_type.name
253
  os.makedirs(save_dir, exist_ok=True)
254
  video_out_file = f"{save_dir}/{prompt[:100].replace('/','')}_{seed}.mp4"
255
  print(f"generate video, local path: {video_out_file}")
256
+ export_to_video(output, video_out_file, fps=24) # Use a reasonable FPS
257
  return video_out_file, kwargs
258
 
 
259
  def create_gradio_interface():
260
  with gr.Blocks() as demo:
261
  with gr.Row():
262
+ with gr.Column():
263
+ image = gr.Image(label="Upload Image", type="filepath")
264
+ prompt = gr.Textbox(label="Input Prompt")
265
+ seed = gr.Number(label="Random Seed", value=-1)
266
+ with gr.Column():
267
+ submit_button = gr.Button("Generate Video")
268
+ output_video = gr.Video(label="Generated Video")
269
+ output_params = gr.Textbox(label="Output Parameters")
270
+
271
  submit_button.click(
272
  fn=generate_video,
273
  inputs=[prompt, seed, image],
 
277
 
278
 
279
  if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
280
  demo = create_gradio_interface()
281
  demo.queue().launch()