Spaces:
Running
on
Zero
Running
on
Zero
| import json | |
| import os | |
| import random | |
| import tempfile | |
| from typing import Any, List, Union | |
| import gradio as gr | |
| import numpy as np | |
| import spaces | |
| import torch | |
| import trimesh | |
| from gradio_image_prompter import ImagePrompter | |
| from gradio_litmodel3d import LitModel3D | |
| from huggingface_hub import snapshot_download | |
| from PIL import Image | |
| from skimage import measure | |
| from transformers import AutoModelForMaskGeneration, AutoProcessor | |
| from midi.pipelines.pipeline_midi import MIDIPipeline | |
| from midi.utils.smoothing import smooth_gpu | |
| from scripts.grounding_sam import plot_segmentation, segment | |
| from scripts.inference_midi import preprocess_image, split_rgb_mask | |
| # Constants | |
| MAX_SEED = np.iinfo(np.int32).max | |
| TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp") | |
| DTYPE = torch.bfloat16 | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| REPO_ID = "VAST-AI/MIDI-3D" | |
| MARKDOWN = """ | |
| ## Image to 3D Scene with [MIDI-3D](https://huanngzh.github.io/MIDI-Page/) | |
| <b>Important!</b> Please check out our [instruction video](https://github.com/user-attachments/assets/814c046e-f5c3-47cf-bb56-60154be8374c)! | |
| 1. Upload an image, and draw bounding boxes for each instance by holding and dragging the mouse. Then clik "Run Segmentation" to generate the segmentation result. <b>Ensure instances should not be too small and bounding boxes fit snugly around each instance.</b> | |
| 2. <b>Check "Do image padding" in "Generation Settings" if instances in your image are too close to the image border.</b> Then click "Run Generation" to generate a 3D scene from the image and segmentation result. | |
| 3. If you find the generated 3D scene satisfactory, download it by clicking the "Download GLB" button. | |
| """ | |
| EXAMPLES = [ | |
| [ | |
| { | |
| "image": "assets/example_data/Cartoon-Style/03_rgb.png", | |
| }, | |
| "assets/example_data/Cartoon-Style/03_seg.png", | |
| 42, | |
| False, | |
| False, | |
| ], | |
| [ | |
| { | |
| "image": "assets/example_data/Cartoon-Style/01_rgb.png", | |
| }, | |
| "assets/example_data/Cartoon-Style/01_seg.png", | |
| 42, | |
| False, | |
| False, | |
| ], | |
| [ | |
| { | |
| "image": "assets/example_data/Realistic-Style/02_rgb.png", | |
| }, | |
| "assets/example_data/Realistic-Style/02_seg.png", | |
| 42, | |
| False, | |
| False, | |
| ], | |
| [ | |
| { | |
| "image": "assets/example_data/Cartoon-Style/00_rgb.png", | |
| }, | |
| "assets/example_data/Cartoon-Style/00_seg.png", | |
| 42, | |
| False, | |
| False, | |
| ], | |
| [ | |
| { | |
| "image": "assets/example_data/Realistic-Style/00_rgb.png", | |
| }, | |
| "assets/example_data/Realistic-Style/00_seg.png", | |
| 42, | |
| False, | |
| True, | |
| ], | |
| [ | |
| { | |
| "image": "assets/example_data/Realistic-Style/01_rgb.png", | |
| }, | |
| "assets/example_data/Realistic-Style/01_seg.png", | |
| 42, | |
| False, | |
| True, | |
| ], | |
| [ | |
| { | |
| "image": "assets/example_data/Realistic-Style/05_rgb.png", | |
| }, | |
| "assets/example_data/Realistic-Style/05_seg.png", | |
| 42, | |
| False, | |
| False, | |
| ], | |
| ] | |
| os.makedirs(TMP_DIR, exist_ok=True) | |
| # Prepare models | |
| ## Grounding SAM | |
| segmenter_id = "facebook/sam-vit-base" | |
| sam_processor = AutoProcessor.from_pretrained(segmenter_id) | |
| sam_segmentator = AutoModelForMaskGeneration.from_pretrained(segmenter_id).to( | |
| DEVICE, DTYPE | |
| ) | |
| ## MIDI-3D | |
| local_dir = "pretrained_weights/MIDI-3D" | |
| snapshot_download(repo_id=REPO_ID, local_dir=local_dir) | |
| pipe: MIDIPipeline = MIDIPipeline.from_pretrained(local_dir).to(DEVICE, DTYPE) | |
| pipe.init_custom_adapter( | |
| set_self_attn_module_names=[ | |
| "blocks.8", | |
| "blocks.9", | |
| "blocks.10", | |
| "blocks.11", | |
| "blocks.12", | |
| ] | |
| ) | |
| # Utils | |
| def get_random_hex(): | |
| random_bytes = os.urandom(8) | |
| random_hex = random_bytes.hex() | |
| return random_hex | |
| def run_segmentation(image_prompts: Any, polygon_refinement: bool) -> Image.Image: | |
| rgb_image = image_prompts["image"].convert("RGB") | |
| # pre-process the layers and get the xyxy boxes of each layer | |
| if len(image_prompts["points"]) == 0: | |
| gr.Error("Please draw bounding boxes for each instance on the image.") | |
| boxes = [ | |
| [ | |
| [int(box[0]), int(box[1]), int(box[3]), int(box[4])] | |
| for box in image_prompts["points"] | |
| ] | |
| ] | |
| # run the segmentation | |
| detections = segment( | |
| sam_processor, | |
| sam_segmentator, | |
| rgb_image, | |
| boxes=[boxes], | |
| polygon_refinement=polygon_refinement, | |
| ) | |
| seg_map_pil = plot_segmentation(rgb_image, detections) | |
| torch.cuda.empty_cache() | |
| return seg_map_pil | |
| def run_midi( | |
| pipe: Any, | |
| rgb_image: Union[str, Image.Image], | |
| seg_image: Union[str, Image.Image], | |
| seed: int, | |
| num_inference_steps: int = 50, | |
| guidance_scale: float = 7.0, | |
| do_image_padding: bool = False, | |
| ) -> trimesh.Scene: | |
| if do_image_padding: | |
| rgb_image, seg_image = preprocess_image(rgb_image, seg_image) | |
| instance_rgbs, instance_masks, scene_rgbs = split_rgb_mask(rgb_image, seg_image) | |
| num_instances = len(instance_rgbs) | |
| outputs = pipe( | |
| image=instance_rgbs, | |
| mask=instance_masks, | |
| image_scene=scene_rgbs, | |
| attention_kwargs={"num_instances": num_instances}, | |
| generator=torch.Generator(device=pipe.device).manual_seed(seed), | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| decode_progressive=True, | |
| return_dict=False, | |
| ) | |
| return outputs | |
| def run_generation( | |
| rgb_image: Any, | |
| seg_image: Union[str, Image.Image], | |
| seed: int, | |
| randomize_seed: bool = False, | |
| num_inference_steps: int = 50, | |
| guidance_scale: float = 7.0, | |
| do_image_padding: bool = False, | |
| ): | |
| if randomize_seed: | |
| seed = random.randint(0, MAX_SEED) | |
| if not isinstance(rgb_image, Image.Image) and "image" in rgb_image: | |
| rgb_image = rgb_image["image"] | |
| outputs = run_midi( | |
| pipe, | |
| rgb_image, | |
| seg_image, | |
| seed, | |
| num_inference_steps, | |
| guidance_scale, | |
| do_image_padding, | |
| ) | |
| # marching cubes | |
| trimeshes = [] | |
| for _, (logits_, grid_size, bbox_size, bbox_min, bbox_max) in enumerate( | |
| zip(*outputs) | |
| ): | |
| grid_logits = logits_.view(grid_size) | |
| grid_logits = smooth_gpu(grid_logits, method="gaussian", sigma=1) | |
| torch.cuda.empty_cache() | |
| vertices, faces, normals, _ = measure.marching_cubes( | |
| grid_logits.float().cpu().numpy(), 0, method="lewiner" | |
| ) | |
| vertices = vertices / grid_size * bbox_size + bbox_min | |
| # Trimesh | |
| mesh = trimesh.Trimesh(vertices.astype(np.float32), np.ascontiguousarray(faces)) | |
| trimeshes.append(mesh) | |
| # compose the output meshes | |
| scene = trimesh.Scene(trimeshes) | |
| tmp_path = os.path.join(TMP_DIR, f"midi3d_{get_random_hex()}.glb") | |
| scene.export(tmp_path) | |
| torch.cuda.empty_cache() | |
| return tmp_path, tmp_path, seed | |
| # Demo | |
| with gr.Blocks() as demo: | |
| gr.Markdown(MARKDOWN) | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Row(): | |
| image_prompts = ImagePrompter(label="Input Image", type="pil") | |
| seg_image = gr.Image( | |
| label="Segmentation Result", type="pil", format="png" | |
| ) | |
| with gr.Accordion("Segmentation Settings", open=False): | |
| polygon_refinement = gr.Checkbox( | |
| label="Polygon Refinement", value=False | |
| ) | |
| seg_button = gr.Button("Run Segmentation") | |
| with gr.Accordion("Generation Settings", open=False): | |
| do_image_padding = gr.Checkbox(label="Do image padding", value=False) | |
| seed = gr.Slider( | |
| label="Seed", | |
| minimum=0, | |
| maximum=MAX_SEED, | |
| step=1, | |
| value=0, | |
| ) | |
| randomize_seed = gr.Checkbox(label="Randomize seed", value=True) | |
| num_inference_steps = gr.Slider( | |
| label="Number of inference steps", | |
| minimum=1, | |
| maximum=50, | |
| step=1, | |
| value=50, | |
| ) | |
| guidance_scale = gr.Slider( | |
| label="CFG scale", | |
| minimum=0.0, | |
| maximum=10.0, | |
| step=0.1, | |
| value=7.0, | |
| ) | |
| gen_button = gr.Button("Run Generation", variant="primary") | |
| with gr.Column(): | |
| model_output = LitModel3D(label="Generated GLB", exposure=1.0, height=500) | |
| download_glb = gr.DownloadButton(label="Download GLB", interactive=False) | |
| with gr.Row(): | |
| gr.Examples( | |
| examples=EXAMPLES, | |
| fn=run_generation, | |
| inputs=[image_prompts, seg_image, seed, randomize_seed, do_image_padding], | |
| outputs=[model_output, download_glb, seed], | |
| cache_examples=False, | |
| ) | |
| seg_button.click( | |
| run_segmentation, | |
| inputs=[ | |
| image_prompts, | |
| polygon_refinement, | |
| ], | |
| outputs=[seg_image], | |
| ).then(lambda: gr.Button(interactive=True), outputs=[gen_button]) | |
| gen_button.click( | |
| run_generation, | |
| inputs=[ | |
| image_prompts, | |
| seg_image, | |
| seed, | |
| randomize_seed, | |
| num_inference_steps, | |
| guidance_scale, | |
| do_image_padding, | |
| ], | |
| outputs=[model_output, download_glb, seed], | |
| ).then(lambda: gr.Button(interactive=True), outputs=[download_glb]) | |
| demo.launch() | |