1inkusFace commited on
Commit
f7c613a
·
verified ·
1 Parent(s): 84c1e9b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -7
app.py CHANGED
@@ -232,6 +232,9 @@ def generate_video(prompt, seed, image=None):
232
  else:
233
  task_type = TaskType.I2V
234
  model_id = "Skywork/SkyReels-V1-Hunyuan-I2V"
 
 
 
235
  kwargs = {
236
  "prompt": prompt,
237
  "image": Image.open(image),
@@ -240,6 +243,7 @@ def generate_video(prompt, seed, image=None):
240
  "num_frames": 97,
241
  "num_inference_steps": 30,
242
  "seed": seed,
 
243
  "guidance_scale": 6.0,
244
  "embedded_guidance_scale": 1.0,
245
  "negative_prompt": "Aerial view, low quality, bad hands",
@@ -257,20 +261,20 @@ def generate_video(prompt, seed, image=None):
257
  parameters_level=True,
258
  compiler_transformer=False,
259
  ),
260
- )
261
  _predictor.initialize()
262
  logger.info("Predictor initialized")
263
-
264
- output = _predictor.infer(**kwargs)
265
-
266
- output = (output.cpu().numpy() * 255).astype(np.uint8)
267
- output = output.transpose(0, 2, 3, 4, 1)
268
 
269
  save_dir = f"./result"
270
  os.makedirs(save_dir, exist_ok=True)
271
  video_out_file = f"{save_dir}/{seed}.mp4"
272
  print(f"generate video, local path: {video_out_file}")
273
- export_to_video([output[0, t] for t in range(output.shape[1])], video_out_file, fps=24)
274
  return video_out_file, kwargs
275
 
276
  def create_gradio_interface():
 
232
  else:
233
  task_type = TaskType.I2V
234
  model_id = "Skywork/SkyReels-V1-Hunyuan-I2V"
235
+ seed = 43
236
+ generator = torch.Generator(device="cuda").manual_seed(seed)
237
+
238
  kwargs = {
239
  "prompt": prompt,
240
  "image": Image.open(image),
 
243
  "num_frames": 97,
244
  "num_inference_steps": 30,
245
  "seed": seed,
246
+ "generator": generator,
247
  "guidance_scale": 6.0,
248
  "embedded_guidance_scale": 1.0,
249
  "negative_prompt": "Aerial view, low quality, bad hands",
 
261
  parameters_level=True,
262
  compiler_transformer=False,
263
  ),
264
+ ).to("cuda")
265
  _predictor.initialize()
266
  logger.info("Predictor initialized")
267
+ with torch.no_grad():
268
+ output = _predictor.infer(**kwargs)
269
+ out_samples.extend(output.frames[0])
270
+ #output = (output.cpu().numpy() * 255).astype(np.uint8)
271
+ #output = output.transpose(0, 2, 3, 4, 1)
272
 
273
  save_dir = f"./result"
274
  os.makedirs(save_dir, exist_ok=True)
275
  video_out_file = f"{save_dir}/{seed}.mp4"
276
  print(f"generate video, local path: {video_out_file}")
277
+ export_to_video(out_samples, video_out_file, fps=24)
278
  return video_out_file, kwargs
279
 
280
  def create_gradio_interface():