File size: 5,731 Bytes
51d89a8 5729cc3 3c568cd 5729cc3 3c568cd 5729cc3 51d89a8 5729cc3 51d89a8 5729cc3 51d89a8 5729cc3 51d89a8 5729cc3 51d89a8 5729cc3 51d89a8 5729cc3 51d89a8 5729cc3 51d89a8 5729cc3 51d89a8 5729cc3 51d89a8 5729cc3 51d89a8 5729cc3 51d89a8 5729cc3 51d89a8 5729cc3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
import gradio as gr
import numpy as np
from PIL import Image, ImageDraw
import torch
from transformers import SamModel, SamProcessor
from diffusers import StableDiffusionInpaintPipeline
# Constants
IMG_SIZE = 512
def generate_mask(image, points):
"""
Generates a mask using SAM based on input points.
"""
if not points:
return None
# Initialize SAM model and processor on CPU
sam_model = SamModel.from_pretrained("facebook/sam-vit-huge", torch_dtype=torch.float32).to("cpu")
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
inputs = sam_processor(image, points=points, return_tensors="pt").to("cpu")
with torch.no_grad():
outputs = sam_model(**inputs)
masks = sam_processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu()
)
if len(masks) == 0:
return None
best_mask = masks[0][0][outputs.iou_scores.argmax()]
binary_mask = ~best_mask.numpy().astype(bool).astype(int)
return binary_mask
def replace_object(image, mask, prompt, negative_prompt, seed, guidance_scale):
"""
Replaces the object in the image based on the mask and prompt.
"""
if mask is None:
return image
# Initialize Inpainting pipeline on CPU with a compatible model
inpaint_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-inpainting",
torch_dtype=torch.float32
).to("cpu")
mask_image = Image.fromarray((mask * 255).astype(np.uint8))
generator = torch.Generator("cpu").manual_seed(seed)
try:
result = inpaint_pipeline(
prompt=prompt,
image=image,
mask_image=mask_image,
negative_prompt=negative_prompt if negative_prompt else None,
generator=generator,
guidance_scale=guidance_scale
).images[0]
return result
except Exception as e:
print(f"Inpainting error: {e}")
return image
def visualize_mask(image, mask):
"""
Overlays the mask on the image for visualization.
"""
if mask is None:
return image
bg_transparent = np.zeros(mask.shape + (4,), dtype=np.uint8)
bg_transparent[mask == 1] = [0, 255, 0, 127] # Green with transparency
mask_rgba = Image.fromarray(bg_transparent)
overlay = Image.alpha_composite(image.convert("RGBA"), mask_rgba)
return overlay.convert("RGB")
def get_points(img, evt: gr.SelectData, input_points):
"""
Captures points selected by the user on the image.
"""
x, y = evt.index
input_points.append([x, y])
# Generate mask based on selected points
mask = generate_mask(img, input_points)
# Mark selected points with a green crossmark
draw = ImageDraw.Draw(img)
size = 10
for point in input_points:
px, py = point
draw.line((px - size, py, px + size, py), fill="green", width=5)
draw.line((px, py - size, px, py + size), fill="green", width=5)
# Visualize the mask overlay
masked_image = visualize_mask(img, mask)
return masked_image, input_points
def run_inpaint(prompt, negative_prompt, cfg, seed, invert, input_image, input_points):
"""
Runs the inpainting process based on user inputs.
"""
if input_image is None or len(input_points) == 0:
raise gr.Error("No points provided. Click on the image to select the object to segment with SAM.")
mask = generate_mask(input_image, input_points)
if invert:
mask = ~mask
try:
inpainted = replace_object(input_image, mask, prompt, negative_prompt, seed, cfg)
except Exception as e:
raise gr.Error(str(e))
return inpainted.resize((IMG_SIZE, IMG_SIZE))
def preprocess(input_img):
"""
Preprocesses the uploaded image to ensure it is square and resized.
"""
if input_img is None:
return None
width, height = input_img.size
if width != height:
# Add white padding to make the image square
new_size = max(width, height)
new_image = Image.new("RGB", (new_size, new_size), 'white')
left = (new_size - width) // 2
top = (new_size - height) // 2
new_image.paste(input_img, (left, top))
input_img = new_image
return input_img.resize((IMG_SIZE, IMG_SIZE))
# Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("# Object Replacement with SAM and Stable Diffusion Inpainting")
gr.Markdown("Upload an image, click on the object you want to replace, and generate a new image.")
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Upload Image", type="pil")
output_image = gr.Image(label="Generated Image", type="pil")
input_points = gr.State([]) # Store selected points
with gr.Column():
prompt = gr.Textbox(label="Prompt for Inpainting")
negative_prompt = gr.Textbox(label="Negative Prompt (Optional)")
cfg = gr.Slider(1, 20, value=7.5, label="Guidance Scale")
seed = gr.Number(value=42, label="Seed")
invert = gr.Checkbox(label="Invert Mask")
run_button = gr.Button("Run Inpainting")
reset_button = gr.Button("Reset Points")
input_image.select(get_points, inputs=[input_image, input_points], outputs=[output_image, input_points])
run_button.click(
run_inpaint,
inputs=[prompt, negative_prompt, cfg, seed, invert, input_image, input_points],
outputs=output_image
)
reset_button.click(lambda: (None, []), outputs=[output_image, input_points])
demo.launch() |