rahul7star commited on
Commit
65426a8
·
verified ·
1 Parent(s): 2e2e472

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -50
app.py CHANGED
@@ -1,73 +1,75 @@
1
  import os
2
  import sys
3
  import gradio as gr
4
- import subprocess
5
  from huggingface_hub import snapshot_download
6
- from gradio import Spaces
7
 
8
- # Use GPU
9
- @Spaces.GPU
10
- def dummy(): pass
11
-
12
- # Add PusaV1 to Python path
13
  sys.path.append(os.path.abspath("PusaV1"))
 
14
 
15
- # Install flash-attn in a safe way
16
- subprocess.run('pip install flash-attn --no-build-isolation', shell=True, env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"})
17
-
18
- # Download Wan2.1-T2V-14B model only (not the full repo)
19
  WAN_MODEL_DIR = "/tmp/model_zoo/Wan2.1-T2V-14B"
20
- os.makedirs(WAN_MODEL_DIR, exist_ok=True)
21
- snapshot_download(
22
- repo_id="RaphaelLiu/PusaV1",
23
- allow_patterns=["Wan2.1-T2V-14B/*"],
24
- local_dir=WAN_MODEL_DIR,
25
- local_dir_use_symlinks=False,
26
- resume_download=True
27
- )
28
-
29
- # Pusa Inference imports
30
- from diffsynth import ModelManager, WanVideoPusaPipeline, save_video, VideoData
31
 
 
32
  def generate_video(prompt, lora_upload):
33
- # Prepare LoRA path
 
 
 
 
 
 
 
 
 
 
34
  if lora_upload is not None:
35
  lora_path = lora_upload
36
  else:
37
- # Default PusaV1 LoRA
38
- default_lora_dir = "/tmp/model_zoo/PusaV1"
39
- os.makedirs(default_lora_dir, exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
- # Download all pusa_v1.pt.part* and merge them
42
- snapshot_download(
43
- repo_id="RaphaelLiu/PusaV1",
44
- allow_patterns=["PusaV1/pusa_v1.pt.part*"],
45
- local_dir=default_lora_dir,
46
- local_dir_use_symlinks=False
47
- )
48
- merged_path = os.path.join(default_lora_dir, "pusa_v1.pt")
49
- os.system(f"cat {default_lora_dir}/pusa_v1.pt.part* > {merged_path}")
50
- lora_path = merged_path
51
 
52
- # Run pipeline
53
- model_manager = ModelManager(pretrained_model_dir=WAN_MODEL_DIR)
54
- pipe = WanVideoPusaPipeline(model_manager=model_manager, lora_path=lora_path)
 
55
 
 
56
  result: VideoData = pipe(prompt)
57
- video_path = "/tmp/pusa_output.mp4"
 
 
 
58
  save_video(result.frames, video_path, fps=8)
 
59
  return video_path
60
 
61
- # Gradio UI
62
  with gr.Blocks() as demo:
63
- gr.Markdown("# 🎬 Pusa Text-to-Video Generator")
64
- with gr.Row():
65
- prompt = gr.Textbox(label="Prompt", value="A vibrant coral reef with sea turtles and sunlight.")
66
- lora_upload = gr.File(label="Upload .pt LoRA (optional)", file_types=[".pt"])
67
- generate_btn = gr.Button("Generate Video")
68
- output_video = gr.Video(label="Output")
69
 
70
- generate_btn.click(fn=generate_video, inputs=[prompt, lora_upload], outputs=output_video)
71
 
72
- if __name__ == "__main__":
73
- demo.launch()
 
1
  import os
2
  import sys
3
  import gradio as gr
4
+ import tempfile
5
  from huggingface_hub import snapshot_download
6
+ import spaces
7
 
 
 
 
 
 
8
  sys.path.append(os.path.abspath("PusaV1"))
9
+ from diffsynth import ModelManager, WanVideoPusaPipeline, save_video, VideoData
10
 
 
 
 
 
11
  WAN_MODEL_DIR = "/tmp/model_zoo/Wan2.1-T2V-14B"
12
+ LORA_DIR = "/tmp/model_zoo/PusaV1"
13
+ LORA_PATH = os.path.join(LORA_DIR, "pusa_v1.pt")
 
 
 
 
 
 
 
 
 
14
 
15
+ @spaces.GPU
16
  def generate_video(prompt, lora_upload):
17
+ # Download Wan2.1 model only if missing
18
+ if not os.path.exists(WAN_MODEL_DIR):
19
+ snapshot_download(
20
+ repo_id="RaphaelLiu/PusaV1",
21
+ allow_patterns=["Wan2.1-T2V-14B/*"],
22
+ local_dir=WAN_MODEL_DIR,
23
+ local_dir_use_symlinks=False,
24
+ resume_download=True,
25
+ )
26
+
27
+ # Handle LoRA file (upload or default download + stitch)
28
  if lora_upload is not None:
29
  lora_path = lora_upload
30
  else:
31
+ if not os.path.exists(LORA_PATH):
32
+ os.makedirs(LORA_DIR, exist_ok=True)
33
+ snapshot_download(
34
+ repo_id="RaphaelLiu/PusaV1",
35
+ allow_patterns=["PusaV1/pusa_v1.pt.part*"],
36
+ local_dir=LORA_DIR,
37
+ local_dir_use_symlinks=False,
38
+ )
39
+ # Stitch parts
40
+ part_files = sorted(
41
+ f for f in os.listdir(LORA_DIR) if f.startswith("pusa_v1.pt.part")
42
+ )
43
+ with open(LORA_PATH, "wb") as wfd:
44
+ for part in part_files:
45
+ with open(os.path.join(LORA_DIR, part), "rb") as fd:
46
+ wfd.write(fd.read())
47
 
48
+ lora_path = LORA_PATH
 
 
 
 
 
 
 
 
 
49
 
50
+ # Load model and pipeline
51
+ manager = ModelManager(pretrained_model_dir=WAN_MODEL_DIR)
52
+ pipe = WanVideoPusaPipeline(model_manager=manager)
53
+ pipe.set_lora_adapters(lora_path)
54
 
55
+ # Run generation
56
  result: VideoData = pipe(prompt)
57
+
58
+ # Save video to temp file
59
+ tmp_dir = tempfile.mkdtemp()
60
+ video_path = os.path.join(tmp_dir, "output.mp4")
61
  save_video(result.frames, video_path, fps=8)
62
+
63
  return video_path
64
 
65
+
66
  with gr.Blocks() as demo:
67
+ gr.Markdown("# 🎥 Pusa Text-to-Video (Wan2.1-T2V-14B)")
68
+ prompt = gr.Textbox(label="Prompt", lines=4)
69
+ lora_file = gr.File(label="Upload LoRA .pt (optional)", file_types=[".pt"])
70
+ generate_btn = gr.Button("Generate")
71
+ output_video = gr.Video(label="Generated Video")
 
72
 
73
+ generate_btn.click(fn=generate_video, inputs=[prompt, lora_file], outputs=output_video)
74
 
75
+ demo.launch()