1inkusFace commited on
Commit
c459371
·
verified ·
1 Parent(s): a40e1b1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -6
app.py CHANGED
@@ -36,7 +36,6 @@ logger = logging.getLogger(__name__)
36
 
37
  _predictor = None
38
  task_type = TaskType.I2V # Default task type.
39
-
40
  @spaces.GPU(duration=90)
41
  def init_predictor():
42
  global _predictor
@@ -93,7 +92,7 @@ def generate_video(prompt, seed, image=None):
93
  assert image is not None, "Please input an image for I2V task."
94
  kwargs["image"] = Image.open(image)
95
  elif task_type == TaskType.T2V:
96
- pass # No image
97
  else:
98
  raise ValueError(f"Invalid task_type: {task_type}")
99
 
@@ -104,13 +103,13 @@ def generate_video(prompt, seed, image=None):
104
 
105
  # --- Convert to NumPy, move to CPU, scale, and change dtype ---
106
  output = (output.cpu().numpy() * 255).astype(np.uint8)
107
- # --- Convert from B, C, T, H, W to B, T, H, W, C
108
- output = output.transpose(0, 2, 3, 4, 1)
109
  save_dir = f"./result/{task_type.name}"
110
  os.makedirs(save_dir, exist_ok=True)
111
  video_out_file = f"{save_dir}/{prompt[:100].replace('/','')}_{seed}.mp4"
112
  print(f"generate video, local path: {video_out_file}")
113
- export_to_video(output, video_out_file, fps=24) # Pass fps
114
  return video_out_file, kwargs
115
 
116
 
@@ -143,7 +142,6 @@ if __name__ == "__main__":
143
  task_type = TaskType.T2V
144
  elif args.task_type == "i2v":
145
  task_type = TaskType.I2V
146
- # No else, default is already set.
147
 
148
  demo = create_gradio_interface()
149
  demo.queue().launch()
 
36
 
37
  _predictor = None
38
  task_type = TaskType.I2V # Default task type.
 
39
  @spaces.GPU(duration=90)
40
  def init_predictor():
41
  global _predictor
 
92
  assert image is not None, "Please input an image for I2V task."
93
  kwargs["image"] = Image.open(image)
94
  elif task_type == TaskType.T2V:
95
+ pass
96
  else:
97
  raise ValueError(f"Invalid task_type: {task_type}")
98
 
 
103
 
104
  # --- Convert to NumPy, move to CPU, scale, and change dtype ---
105
  output = (output.cpu().numpy() * 255).astype(np.uint8)
106
+ output = output.transpose(0, 2, 3, 4, 1) #Correct transpose.
107
+
108
  save_dir = f"./result/{task_type.name}"
109
  os.makedirs(save_dir, exist_ok=True)
110
  video_out_file = f"{save_dir}/{prompt[:100].replace('/','')}_{seed}.mp4"
111
  print(f"generate video, local path: {video_out_file}")
112
+ export_to_video(output, video_out_file, fps=24)
113
  return video_out_file, kwargs
114
 
115
 
 
142
  task_type = TaskType.T2V
143
  elif args.task_type == "i2v":
144
  task_type = TaskType.I2V
 
145
 
146
  demo = create_gradio_interface()
147
  demo.queue().launch()