rahul7star commited on
Commit
4046baa
·
verified ·
1 Parent(s): 1a49bcd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -87
app.py CHANGED
@@ -1,99 +1,73 @@
1
- import os
2
- import shutil
3
- import subprocess
4
- import gradio as gr
5
- from huggingface_hub import snapshot_download
6
- from huggingface_hub import snapshot_download
7
- import spaces
8
-
9
  import os
10
  import sys
11
  import gradio as gr
12
- import tempfile
13
- import torch
14
- import spaces
15
- # Add PusaV1 to path to resolve diffsynth imports
16
- sys.path.append(os.path.abspath("PusaV1"))
17
-
18
- # Import the actual model runner
19
- from diffsynth import ModelManager, WanVideoPusaPipeline, save_video
20
-
21
- # Define paths
22
- WAN_MODEL_DIR = "./model_zoo/Wan2.1-T2V-14B"
23
- LORA_PATH = "./model_zoo/PusaV1/pusa_v1.pt"
24
-
25
- MODEL_SUBFOLDER = "Wan2.1-T2V-14B"
26
- HF_REPO = "RaphaelLiu/PusaV1"
27
- MODEL_ZOO_DIR = "./model_zoo"
28
- MODEL_PARTS_DIR = os.path.join(MODEL_ZOO_DIR, MODEL_SUBFOLDER)
29
- FINAL_MODEL_PATH = os.path.join(MODEL_ZOO_DIR, "PusaV1", "pusa_v1.pt")
30
- PUSA_SCRIPT_PATH = "PusaV1/examples/pusavideo/wan_14b_text_to_video_pusa.py"
31
-
32
-
33
- def download_model_subset():
34
- if os.path.exists(FINAL_MODEL_PATH):
35
- print("✅ Model already exists. Skipping download.")
36
- return
37
-
38
- print("⏬ Downloading model parts...")
39
- snapshot_download(
40
- repo_id=HF_REPO,
41
- repo_type="model",
42
- local_dir=MODEL_ZOO_DIR,
43
- local_dir_use_symlinks=False,
44
- allow_patterns=[f"{MODEL_SUBFOLDER}/*"]
45
- )
46
- os.makedirs(os.path.dirname(FINAL_MODEL_PATH), exist_ok=True)
47
-
48
- part_files = sorted([
49
- os.path.join(MODEL_PARTS_DIR, f)
50
- for f in os.listdir(MODEL_PARTS_DIR)
51
- if f.startswith("pusa_v1.pt.part")
52
- ])
53
-
54
- print("🧩 Stitching model parts...")
55
- with open(FINAL_MODEL_PATH, 'wb') as f_out:
56
- for part in part_files:
57
- with open(part, 'rb') as f_in:
58
- shutil.copyfileobj(f_in, f_out)
59
-
60
- print(f"✅ Final model saved at {FINAL_MODEL_PATH}")
61
-
62
-
63
- @spaces.GPU
64
- def generate_video(prompt: str):
65
- try:
66
- # Load model manager
67
- manager = ModelManager(base_model_dir=WAN_MODEL_DIR)
68
- model = manager.load_model()
69
-
70
- # Create video pipeline and apply LoRA
71
- pipeline = WanVideoPusaPipeline(model=model)
72
- pipeline.set_lora_adapters(LORA_PATH)
73
-
74
- # Generate video
75
- result = pipeline(prompt=prompt)
76
 
77
- # Save video to a temporary file
78
- tmp_dir = tempfile.mkdtemp()
79
- video_path = os.path.join(tmp_dir, "output.mp4")
80
- save_video(result, video_path)
81
 
82
- return video_path
 
83
 
84
- except Exception as e:
85
- print(f"[ERROR] {e}")
86
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
  # Gradio UI
89
  with gr.Blocks() as demo:
