Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,60 +1,82 @@
|
|
| 1 |
import os
|
| 2 |
-
import gradio as gr
|
| 3 |
-
import subprocess
|
| 4 |
import shutil
|
| 5 |
-
import
|
|
|
|
| 6 |
from huggingface_hub import snapshot_download
|
| 7 |
|
| 8 |
-
# Paths
|
| 9 |
-
PUSA_REPO = "./PusaV1"
|
| 10 |
-
PUSA_SCRIPT = os.path.join(PUSA_REPO, "examples/pusavideo/wan_14b_text_to_video_pusa.py")
|
| 11 |
-
MODEL_DIR = "./model_zoo/PusaV1"
|
| 12 |
-
MODEL_PATH = os.path.join(MODEL_DIR, "pusa_v1.pt")
|
| 13 |
-
OUTPUT_VIDEO_PATH = os.path.join(tempfile.gettempdir(), "output.mp4")
|
| 14 |
-
|
| 15 |
-
def setup_dependencies():
|
| 16 |
-
subprocess.run(["pip", "install", "xfuser>=0.4.3", "absl-py", "peft", "lightning", "pandas", "deepspeed", "wandb", "av"])
|
| 17 |
-
subprocess.run(
|
| 18 |
-
'pip install flash-attn --no-build-isolation',
|
| 19 |
-
shell=True,
|
| 20 |
-
env={**os.environ, "FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}
|
| 21 |
-
)
|
| 22 |
|
| 23 |
-
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
return
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
for part in part_files:
|
| 31 |
-
with open(
|
| 32 |
-
shutil.copyfileobj(
|
|
|
|
|
|
|
| 33 |
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
"--prompt", prompt,
|
| 39 |
-
"--lora_path",
|
|
|
|
| 40 |
]
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
|
|
|
|
|
|
| 2 |
import shutil
|
| 3 |
+
import subprocess
|
| 4 |
+
import gradio as gr
|
| 5 |
from huggingface_hub import snapshot_download
|
| 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
+
MODEL_SUBFOLDER = "Wan2.1-T2V-14B"
|
| 9 |
+
HF_REPO = "RaphaelLiu/PusaV1"
|
| 10 |
+
MODEL_ZOO_DIR = "./model_zoo"
|
| 11 |
+
MODEL_PARTS_DIR = os.path.join(MODEL_ZOO_DIR, MODEL_SUBFOLDER)
|
| 12 |
+
FINAL_MODEL_PATH = os.path.join(MODEL_ZOO_DIR, "PusaV1", "pusa_v1.pt")
|
| 13 |
+
PUSA_SCRIPT_PATH = "PusaV1/examples/pusavideo/wan_14b_text_to_video_pusa.py"
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def download_model_subset():
|
| 17 |
+
if os.path.exists(FINAL_MODEL_PATH):
|
| 18 |
+
print("β
Model already exists. Skipping download.")
|
| 19 |
return
|
| 20 |
+
|
| 21 |
+
print("β¬ Downloading model parts...")
|
| 22 |
+
snapshot_download(
|
| 23 |
+
repo_id=HF_REPO,
|
| 24 |
+
repo_type="model",
|
| 25 |
+
local_dir=MODEL_ZOO_DIR,
|
| 26 |
+
local_dir_use_symlinks=False,
|
| 27 |
+
allow_patterns=[f"{MODEL_SUBFOLDER}/*"]
|
| 28 |
+
)
|
| 29 |
+
os.makedirs(os.path.dirname(FINAL_MODEL_PATH), exist_ok=True)
|
| 30 |
+
|
| 31 |
+
part_files = sorted([
|
| 32 |
+
os.path.join(MODEL_PARTS_DIR, f)
|
| 33 |
+
for f in os.listdir(MODEL_PARTS_DIR)
|
| 34 |
+
if f.startswith("pusa_v1.pt.part")
|
| 35 |
+
])
|
| 36 |
+
|
| 37 |
+
print("π§© Stitching model parts...")
|
| 38 |
+
with open(FINAL_MODEL_PATH, 'wb') as f_out:
|
| 39 |
for part in part_files:
|
| 40 |
+
with open(part, 'rb') as f_in:
|
| 41 |
+
shutil.copyfileobj(f_in, f_out)
|
| 42 |
+
|
| 43 |
+
print(f"β
Final model saved at {FINAL_MODEL_PATH}")
|
| 44 |
|
| 45 |
+
|
| 46 |
+
def generate_video(prompt):
|
| 47 |
+
download_model_subset()
|
| 48 |
+
|
| 49 |
+
temp_output_dir = "/tmp/pusa_video_output"
|
| 50 |
+
os.makedirs(temp_output_dir, exist_ok=True)
|
| 51 |
+
|
| 52 |
+
command = [
|
| 53 |
+
"python", PUSA_SCRIPT_PATH,
|
| 54 |
"--prompt", prompt,
|
| 55 |
+
"--lora_path", FINAL_MODEL_PATH,
|
| 56 |
+
"--output_dir", temp_output_dir
|
| 57 |
]
|
| 58 |
+
|
| 59 |
+
try:
|
| 60 |
+
print("π Running inference...")
|
| 61 |
+
subprocess.run(command, check=True)
|
| 62 |
+
|
| 63 |
+
# Return first mp4 video found
|
| 64 |
+
for file in os.listdir(temp_output_dir):
|
| 65 |
+
if file.endswith(".mp4"):
|
| 66 |
+
return os.path.join(temp_output_dir, file)
|
| 67 |
+
|
| 68 |
+
return "β No video generated."
|
| 69 |
+
|
| 70 |
+
except subprocess.CalledProcessError as e:
|
| 71 |
+
return f"β Inference failed: {str(e)}"
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
with gr.Blocks() as demo:
|
| 75 |
+
gr.Markdown("## π§ββοΈ PusaV1 Text-to-Video Generator (Wan2.1-T2V-14B)")
|
| 76 |
+
prompt_input = gr.Textbox(label="Enter your prompt", lines=4, placeholder="A coral reef full of colorful fish...")
|
| 77 |
+
generate_button = gr.Button("Generate Video")
|
| 78 |
+
video_output = gr.Video(label="Generated Video")
|
| 79 |
+
|
| 80 |
+
generate_button.click(fn=generate_video, inputs=prompt_input, outputs=video_output)
|
| 81 |
+
|
| 82 |
+
demo.launch()
|