Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import torch | |
| import torch.nn as nn | |
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image | |
| from omegaconf import OmegaConf | |
| from pytorch_lightning import seed_everything | |
| from huggingface_hub import hf_hub_download | |
| ""||||||||||||||||||||"from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler | |
| from einops import rearrange | |
| from shap_e.diffusion.sample import sample_latents | |
| from shap_e.diffusion.gaussian_diffusion import diffusion_from_config | |
| from shap_e.models.download import load_model, load_config | |
| from shap_e.util.notebooks import create_pan_cameras, decode_latent_images, create_custom_cameras | |
| from src.utils.train_util import instantiate_from_config | |
| from src.utils.camera_util import ( | |
| FOV_to_intrinsics, | |
| get_zero123plus_input_cameras, | |
| get_circular_camera_poses, | |
| spherical_camera_pose | |
| ) | |
| from src.utils.mesh_util import save_obj, save_glb | |
| from src.utils.infer_util import remove_background, resize_foreground | |
| def load_models(): | |
| """Initialize and load all required models""" | |
| config = OmegaConf.load('configs/instant-nerf-large-best.yaml') | |
| model_config = config.model_config | |
| infer_config = config.infer_config | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # Load diffusion pipeline | |
| print('Loading diffusion pipeline...') | |
| pipeline = DiffusionPipeline.from_pretrained( | |
| "sudo-ai/zero123plus-v1.2", | |
| custom_pipeline="zero123plus", | |
| torch_dtype=torch.float16 | |
| ) | |
| pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config( | |
| pipeline.scheduler.config, timestep_spacing='trailing' | |
| ) | |
| # Modify UNet to handle 8 input channels instead of 4 | |
| in_channels = 8 | |
| out_channels = pipeline.unet.conv_in.out_channels | |
| pipeline.unet.register_to_config(in_channels=in_channels) | |
| with torch.no_grad(): | |
| new_conv_in = nn.Conv2d( | |
| in_channels, out_channels, pipeline.unet.conv_in.kernel_size, | |
| pipeline.unet.conv_in.stride, pipeline.unet.conv_in.padding | |
| ) | |
| new_conv_in.weight.zero_() | |
| new_conv_in.weight[:, :4, :, :].copy_(pipeline.unet.conv_in.weight) | |
| pipeline.unet.conv_in = new_conv_in | |
| # Load custom UNet | |
| print('Loading custom UNet...') | |
| unet_path = "best_21.ckpt" | |
| state_dict = torch.load(unet_path, map_location='cpu') | |
| # Process the state dict to match the model keys | |
| if 'state_dict' in state_dict: | |
| new_state_dict = {key.replace('unet.unet.', ''): value for key, value in state_dict['state_dict'].items()} | |
| pipeline.unet.load_state_dict(new_state_dict, strict=False) | |
| else: | |
| pipeline.unet.load_state_dict(state_dict, strict=False) | |
| pipeline = pipeline.to(device).to(torch_dtype=torch.float16) | |
| # Load reconstruction model | |
| print('Loading reconstruction model...') | |
| model = instantiate_from_config(model_config) | |
| model_path = hf_hub_download( | |
| repo_id="TencentARC/InstantMesh", | |
| filename="instant_nerf_large.ckpt", | |
| repo_type="model" | |
| ) | |
| state_dict = torch.load(model_path, map_location='cpu')['state_dict'] | |
| state_dict = {k[14:]: v for k, v in state_dict.items() | |
| if k.startswith('lrm_generator.') and 'source_camera' not in k} | |
| model.load_state_dict(state_dict, strict=True) | |
| model = model.to(device) | |
| model.eval() | |
| return pipeline, model, infer_config | |
| def process_images(input_images, prompt, steps=75, guidance_scale=7.5, pipeline=None): | |
| """Process input images and run refinement""" | |
| device = pipeline.device | |
| if isinstance(input_images, list): | |
| if len(input_images) == 1: | |
| # Check if this is a pre-arranged layout | |
| img = Image.open(input_images[0].name).convert('RGB') | |
| if img.size == (640, 960): | |
| # This is already a layout, use it directly | |
| input_image = img | |
| else: | |
| # Single view - need 6 copies | |
| img = img.resize((320, 320)) | |
| img_array = np.array(img) / 255.0 | |
| images = [img_array] * 6 | |
| images = np.stack(images) | |
| # Convert to tensor and create layout | |
| images = torch.from_numpy(images).float() | |
| images = images.permute(0, 3, 1, 2) | |
| images = images.reshape(3, 2, 3, 320, 320) | |
| images = images.permute(0, 2, 3, 1, 4) | |
| images = images.reshape(3, 3, 320, 640) | |
| images = images.reshape(1, 3, 960, 640) | |
| # Convert back to PIL | |
| images = images.permute(0, 2, 3, 1)[0] | |
| images = (images.numpy() * 255).astype(np.uint8) | |
| input_image = Image.fromarray(images) | |
| else: | |
| # Multiple individual views | |
| images = [] | |
| for img_file in input_images: | |
| img = Image.open(img_file.name).convert('RGB') | |
| img = img.resize((320, 320)) | |
| img = np.array(img) / 255.0 | |
| images.append(img) | |
| # Pad to 6 images if needed | |
| while len(images) < 6: | |
| images.append(np.zeros_like(images[0])) | |
| images = np.stack(images[:6]) | |
| # Convert to tensor and create layout | |
| images = torch.from_numpy(images).float() | |
| images = images.permute(0, 3, 1, 2) | |
| images = images.reshape(3, 2, 3, 320, 320) | |
| images = images.permute(0, 2, 3, 1, 4) | |
| images = images.reshape(3, 3, 320, 640) | |
| images = images.reshape(1, 3, 960, 640) | |
| # Convert back to PIL | |
| images = images.permute(0, 2, 3, 1)[0] | |
| images = (images.numpy() * 255).astype(np.uint8) | |
| input_image = Image.fromarray(images) | |
| else: | |
| raise ValueError("Expected a list of images") | |
| # Generate refined output | |
| output = pipeline.refine( | |
| input_image, | |
| prompt=prompt, | |
| num_inference_steps=int(steps), | |
| guidance_scale=guidance_scale | |
| ).images[0] | |
| return output, input_image | |
| def create_mesh(refined_image, model, infer_config): | |
| """Generate mesh from refined image""" | |
| # Convert PIL image to tensor | |
| image = np.array(refined_image) / 255.0 | |
| image = torch.from_numpy(image).float().permute(2, 0, 1) | |
| # Reshape to 6 views | |
| image = image.reshape(3, 960, 640) | |
| image = image.reshape(3, 3, 320, 640) | |
| image = image.permute(1, 0, 2, 3) | |
| image = image.reshape(3, 3, 320, 2, 320) | |
| image = image.permute(0, 3, 1, 2, 4) | |
| image = image.reshape(6, 3, 320, 320) | |
| # Add batch dimension | |
| image = image.unsqueeze(0) | |
| input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=4.0).to("cuda") | |
| image = image.to("cuda") | |
| with torch.no_grad(): | |
| planes = model.forward_planes(image, input_cameras) | |
| mesh_out = model.extract_mesh(planes, **infer_config) | |
| vertices, faces, vertex_colors = mesh_out | |
| return vertices, faces, vertex_colors | |
| class ShapERenderer: | |
| def __init__(self, device): | |
| print("Loading Shap-E models...") | |
| self.device = device | |
| self.xm = load_model('transmitter', device=device) | |
| self.model = load_model('text300M', device=device) | |
| self.diffusion = diffusion_from_config(load_config('diffusion')) | |
| print("Shap-E models loaded!") | |
| def generate_views(self, prompt, guidance_scale=15.0, num_steps=64): | |
| # Generate latents using the text-to-3D model | |
| batch_size = 1 | |
| guidance_scale = float(guidance_scale) | |
| latents = sample_latents( | |
| batch_size=batch_size, | |
| model=self.model, | |
| diffusion=self.diffusion, | |
| guidance_scale=guidance_scale, | |
| model_kwargs=dict(texts=[prompt] * batch_size), | |
| progress=True, | |
| clip_denoised=True, | |
| use_fp16=True, | |
| use_karras=True, | |
| karras_steps=num_steps, | |
| sigma_min=1e-3, | |
| sigma_max=160, | |
| s_churn=0, | |
| ) | |
| # Render the 6 views we need with specific viewing angles | |
| size = 320 # Size of each rendered image | |
| images = [] | |
| # Define our 6 specific camera positions to match refine.py | |
| azimuths = [30, 90, 150, 210, 270, 330] | |
| elevations = [20, -10, 20, -10, 20, -10] | |
| for i, (azimuth, elevation) in enumerate(zip(azimuths, elevations)): | |
| cameras = create_custom_cameras(size, self.device, azimuths=[azimuth], elevations=[elevation], fov_degrees=30, distance=3.0) | |
| rendered_image = decode_latent_images( | |
| self.xm, | |
| latents[0], | |
| rendering_mode='stf', | |
| cameras=cameras | |
| ) | |
| images.append(rendered_image.detach().cpu().numpy()) | |
| # Convert images to uint8 | |
| images = [(image).astype(np.uint8) for image in images] | |
| # Create 2x3 grid layout (640x960) instead of 3x2 (960x640) | |
| layout = np.zeros((960, 640, 3), dtype=np.uint8) | |
| for i, img in enumerate(images): | |
| row = i // 2 # Now 3 images per row | |
| col = i % 2 # Now 3 images per row | |
| layout[row*320:(row+1)*320, col*320:(col+1)*320] = img | |
| return Image.fromarray(layout), images | |
| class RefinerInterface: | |
| def __init__(self): | |
| print("Initializing InstantMesh models...") | |
| self.pipeline, self.model, self.infer_config = load_models() | |
| print("InstantMesh models loaded!") | |
| def refine_model(self, input_image, prompt, steps=75, guidance_scale=7.5): | |
| """Main refinement function""" | |
| # Process image and get refined output | |
| input_image = Image.fromarray(input_image) | |
| # Rotate the layout if needed (if we're getting a 640x960 layout but pipeline expects 960x640) | |
| if input_image.width == 960 and input_image.height == 640: | |
| # Transpose the image to get 960x640 layout | |
| input_array = np.array(input_image) | |
| new_layout = np.zeros((960, 640, 3), dtype=np.uint8) | |
| # Rearrange from 2x3 to 3x2 | |
| for i in range(6): | |
| src_row = i // 3 | |
| src_col = i % 3 | |
| dst_row = i // 2 | |
| dst_col = i % 2 | |
| new_layout[dst_row*320:(dst_row+1)*320, dst_col*320:(dst_col+1)*320] = \ | |
| input_array[src_row*320:(src_row+1)*320, src_col*320:(src_col+1)*320] | |
| input_image = Image.fromarray(new_layout) | |
| # Process with the pipeline (expects 960x640) | |
| refined_output_960x640 = self.pipeline.refine( | |
| input_image, | |
| prompt=prompt, | |
| num_inference_steps=int(steps), | |
| guidance_scale=guidance_scale | |
| ).images[0] | |
| # Generate mesh using the 960x640 format | |
| vertices, faces, vertex_colors = create_mesh( | |
| refined_output_960x640, | |
| self.model, | |
| self.infer_config | |
| ) | |
| # Save temporary mesh file | |
| os.makedirs("temp", exist_ok=True) | |
| temp_obj = os.path.join("temp", "refined_mesh.obj") | |
| save_obj(vertices, faces, vertex_colors, temp_obj) | |
| # Convert the output to 640x960 for display | |
| refined_array = np.array(refined_output_960x640) | |
| display_layout = np.zeros((960, 640, 3), dtype=np.uint8) | |
| # Rearrange from 3x2 to 2x3 | |
| for i in range(6): | |
| src_row = i // 2 | |
| src_col = i % 2 | |
| dst_row = i // 2 | |
| dst_col = i % 2 | |
| display_layout[dst_row*320:(dst_row+1)*320, dst_col*320:(dst_col+1)*320] = \ | |
| refined_array[src_row*320:(src_row+1)*320, src_col*320:(src_col+1)*320] | |
| refined_output_640x960 = Image.fromarray(display_layout) | |
| return refined_output_640x960, temp_obj | |
| def create_demo(): | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| shap_e = ShapERenderer(device) | |
| refiner = RefinerInterface() | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Shap-E to InstantMesh Pipeline") | |
| # First row: Controls | |
| with gr.Row(): | |
| with gr.Column(): | |
| # Shap-E inputs | |
| shape_prompt = gr.Textbox( | |
| label="Shap-E Prompt", | |
| placeholder="Enter text to generate initial 3D model..." | |
| ) | |
| shape_guidance = gr.Slider( | |
| minimum=1, | |
| maximum=30, | |
| value=15.0, | |
| label="Shap-E Guidance Scale" | |
| ) | |
| shape_steps = gr.Slider( | |
| minimum=16, | |
| maximum=128, | |
| value=64, | |
| step=16, | |
| label="Shap-E Steps" | |
| ) | |
| generate_btn = gr.Button("Generate Views") | |
| with gr.Column(): | |
| # Refinement inputs | |
| refine_prompt = gr.Textbox( | |
| label="Refinement Prompt", | |
| placeholder="Enter prompt to guide refinement..." | |
| ) | |
| refine_steps = gr.Slider( | |
| minimum=30, | |
| maximum=100, | |
| value=75, | |
| step=1, | |
| label="Refinement Steps" | |
| ) | |
| refine_guidance = gr.Slider( | |
| minimum=1, | |
| maximum=20, | |
| value=7.5, | |
| label="Refinement Guidance Scale" | |
| ) | |
| refine_btn = gr.Button("Refine") | |
| # Second row: Image panels side by side | |
| with gr.Row(): | |
| # Outputs - Images side by side | |
| shape_output = gr.Image( | |
| label="Generated Views", | |
| width=640, # Swapped dimensions | |
| height=960 # Swapped dimensions | |
| ) | |
| refined_output = gr.Image( | |
| label="Refined Output", | |
| width=640, # Swapped dimensions | |
| height=960 # Swapped dimensions | |
| ) | |
| # Third row: 3D mesh panel below | |
| with gr.Row(): | |
| # 3D mesh centered | |
| mesh_output = gr.Model3D( | |
| label="3D Mesh", | |
| clear_color=[1.0, 1.0, 1.0, 1.0], | |
| width=1280, # Full width | |
| height=600 # Taller for better visualization | |
| ) | |
| # Set up event handlers | |
| def generate(prompt, guidance_scale, num_steps): | |
| with torch.no_grad(): | |
| layout, _ = shap_e.generate_views(prompt, guidance_scale, num_steps) | |
| return layout | |
| def refine(input_image, prompt, steps, guidance_scale): | |
| refined_img, mesh_path = refiner.refine_model( | |
| input_image, | |
| prompt, | |
| steps, | |
| guidance_scale | |
| ) | |
| return refined_img, mesh_path | |
| generate_btn.click( | |
| fn=generate, | |
| inputs=[shape_prompt, shape_guidance, shape_steps], | |
| outputs=[shape_output] | |
| ) | |
| refine_btn.click( | |
| fn=refine, | |
| inputs=[shape_output, refine_prompt, refine_steps, refine_guidance], | |
| outputs=[refined_output, mesh_output] | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = create_demo() | |
| demo.launch(share=True) |