90
- gr.Markdown("## 🎥 PusaV1 Text-to-Video Generator")
91
- gr.Markdown("Describe a scene and generate a short video using Wan2.1-T2V + Pusa LoRA!")
92
-
93
- prompt_input = gr.Textbox(label="Enter Prompt", lines=4, placeholder="E.g. A coral reef full of colorful fish...")
94
  generate_btn = gr.Button("Generate Video")
95
- video_output = gr.Video(label="Generated Video")
96
 
97
- generate_btn.click(fn=generate_video, inputs=prompt_input, outputs=video_output)
98
 
99
- demo.launch()
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import sys
3
  import gradio as gr
4
+ import subprocess
5
+ from huggingface_hub import snapshot_download
6
+ from gradio import Spaces
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
+ # Use GPU
9
+ @Spaces.GPU
10
+ def dummy(): pass
 
11
 
12
+ # Add PusaV1 to Python path
13
+ sys.path.append(os.path.abspath("PusaV1"))
14
 
15
+ # Install flash-attn in a safe way
16
+ subprocess.run('pip install flash-attn --no-build-isolation', shell=True, env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"})
17
+
18
+ # Download Wan2.1-T2V-14B model only (not the full repo)
19
+ WAN_MODEL_DIR = "/tmp/model_zoo/Wan2.1-T2V-14B"
20
+ os.makedirs(WAN_MODEL_DIR, exist_ok=True)
21
+ snapshot_download(
22
+ repo_id="RaphaelLiu/PusaV1",
23
+ allow_patterns=["Wan2.1-T2V-14B/*"],
24
+ local_dir=WAN_MODEL_DIR,
25
+ local_dir_use_symlinks=False,
26
+ resume_download=True
27
+ )
28
+
29
+ # Pusa Inference imports
30
+ from diffsynth import ModelManager, WanVideoPusaPipeline, save_video, VideoData
31
+
32
+ def generate_video(prompt, lora_upload):
33
+ # Prepare LoRA path
34
+ if lora_upload is not None:
35
+ lora_path = lora_upload
36
+ else:
37
+ # Default PusaV1 LoRA
38
+ default_lora_dir = "/tmp/model_zoo/PusaV1"
39
+ os.makedirs(default_lora_dir, exist_ok=True)
40
+
41
+ # Download all pusa_v1.pt.part* and merge them
42
+ snapshot_download(
43
+ repo_id="RaphaelLiu/PusaV1",
44
+ allow_patterns=["PusaV1/pusa_v1.pt.part*"],
45
+ local_dir=default_lora_dir,
46
+ local_dir_use_symlinks=False
47
+ )
48
+ merged_path = os.path.join(default_lora_dir, "pusa_v1.pt")
49
+ os.system(f"cat {default_lora_dir}/pusa_v1.pt.part* > {merged_path}")
50
+ lora_path = merged_path
51
+
52
+ # Run pipeline
53
+ model_manager = ModelManager(pretrained_model_dir=WAN_MODEL_DIR)
54
+ pipe = WanVideoPusaPipeline(model_manager=model_manager, lora_path=lora_path)
55
+
56
+ result: VideoData = pipe(prompt)
57
+ video_path = "/tmp/pusa_output.mp4"
58
+ save_video(result.frames, video_path, fps=8)
59
+ return video_path
60
 
61
  # Gradio UI
62
  with gr.Blocks() as demo:
63
+ gr.Markdown("# 🎬 Pusa Text-to-Video Generator")
64
+ with gr.Row():
65
+ prompt = gr.Textbox(label="Prompt", value="A vibrant coral reef with sea turtles and sunlight.")
66
+ lora_upload = gr.File(label="Upload .pt LoRA (optional)", file_types=[".pt"])
67
  generate_btn = gr.Button("Generate Video")
68
+ output_video = gr.Video(label="Output")
69
 
70
+ generate_btn.click(fn=generate_video, inputs=[prompt, lora_upload], outputs=output_video)
71
 
72
+ if __name__ == "__main__":
73
+ demo.launch()