rahul7star commited on
Commit
62c6eb6
·
verified ·
1 Parent(s): 0de6c30

Update app_g.py

Browse files
Files changed (1) hide show
  1. app_g.py +40 -1
app_g.py CHANGED
@@ -10,6 +10,42 @@ import numpy as np
10
  from diffsynth import ModelManager, PusaMultiFramesPipeline, PusaV2VPipeline, WanVideoPusaPipeline, save_video
11
  import tempfile
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  class PusaVideoDemo:
14
  def __init__(self):
15
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -38,6 +74,8 @@ class PusaVideoDemo:
38
  torch_dtype=torch.bfloat16,
39
  )
40
  print("Models loaded successfully!")
 
 
41
 
42
  def load_lora_and_get_pipe(self, pipe_type, lora_path, lora_alpha):
43
  """Load LoRA and return appropriate pipeline"""
@@ -244,7 +282,7 @@ class PusaVideoDemo:
244
 
245
  except Exception as e:
246
  return None, f"Error: {str(e)}"
247
-
248
  def generate_t2v_video(self, prompt, lora_alpha, num_inference_steps,
249
  negative_prompt, progress=gr.Progress()):
250
  """Generate video from text prompt"""
@@ -1191,6 +1229,7 @@ def create_demo():
1191
  return demo
1192
 
1193
  if __name__ == "__main__":
 
1194
  demo = create_demo()
1195
  demo.launch(
1196
  share=False,
 
10
  from diffsynth import ModelManager, PusaMultiFramesPipeline, PusaV2VPipeline, WanVideoPusaPipeline, save_video
11
  import tempfile
12
 
13
+
14
+
15
+ # Constants
16
+ import os
17
+ from huggingface_hub import snapshot_download
18
+
19
+ # Constants
20
+ MODEL_ZOO_DIR = "./model_zoo"
21
+ PUSA_DIR = os.path.join(MODEL_ZOO_DIR, "PusaV1")
22
+ WAN_SUBFOLDER = "Wan2.1-T2V-14B"
23
+ WAN_MODEL_PATH = os.path.join(PUSA_DIR, WAN_SUBFOLDER)
24
+ LORA_PATH = os.path.join(PUSA_DIR, "pusa_v1.pt")
25
+
26
+ # Ensure model and weights are downloaded
27
+ def ensure_model_downloaded():
28
+ if not os.path.exists(PUSA_DIR):
29
+ print("Downloading RaphaelLiu/PusaV1 to ./model_zoo/PusaV1 ...")
30
+ snapshot_download(
31
+ repo_id="RaphaelLiu/PusaV1",
32
+ local_dir=PUSA_DIR,
33
+ repo_type="model",
34
+ local_dir_use_symlinks=False,
35
+ )
36
+ print("✅ PusaV1 downloaded.")
37
+
38
+ if not os.path.exists(WAN_MODEL_PATH):
39
+ print("Downloading Wan-AI/Wan2.1-T2V-14B to ./model_zoo/PusaV1/Wan2.1-T2V-14B ...")
40
+ snapshot_download(
41
+ repo_id="Wan-AI/Wan2.1-T2V-14B",
42
+ local_dir=WAN_MODEL_PATH, # Changed from WAN_DIR to WAN_MODEL_PATH
43
+ repo_type="model",
44
+ local_dir_use_symlinks=False,
45
+ )
46
+ print("✅ Wan2.1-T2V-14B downloaded.")
47
+
48
+
49
  class PusaVideoDemo:
50
  def __init__(self):
51
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
 
74
  torch_dtype=torch.bfloat16,
75
  )
76
  print("Models loaded successfully!")
77
+
78
+
79
 
80
  def load_lora_and_get_pipe(self, pipe_type, lora_path, lora_alpha):
81
  """Load LoRA and return appropriate pipeline"""
 
282
 
283
  except Exception as e:
284
  return None, f"Error: {str(e)}"
285
+
286
  def generate_t2v_video(self, prompt, lora_alpha, num_inference_steps,
287
  negative_prompt, progress=gr.Progress()):
288
  """Generate video from text prompt"""
 
1229
  return demo
1230
 
1231
  if __name__ == "__main__":
1232
+ ensure_model_downloaded()
1233
  demo = create_demo()
1234
  demo.launch(
1235
  share=False,