1inkusFace commited on
Commit
0e0805e
·
verified ·
1 Parent(s): ed58212

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -34
app.py CHANGED
@@ -1,22 +1,22 @@
1
  import spaces
2
  import gradio as gr
3
- import argparse # Import argparse
4
  import sys
5
  import os
6
  import random
7
  import subprocess
8
- from PIL import Image # Keep PIL import
 
9
 
10
- subprocess.run(['sh', './sky.sh']) # Keep if needed
11
 
 
12
  sys.path.append("./SkyReels-V1")
13
 
14
  # Corrected Relative Imports
15
- from skyreelsinfer import TaskType # Now imported correctly
16
- from skyreelsinfer.offload import OffloadConfig
17
- from skyreelsinfer.skyreels_video_infer import SkyReelsVideoSingleGpuInfer # Import the class
18
  from diffusers.utils import export_to_video
19
-
20
  import torch
21
  import logging
22
 
@@ -30,26 +30,24 @@ torch.set_float32_matmul_precision("highest")
30
 
31
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
32
 
33
- logger = logging.getLogger(__name__)
34
-
35
  # --- Dummy Classes (Moved to skyreelsinfer/__init__.py) ---
36
-
37
  # --- Global Variables and Argument Parsing ---
38
 
39
  _predictor = None
40
- task_type = TaskType.I2V # Default task type. IMPORTANT: Set a default.
41
 
42
  @spaces.GPU(duration=90)
43
  def init_predictor():
44
  global _predictor
45
- global task_type # Access the global task_type
46
  logger = logging.getLogger(__name__)
47
 
48
  if _predictor is None:
49
  if task_type == TaskType.I2V:
50
  model_id = "Skywork/SkyReels-V1-Hunyuan-I2V"
51
  elif task_type == TaskType.T2V:
52
- model_id = "your_t2v_model_id" # Replace with your T2V model ID
53
  else:
54
  raise ValueError(f"Invalid task_type: {task_type}")
55
 
@@ -69,13 +67,13 @@ def init_predictor():
69
  else:
70
  logger.warning("Predictor already initialized (should be rare).")
71
 
72
- @spaces.GPU(duration=90) # Needed, because we are saving a file
73
  def generate_video(prompt, seed, image=None):
74
  global _predictor
75
  global task_type
76
 
77
  if seed == -1:
78
- random.seed() # Use system time for randomness if seed is -1
79
  seed = int(random.randrange(4294967294))
80
 
