Spaces:
Paused
Paused
| import gradio as gr | |
| import os | |
| import tempfile | |
| from huggingface_hub import snapshot_download | |
| import sys, os | |
| sys.path.insert(0, os.path.abspath("./PusaV1")) | |
| import spaces | |
| import sys, os | |
| # Add PusaV1 to sys.path if not already | |
| PUSA_PATH = os.path.abspath("./PusaV1") | |
| if PUSA_PATH not in sys.path: | |
| sys.path.insert(0, PUSA_PATH) | |
| # Validate diffsynth presence | |
| DIFFSYNTH_PATH = os.path.join(PUSA_PATH, "diffsynth") | |
| if not os.path.exists(DIFFSYNTH_PATH): | |
| raise RuntimeError( | |
| f"'diffsynth' package not found in {PUSA_PATH}. " | |
| f"Ensure PusaV1 is correctly cloned and folder structure is intact." | |
| ) | |
| if os.path.exists(setup_file): | |
| subprocess.run([sys.executable, setup_file, "install"], check=False) | |
| from PusaV1.diffsynth import ModelManager, WanVideoPusaPipeline, save_video | |
| # Constants | |
| WAN_SUBFOLDER = "Wan2.1-T2V-14B" | |
| MODEL_REPO_ID = "RaphaelLiu/PusaV1" | |
| MODEL_ZOO_DIR = "./model_zoo" | |
| WAN_MODEL_PATH = os.path.join(MODEL_ZOO_DIR, WAN_SUBFOLDER) | |
| LORA_PATH = os.path.join(MODEL_ZOO_DIR, "PusaV1", "pusa_v1.pt") | |
| # Ensure model is downloaded | |
| def ensure_model_downloaded(): | |
| if not os.path.exists(WAN_MODEL_PATH): | |
| print("Downloading Wan2.1-T2V-14B from HuggingFace Hub...") | |
| snapshot_download( | |
| repo_id=MODEL_REPO_ID, | |
| local_dir=MODEL_ZOO_DIR, | |
| repo_type="model", | |
| allow_patterns=[f"{WAN_SUBFOLDER}/**"], | |
| local_dir_use_symlinks=False, | |
| ) | |
| print("Model downloaded.") | |
| # Video generation logic | |
| def generate_video(prompt: str): | |
| ensure_model_downloaded() | |
| # Load model | |
| manager = ModelManager(pretrained_model_dir=WAN_MODEL_PATH) | |
| model = manager.load_model() | |
| # Set up pipeline | |
| pipeline = WanVideoPusaPipeline(model=model) | |
| pipeline.set_lora_adapters(LORA_PATH) | |
| # Generate video | |
| result = pipeline(prompt) | |
| # Save video | |
| tmp_dir = tempfile.mkdtemp() | |
| output_path = os.path.join(tmp_dir, "video.mp4") | |
| save_video(result.frames, output_path, fps=8) | |
| return output_path | |
| # Gradio UI | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## π₯ Wan2.1-T2V-14B with Pusa LoRA | Text-to-Video Generator") | |
| prompt_input = gr.Textbox( | |
| lines=4, | |
| label="Prompt", | |
| placeholder="Describe your video (e.g. A coral reef full of colorful fish...)" | |
| ) | |
| generate_btn = gr.Button("Generate Video") | |
| video_output = gr.Video(label="Output") | |
| generate_btn.click(fn=generate_video, inputs=prompt_input, outputs=video_output) | |
| demo.launch() | |