PusaV1 / app.py
rahul7star's picture
Update app.py
02f7f0d verified
raw
history blame
2.53 kB
import gradio as gr
import os
import tempfile
from huggingface_hub import snapshot_download
import sys, os
sys.path.insert(0, os.path.abspath("./PusaV1"))
import spaces
import sys, os
# Add PusaV1 to sys.path if not already
PUSA_PATH = os.path.abspath("./PusaV1")
if PUSA_PATH not in sys.path:
sys.path.insert(0, PUSA_PATH)
# Validate diffsynth presence
DIFFSYNTH_PATH = os.path.join(PUSA_PATH, "diffsynth")
if not os.path.exists(DIFFSYNTH_PATH):
raise RuntimeError(
f"'diffsynth' package not found in {PUSA_PATH}. "
f"Ensure PusaV1 is correctly cloned and folder structure is intact."
)
if os.path.exists(setup_file):
subprocess.run([sys.executable, setup_file, "install"], check=False)
from PusaV1.diffsynth import ModelManager, WanVideoPusaPipeline, save_video
# Constants
WAN_SUBFOLDER = "Wan2.1-T2V-14B"
MODEL_REPO_ID = "RaphaelLiu/PusaV1"
MODEL_ZOO_DIR = "./model_zoo"
WAN_MODEL_PATH = os.path.join(MODEL_ZOO_DIR, WAN_SUBFOLDER)
LORA_PATH = os.path.join(MODEL_ZOO_DIR, "PusaV1", "pusa_v1.pt")
# Ensure model is downloaded
def ensure_model_downloaded():
if not os.path.exists(WAN_MODEL_PATH):
print("Downloading Wan2.1-T2V-14B from HuggingFace Hub...")
snapshot_download(
repo_id=MODEL_REPO_ID,
local_dir=MODEL_ZOO_DIR,
repo_type="model",
allow_patterns=[f"{WAN_SUBFOLDER}/**"],
local_dir_use_symlinks=False,
)
print("Model downloaded.")
# Video generation logic
@spaces.GPU
def generate_video(prompt: str):
ensure_model_downloaded()
# Load model
manager = ModelManager(pretrained_model_dir=WAN_MODEL_PATH)
model = manager.load_model()
# Set up pipeline
pipeline = WanVideoPusaPipeline(model=model)
pipeline.set_lora_adapters(LORA_PATH)
# Generate video
result = pipeline(prompt)
# Save video
tmp_dir = tempfile.mkdtemp()
output_path = os.path.join(tmp_dir, "video.mp4")
save_video(result.frames, output_path, fps=8)
return output_path
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("## πŸŽ₯ Wan2.1-T2V-14B with Pusa LoRA | Text-to-Video Generator")
prompt_input = gr.Textbox(
lines=4,
label="Prompt",
placeholder="Describe your video (e.g. A coral reef full of colorful fish...)"
)
generate_btn = gr.Button("Generate Video")
video_output = gr.Video(label="Output")
generate_btn.click(fn=generate_video, inputs=prompt_input, outputs=video_output)
demo.launch()