81
  kwargs = {
@@ -93,37 +91,40 @@ def generate_video(prompt, seed, image=None):
93
 
94
  if task_type == TaskType.I2V:
95
  assert image is not None, "Please input an image for I2V task."
96
- kwargs["image"] = Image.open(image) # Use PIL.Image.open
97
  elif task_type == TaskType.T2V:
98
- pass # No image needed.
99
  else:
100
- raise ValueError("Invalid Tasktype")
101
 
102
  if _predictor is None:
103
  init_predictor()
104
 
105
  output = _predictor.infer(**kwargs)
106
 
107
- save_dir = f"./result/{task_type.name}" # Use task_type.name for directory
 
 
 
 
108
  os.makedirs(save_dir, exist_ok=True)
109
- video_out_file = f"{save_dir}/{prompt[:100].replace('/', '')}_{seed}.mp4"
110
  print(f"generate video, local path: {video_out_file}")
111
- export_to_video(output, video_out_file, fps=24)
112
- return video_out_file, kwargs # Return the file path
113
 
114
 
115
  def create_gradio_interface():
116
  with gr.Blocks() as demo:
117
  with gr.Row():
118
- with gr.Column():
119
- image = gr.Image(label="Upload Image", type="filepath")
120
- prompt = gr.Textbox(label="Input Prompt")
121
- seed = gr.Number(label="Random Seed", value=-1) # Default to -1
122
- with gr.Column():
123
- submit_button = gr.Button("Generate Video")
124
- output_video = gr.Video(label="Generated Video")
125
- output_params = gr.Textbox(label="Output Parameters")
126
-
127
  submit_button.click(
128
  fn=generate_video,
129
  inputs=[prompt, seed, image],
@@ -138,12 +139,11 @@ if __name__ == "__main__":
138
  help="Task type, 't2v' for text-to-video, 'i2v' for image-to-video.")
139
  args = parser.parse_args()
140
 
141
- # Set the global task_type based on command-line arguments
142
  if args.task_type == "t2v":
143
  task_type = TaskType.T2V
144
  elif args.task_type == "i2v":
145
  task_type = TaskType.I2V
146
- # No else needed, default is already set
147
 
148
  demo = create_gradio_interface()
149
- demo.queue().launch() # Add queue
 
1
  import spaces
2
  import gradio as gr
3
+ import argparse
4
  import sys
5
  import os
6
  import random
7
  import subprocess
8
+ from PIL import Image
9
+ import numpy as np # Import NumPy
10
 
 
11
 
12
+ # subprocess.run(['sh', './sky.sh']) # Keep if needed
13
  sys.path.append("./SkyReels-V1")
14
 
15
  # Corrected Relative Imports
16
+ from SkyReels-V1.skyreelsinfer import TaskType
17
+ from SkyReels-V1.skyreelsinfer.offload import OffloadConfig
18
+ from SkyReels-V1.skyreelsinfer.skyreels_video_infer import SkyReelsVideoSingleGpuInfer # Import the class
19
  from diffusers.utils import export_to_video
 
20
  import torch
21
  import logging
22
 
 
30
 
31
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
32
 
 
 
33
  # --- Dummy Classes (Moved to skyreelsinfer/__init__.py) ---
34
+ logger = logging.getLogger(__name__)
35
  # --- Global Variables and Argument Parsing ---
36
 
37
  _predictor = None
38
+ task_type = TaskType.I2V # Default task type.
39
 
40
  @spaces.GPU(duration=90)
41
  def init_predictor():
42
  global _predictor
43
+ global task_type
44
  logger = logging.getLogger(__name__)
45
 
46
  if _predictor is None:
47
  if task_type == TaskType.I2V:
48
  model_id = "Skywork/SkyReels-V1-Hunyuan-I2V"
49
  elif task_type == TaskType.T2V:
50
+ model_id = "your_t2v_model_id" # Replace
51
  else:
52
  raise ValueError(f"Invalid task_type: {task_type}")
53
 
 
67
  else:
68
  logger.warning("Predictor already initialized (should be rare).")
69
 
70
+ @spaces.GPU(duration=90)
71
  def generate_video(prompt, seed, image=None):
72
  global _predictor
73
  global task_type
74
 
75
  if seed == -1:
76
+ random.seed()
77
  seed = int(random.randrange(4294967294))
78
 
79
  kwargs = {
 
91
 
92
  if task_type == TaskType.I2V:
93
  assert image is not None, "Please input an image for I2V task."
94
+ kwargs["image"] = Image.open(image)
95
  elif task_type == TaskType.T2V:
96
+ pass # No image
97
  else:
98
+ raise ValueError(f"Invalid task_type: {task_type}")
99
 
100
  if _predictor is None:
101
  init_predictor()
102
 
103
  output = _predictor.infer(**kwargs)
104
 
105
+ # --- Convert to NumPy, move to CPU, scale, and change dtype ---
106
+ output = (output.cpu().numpy() * 255).astype(np.uint8)
107
+ # --- Convert from B, C, T, H, W to B, T, H, W, C
108
+ output = output.transpose(0, 2, 3, 4, 1)
109
+ save_dir = f"./result/{task_type.name}"
110
  os.makedirs(save_dir, exist_ok=True)
111
+ video_out_file = f"{save_dir}/{prompt[:100].replace('/','')}_{seed}.mp4"
112
  print(f"generate video, local path: {video_out_file}")
113
+ export_to_video(output, video_out_file, fps=24) # Pass fps
114
+ return video_out_file, kwargs
115
 
116
 
117
  def create_gradio_interface():
118
  with gr.Blocks() as demo:
119
  with gr.Row():
120
+ with gr.Column():
121
+ image = gr.Image(label="Upload Image", type="filepath")
122
+ prompt = gr.Textbox(label="Input Prompt")
123
+ seed = gr.Number(label="Random Seed", value=-1)
124
+ with gr.Column():
125
+ submit_button = gr.Button("Generate Video")
126
+ output_video = gr.Video(label="Generated Video")
127
+ output_params = gr.Textbox(label="Output Parameters")
 
128
  submit_button.click(
129
  fn=generate_video,
130
  inputs=[prompt, seed, image],
 
139
  help="Task type, 't2v' for text-to-video, 'i2v' for image-to-video.")
140
  args = parser.parse_args()
141
 
 
142
  if args.task_type == "t2v":
143
  task_type = TaskType.T2V
144
  elif args.task_type == "i2v":
145
  task_type = TaskType.I2V
146
+ # No else, default is already set.
147
 
148
  demo = create_gradio_interface()
149
+ demo.queue().launch()