Spaces:
				
			
			
	
			
			
		Paused
		
	
	
	
			
			
	
	
	
	
		
		
		Paused
		
	Update app.py
Browse files
    	
        app.py
    CHANGED
    
    | @@ -1,99 +1,73 @@ | |
| 1 | 
            -
            import os
         | 
| 2 | 
            -
            import shutil
         | 
| 3 | 
            -
            import subprocess
         | 
| 4 | 
            -
            import gradio as gr
         | 
| 5 | 
            -
            from huggingface_hub import snapshot_download
         | 
| 6 | 
            -
            from huggingface_hub import snapshot_download
         | 
| 7 | 
            -
            import spaces
         | 
| 8 | 
            -
             | 
| 9 | 
             
            import os
         | 
| 10 | 
             
            import sys
         | 
| 11 | 
             
            import gradio as gr
         | 
| 12 | 
            -
            import  | 
| 13 | 
            -
            import  | 
| 14 | 
            -
            import  | 
| 15 | 
            -
            # Add PusaV1 to path to resolve diffsynth imports
         | 
| 16 | 
            -
            sys.path.append(os.path.abspath("PusaV1"))
         | 
| 17 | 
            -
             | 
| 18 | 
            -
            # Import the actual model runner
         | 
| 19 | 
            -
            from diffsynth import ModelManager, WanVideoPusaPipeline, save_video
         | 
| 20 | 
            -
             | 
| 21 | 
            -
            # Define paths
         | 
| 22 | 
            -
            WAN_MODEL_DIR = "./model_zoo/Wan2.1-T2V-14B"
         | 
| 23 | 
            -
            LORA_PATH = "./model_zoo/PusaV1/pusa_v1.pt"
         | 
| 24 | 
            -
             | 
| 25 | 
            -
            MODEL_SUBFOLDER = "Wan2.1-T2V-14B"
         | 
| 26 | 
            -
            HF_REPO = "RaphaelLiu/PusaV1"
         | 
| 27 | 
            -
            MODEL_ZOO_DIR = "./model_zoo"
         | 
| 28 | 
            -
            MODEL_PARTS_DIR = os.path.join(MODEL_ZOO_DIR, MODEL_SUBFOLDER)
         | 
| 29 | 
            -
            FINAL_MODEL_PATH = os.path.join(MODEL_ZOO_DIR, "PusaV1", "pusa_v1.pt")
         | 
| 30 | 
            -
            PUSA_SCRIPT_PATH = "PusaV1/examples/pusavideo/wan_14b_text_to_video_pusa.py"
         | 
| 31 | 
            -
             | 
| 32 | 
            -
             | 
| 33 | 
            -
            def download_model_subset():
         | 
| 34 | 
            -
                if os.path.exists(FINAL_MODEL_PATH):
         | 
| 35 | 
            -
                    print("✅ Model already exists. Skipping download.")
         | 
| 36 | 
            -
                    return
         | 
| 37 | 
            -
             | 
| 38 | 
            -
                print("⏬ Downloading model parts...")
         | 
| 39 | 
            -
                snapshot_download(
         | 
| 40 | 
            -
                    repo_id=HF_REPO,
         | 
| 41 | 
            -
                    repo_type="model",
         | 
| 42 | 
            -
                    local_dir=MODEL_ZOO_DIR,
         | 
| 43 | 
            -
                    local_dir_use_symlinks=False,
         | 
| 44 | 
            -
                    allow_patterns=[f"{MODEL_SUBFOLDER}/*"]
         | 
| 45 | 
            -
                )
         | 
| 46 | 
            -
                os.makedirs(os.path.dirname(FINAL_MODEL_PATH), exist_ok=True)
         | 
| 47 | 
            -
             | 
| 48 | 
            -
                part_files = sorted([
         | 
| 49 | 
            -
                    os.path.join(MODEL_PARTS_DIR, f)
         | 
| 50 | 
            -
                    for f in os.listdir(MODEL_PARTS_DIR)
         | 
| 51 | 
            -
                    if f.startswith("pusa_v1.pt.part")
         | 
| 52 | 
            -
                ])
         | 
