Spaces:
Paused
Paused
| import gradio as gr | |
| import os | |
| import torch | |
| import tempfile | |
| import sys | |
| from huggingface_hub import snapshot_download | |
| import spaces | |
| # Setup paths | |
| PUSA_PATH = os.path.abspath("./PusaV1") | |
| if PUSA_PATH not in sys.path: | |
| sys.path.insert(0, PUSA_PATH) | |
| # Validate diffsynth presence | |
| DIFFSYNTH_PATH = os.path.join(PUSA_PATH, "diffsynth") | |
| if not os.path.exists(DIFFSYNTH_PATH): | |
| raise RuntimeError( | |
| f"'diffsynth' package not found in {PUSA_PATH}. " | |
| f"Ensure PusaV1 is correctly cloned and folder structure is intact." | |
| ) | |
| # Import core modules from PusaV1 | |
| from PusaV1.diffsynth import ModelManager, WanVideoPusaPipeline, save_video | |
| class PatchedModelManager(ModelManager): | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| # Patch architecture dict here | |
| custom_architecture_dict = { | |
| "WanModel": ("diffsynth.models.wan_model", "WanModelPusa", None), | |
| } | |
| self.architecture_dict.update(custom_architecture_dict) | |
| # Constants | |
| import os | |
| from huggingface_hub import snapshot_download | |
| # Constants | |
| MODEL_ZOO_DIR = "./model_zoo" | |
| PUSA_DIR = os.path.join(MODEL_ZOO_DIR, "PusaV1") | |
| WAN_SUBFOLDER = "Wan2.1-T2V-14B" | |
| WAN_MODEL_PATH = os.path.join(PUSA_DIR, WAN_SUBFOLDER) | |
| LORA_PATH = os.path.join(PUSA_DIR, "pusa_v1.pt") | |
| # Ensure model and weights are downloaded | |
| def ensure_model_downloaded(): | |
| if not os.path.exists(PUSA_DIR): | |
| print("Downloading RaphaelLiu/PusaV1 to ./model_zoo/PusaV1 ...") | |
| snapshot_download( | |
| repo_id="RaphaelLiu/PusaV1", | |
| local_dir=PUSA_DIR, | |
| repo_type="model", | |
| local_dir_use_symlinks=False, | |
| ) | |
| print("✅ PusaV1 downloaded.") | |
| if not os.path.exists(WAN_MODEL_PATH): | |
| print("Downloading Wan-AI/Wan2.1-T2V-14B to ./model_zoo/PusaV1/Wan2.1-T2V-14B ...") | |
| snapshot_download( | |
| repo_id="Wan-AI/Wan2.1-T2V-14B", | |
| local_dir=WAN_MODEL_PATH, # Changed from WAN_DIR to WAN_MODEL_PATH | |
| repo_type="model", | |
| local_dir_use_symlinks=False, | |
| ) | |
| print("✅ Wan2.1-T2V-14B downloaded.") | |
| # Subclass ModelManager to force WanModelPusa | |
| class PatchedModelManager(ModelManager): | |
| def load_model(self, file_path=None, model_names=None, device=None, torch_dtype=None): | |
| if file_path is None: | |
| file_path = self.file_path_list[0] | |
| print(f"[app.py] Forcing architecture: WanModelPusa for {file_path}") | |
| for detector in self.model_detector: | |
| if detector.match(file_path, {}): | |
| model_names, models = detector.load( | |
| file_path, | |
| state_dict={}, | |
| device=device or self.device, | |
| torch_dtype=torch_dtype or self.torch_dtype, | |
| allowed_model_names=model_names, | |
| model_manager=self, | |
| forced_architecture="WanModelPusa" | |
| ) | |
| for name, model in zip(model_names, models): | |
| self.model.append(model) | |
| self.model_path.append(file_path) | |
| self.model_name.append(name) | |
| return models[0] if models else None | |
| print("No suitable model detector matched.") | |
| return None | |
| # Video generation logic | |
| def generate_t2v_video(self, prompt, lora_alpha, num_inference_steps, | |
| negative_prompt, progress=gr.Progress()): | |
| """Generate video from text prompt""" | |
| try: | |
| progress(0.1, desc="Loading models...") | |
| lora_path = "./model_zoo/PusaV1/pusa_v1.pt" | |
| pipe = self.load_lora_and_get_pipe("t2v", lora_path, lora_alpha) | |
| progress(0.3, desc="Generating video...") | |
| video = pipe( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| num_inference_steps=num_inference_steps, | |
| height=720, width=1280, num_frames=81, | |
| seed=0, tiled=True | |
| ) | |
| progress(0.9, desc="Saving video...") | |
| timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") | |
| video_filename = os.path.join(self.output_dir, f"t2v_output_{timestamp}.mp4") | |
| save_video(video, video_filename, fps=25, quality=5) | |
| progress(1.0, desc="Complete!") | |
| return video_filename, f"Video generated successfully! Saved to {video_filename}" | |
| except Exception as e: | |
| return None, f"Error: {str(e)}" | |
| def generate_video(prompt: str): | |
| # Load model using patched manager | |
| model_manager = ModelManager(device="cuda") | |
| base_dir = "model_zoo/PusaV1/Wan2.1-T2V-14B" | |
| model_files = sorted([os.path.join(base_dir, f) for f in os.listdir(base_dir) if f.endswith('.safetensors')]) | |
| model_manager.load_models( | |
| [ | |
| model_files, | |
| os.path.join(base_dir, "models_t5_umt5-xxl-enc-bf16.pth"), | |
| os.path.join(base_dir, "Wan2.1_VAE.pth"), | |
| ], | |
| torch_dtype=torch.bfloat16, | |
| ) | |
| # manager = ModelManager( | |
| # file_path_list=[WAN_MODEL_PATH], | |
| # torch_dtype=torch.float16, | |
| # device="cuda" | |
| # ) | |
| # manager = PatchedModelManager( | |
| # file_path_list=[WAN_MODEL_PATH], | |
| # torch_dtype=torch.float16, | |
| # device="cuda" | |
| # ) | |
| #model = manager.load_model(WAN_MODEL_PATH) | |
| # Set up pipeline | |
| #pipeline = WanVideoPusaPipeline(model=model_manager) | |
| pipeline = WanVideoPusaPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda") | |
| #pipeline.set_lora_adapters(LORA_PATH) | |
| # Generate video | |
| result = pipeline(prompt) | |
| # Save video | |
| tmp_dir = tempfile.mkdtemp() | |
| output_path = os.path.join(tmp_dir, "video.mp4") | |
| save_video(result.frames, output_path, fps=8) | |
| return output_path | |
| # Gradio UI | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## 🎥 Wan2.1-T2V-14B with Pusa LoRA | Text-to-Video Generator") | |
| prompt_input = gr.Textbox( | |
| lines=4, | |
| label="Prompt", | |
| placeholder="Describe your video (e.g. A coral reef full of colorful fish...)" | |
| ) | |
| generate_btn = gr.Button("Generate Video") | |
| video_output = gr.Video(label="Output") | |
| generate_btn.click(fn=generate_video, inputs=prompt_input, outputs=video_output) | |
| if __name__ == "__main__": | |
| ensure_model_downloaded() | |
| demo.launch(share=True, show_error=True) | |