import gradio as gr import numpy as np from PIL import Image, ImageDraw import torch from transformers import SamModel, SamProcessor from diffusers import StableDiffusionInpaintPipeline # Constants IMG_SIZE = 512 def generate_mask(image, points): """ Generates a mask using SAM based on input points. """ if not points: return None # Initialize SAM model and processor on CPU sam_model = SamModel.from_pretrained("facebook/sam-vit-huge", torch_dtype=torch.float32).to("cpu") sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") inputs = sam_processor(image, points=points, return_tensors="pt").to("cpu") with torch.no_grad(): outputs = sam_model(**inputs) masks = sam_processor.image_processor.post_process_masks( outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu() ) if len(masks) == 0: return None best_mask = masks[0][0][outputs.iou_scores.argmax()] binary_mask = ~best_mask.numpy().astype(bool).astype(int) return binary_mask def replace_object(image, mask, prompt, negative_prompt, seed, guidance_scale): """ Replaces the object in the image based on the mask and prompt. """ if mask is None: return image # Initialize Inpainting pipeline on CPU with a compatible model inpaint_pipeline = StableDiffusionInpaintPipeline.from_pretrained( "stabilityai/stable-diffusion-2-inpainting", torch_dtype=torch.float32 ).to("cpu") mask_image = Image.fromarray((mask * 255).astype(np.uint8)) generator = torch.Generator("cpu").manual_seed(seed) try: result = inpaint_pipeline( prompt=prompt, image=image, mask_image=mask_image, negative_prompt=negative_prompt if negative_prompt else None, generator=generator, guidance_scale=guidance_scale ).images[0] return result except Exception as e: print(f"Inpainting error: {e}") return image def visualize_mask(image, mask): """ Overlays the mask on the image for visualization. """ if mask is None: return image bg_transparent = np.zeros(mask.shape + (4,), dtype=np.uint8) bg_transparent[mask == 1] = [0, 255, 0, 127] # Green with transparency mask_rgba = Image.fromarray(bg_transparent) overlay = Image.alpha_composite(image.convert("RGBA"), mask_rgba) return overlay.convert("RGB") def get_points(img, evt: gr.SelectData, input_points): """ Captures points selected by the user on the image. """ x, y = evt.index input_points.append([x, y]) # Generate mask based on selected points mask = generate_mask(img, input_points) # Mark selected points with a green crossmark draw = ImageDraw.Draw(img) size = 10 for point in input_points: px, py = point draw.line((px - size, py, px + size, py), fill="green", width=5) draw.line((px, py - size, px, py + size), fill="green", width=5) # Visualize the mask overlay masked_image = visualize_mask(img, mask) return masked_image, input_points def run_inpaint(prompt, negative_prompt, cfg, seed, invert, input_image, input_points): """ Runs the inpainting process based on user inputs. """ if input_image is None or len(input_points) == 0: raise gr.Error("No points provided. Click on the image to select the object to segment with SAM.") mask = generate_mask(input_image, input_points) if invert: mask = ~mask try: inpainted = replace_object(input_image, mask, prompt, negative_prompt, seed, cfg) except Exception as e: raise gr.Error(str(e)) return inpainted.resize((IMG_SIZE, IMG_SIZE)) def preprocess(input_img): """ Preprocesses the uploaded image to ensure it is square and resized. """ if input_img is None: return None width, height = input_img.size if width != height: # Add white padding to make the image square new_size = max(width, height) new_image = Image.new("RGB", (new_size, new_size), 'white') left = (new_size - width) // 2 top = (new_size - height) // 2 new_image.paste(input_img, (left, top)) input_img = new_image return input_img.resize((IMG_SIZE, IMG_SIZE)) # Gradio Interface with gr.Blocks() as demo: gr.Markdown("# Object Replacement with SAM and Stable Diffusion Inpainting") gr.Markdown("Upload an image, click on the object you want to replace, and generate a new image.") with gr.Row(): with gr.Column(): input_image = gr.Image(label="Upload Image", type="pil") output_image = gr.Image(label="Generated Image", type="pil") input_points = gr.State([]) # Store selected points with gr.Column(): prompt = gr.Textbox(label="Prompt for Inpainting") negative_prompt = gr.Textbox(label="Negative Prompt (Optional)") cfg = gr.Slider(1, 20, value=7.5, label="Guidance Scale") seed = gr.Number(value=42, label="Seed") invert = gr.Checkbox(label="Invert Mask") run_button = gr.Button("Run Inpainting") reset_button = gr.Button("Reset Points") input_image.select(get_points, inputs=[input_image, input_points], outputs=[output_image, input_points]) run_button.click( run_inpaint, inputs=[prompt, negative_prompt, cfg, seed, invert, input_image, input_points], outputs=output_image ) reset_button.click(lambda: (None, []), outputs=[output_image, input_points]) demo.launch()