| 53 | 
            -
             | 
| 54 | 
            -
                print("🧩 Stitching model parts...")
         | 
| 55 | 
            -
                with open(FINAL_MODEL_PATH, 'wb') as f_out:
         | 
| 56 | 
            -
                    for part in part_files:
         | 
| 57 | 
            -
                        with open(part, 'rb') as f_in:
         | 
| 58 | 
            -
                            shutil.copyfileobj(f_in, f_out)
         | 
| 59 | 
            -
             | 
| 60 | 
            -
                print(f"✅ Final model saved at {FINAL_MODEL_PATH}")
         | 
| 61 | 
            -
             | 
| 62 | 
            -
             | 
| 63 | 
            -
            @spaces.GPU
         | 
| 64 | 
            -
            def generate_video(prompt: str):
         | 
| 65 | 
            -
                try:
         | 
| 66 | 
            -
                    # Load model manager
         | 
| 67 | 
            -
                    manager = ModelManager(base_model_dir=WAN_MODEL_DIR)
         | 
| 68 | 
            -
                    model = manager.load_model()
         | 
| 69 | 
            -
             | 
| 70 | 
            -
                    # Create video pipeline and apply LoRA
         | 
| 71 | 
            -
                    pipeline = WanVideoPusaPipeline(model=model)
         | 
| 72 | 
            -
                    pipeline.set_lora_adapters(LORA_PATH)
         | 
| 73 | 
            -
             | 
| 74 | 
            -
                    # Generate video
         | 
| 75 | 
            -
                    result = pipeline(prompt=prompt)
         | 
| 76 |  | 
| 77 | 
            -
             | 
| 78 | 
            -
             | 
| 79 | 
            -
             | 
| 80 | 
            -
                    save_video(result, video_path)
         | 
| 81 |  | 
| 82 | 
            -
             | 
|  | |
| 83 |  | 
| 84 | 
            -
             | 
| 85 | 
            -
             | 
| 86 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 87 |  | 
| 88 | 
             
            # Gradio UI
         | 
| 89 | 
             
            with gr.Blocks() as demo:
         | 
| 90 | 
            -
                gr.Markdown(" | 
| 91 | 
            -
                gr. | 
| 92 | 
            -
             | 
| 93 | 
            -
             | 
| 94 | 
             
                generate_btn = gr.Button("Generate Video")
         | 
| 95 | 
            -
                 | 
| 96 |  | 
| 97 | 
            -
                generate_btn.click(fn=generate_video, inputs= | 
| 98 |  | 
| 99 | 
            -
             | 
|  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 1 | 
             
            import os
         | 
| 2 | 
             
            import sys
         | 
| 3 | 
             
            import gradio as gr
         | 
| 4 | 
            +
            import subprocess
         | 
| 5 | 
            +
            from huggingface_hub import snapshot_download
         | 
| 6 | 
            +
            from gradio import Spaces
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 7 |  | 
| 8 | 
            +
            # Use GPU
         | 
| 9 | 
            +
            @Spaces.GPU
         | 
| 10 | 
            +
            def dummy(): pass
         | 
|  | |
| 11 |  | 
| 12 | 
            +
            # Add PusaV1 to Python path
         | 
| 13 | 
            +
            sys.path.append(os.path.abspath("PusaV1"))
         | 
| 14 |  | 
| 15 | 
            +
            # Install flash-attn in a safe way
         | 
| 16 | 
            +
            subprocess.run('pip install flash-attn --no-build-isolation', shell=True, env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"})
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            # Download Wan2.1-T2V-14B model only (not the full repo)
         | 
| 19 | 
            +
            WAN_MODEL_DIR = "/tmp/model_zoo/Wan2.1-T2V-14B"
         | 
| 20 | 
            +
            os.makedirs(WAN_MODEL_DIR, exist_ok=True)
         | 
| 21 | 
            +
            snapshot_download(
         | 
| 22 | 
            +
                repo_id="RaphaelLiu/PusaV1",
         | 
| 23 | 
            +
                allow_patterns=["Wan2.1-T2V-14B/*"],
         | 
| 24 | 
            +
                local_dir=WAN_MODEL_DIR,
         | 
| 25 | 
            +
                local_dir_use_symlinks=False,
         | 
| 26 | 
            +
                resume_download=True
         | 
| 27 | 
            +
            )
         | 
| 28 | 
            +
             | 
| 29 | 
            +
            # Pusa Inference imports
         | 
| 30 | 
            +
            from diffsynth import ModelManager, WanVideoPusaPipeline, save_video, VideoData
         | 
| 31 | 
            +
             | 
| 32 | 
            +
            def generate_video(prompt, lora_upload):
         | 
| 33 | 
            +
                # Prepare LoRA path
         | 
| 34 | 
            +
                if lora_upload is not None:
         | 
| 35 | 
            +
                    lora_path = lora_upload
         | 
| 36 | 
            +
                else:
         | 
| 37 | 
            +
                    # Default PusaV1 LoRA
         | 
| 38 | 
            +
                    default_lora_dir = "/tmp/model_zoo/PusaV1"
         | 
| 39 | 
            +
                    os.makedirs(default_lora_dir, exist_ok=True)
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                    # Download all pusa_v1.pt.part* and merge them
         | 
| 42 | 
            +
                    snapshot_download(
         | 
| 43 | 
            +
                        repo_id="RaphaelLiu/PusaV1",
         | 
| 44 | 
            +
                        allow_patterns=["PusaV1/pusa_v1.pt.part*"],
         | 
| 45 | 
            +
                        local_dir=default_lora_dir,
         | 
| 46 | 
            +
                        local_dir_use_symlinks=False
         | 
| 47 | 
            +
                    )
         | 
| 48 | 
            +
                    merged_path = os.path.join(default_lora_dir, "pusa_v1.pt")
         | 
| 49 | 
            +
                    os.system(f"cat {default_lora_dir}/pusa_v1.pt.part* > {merged_path}")
         | 
| 50 | 
            +
                    lora_path = merged_path
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                # Run pipeline
         | 
| 53 | 
            +
                model_manager = ModelManager(pretrained_model_dir=WAN_MODEL_DIR)
         | 
| 54 | 
            +
                pipe = WanVideoPusaPipeline(model_manager=model_manager, lora_path=lora_path)
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                result: VideoData = pipe(prompt)
         | 
| 57 | 
            +
                video_path = "/tmp/pusa_output.mp4"
         | 
| 58 | 
            +
                save_video(result.frames, video_path, fps=8)
         | 
| 59 | 
            +
                return video_path
         | 
| 60 |  | 
| 61 | 
             
            # Gradio UI
         | 
| 62 | 
             
            with gr.Blocks() as demo:
         | 
| 63 | 
            +
                gr.Markdown("# 🎬 Pusa Text-to-Video Generator")
         | 
| 64 | 
            +
                with gr.Row():
         | 
| 65 | 
            +
                    prompt = gr.Textbox(label="Prompt", value="A vibrant coral reef with sea turtles and sunlight.")
         | 
| 66 | 
            +
                    lora_upload = gr.File(label="Upload .pt LoRA (optional)", file_types=[".pt"])
         | 
| 67 | 
             
                generate_btn = gr.Button("Generate Video")
         | 
| 68 | 
            +
                output_video = gr.Video(label="Output")
         | 
| 69 |  | 
| 70 | 
            +
                generate_btn.click(fn=generate_video, inputs=[prompt, lora_upload], outputs=output_video)
         | 
| 71 |  | 
| 72 | 
            +
            if __name__ == "__main__":
         | 
| 73 | 
            +
                demo.launch()
         | 
