segmask / app.py
TDN-M's picture
Update app.py
5729cc3 verified
import gradio as gr
import numpy as np
from PIL import Image, ImageDraw
import torch
from transformers import SamModel, SamProcessor
from diffusers import StableDiffusionInpaintPipeline
# Constants
IMG_SIZE = 512
def generate_mask(image, points):
"""
Generates a mask using SAM based on input points.
"""
if not points:
return None
# Initialize SAM model and processor on CPU
sam_model = SamModel.from_pretrained("facebook/sam-vit-huge", torch_dtype=torch.float32).to("cpu")
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
inputs = sam_processor(image, points=points, return_tensors="pt").to("cpu")
with torch.no_grad():
outputs = sam_model(**inputs)
masks = sam_processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu()
)
if len(masks) == 0:
return None
best_mask = masks[0][0][outputs.iou_scores.argmax()]
binary_mask = ~best_mask.numpy().astype(bool).astype(int)
return binary_mask
def replace_object(image, mask, prompt, negative_prompt, seed, guidance_scale):
"""
Replaces the object in the image based on the mask and prompt.
"""
if mask is None:
return image
# Initialize Inpainting pipeline on CPU with a compatible model
inpaint_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-inpainting",
torch_dtype=torch.float32
).to("cpu")
mask_image = Image.fromarray((mask * 255).astype(np.uint8))
generator = torch.Generator("cpu").manual_seed(seed)
try:
result = inpaint_pipeline(
prompt=prompt,
image=image,
mask_image=mask_image,
negative_prompt=negative_prompt if negative_prompt else None,
generator=generator,
guidance_scale=guidance_scale
).images[0]
return result
except Exception as e:
print(f"Inpainting error: {e}")
return image
def visualize_mask(image, mask):
"""
Overlays the mask on the image for visualization.
"""
if mask is None:
return image
bg_transparent = np.zeros(mask.shape + (4,), dtype=np.uint8)
bg_transparent[mask == 1] = [0, 255, 0, 127] # Green with transparency
mask_rgba = Image.fromarray(bg_transparent)
overlay = Image.alpha_composite(image.convert("RGBA"), mask_rgba)
return overlay.convert("RGB")
def get_points(img, evt: gr.SelectData, input_points):
"""
Captures points selected by the user on the image.
"""
x, y = evt.index
input_points.append([x, y])
# Generate mask based on selected points
mask = generate_mask(img, input_points)
# Mark selected points with a green crossmark
draw = ImageDraw.Draw(img)
size = 10
for point in input_points:
px, py = point
draw.line((px - size, py, px + size, py), fill="green", width=5)
draw.line((px, py - size, px, py + size), fill="green", width=5)
# Visualize the mask overlay
masked_image = visualize_mask(img, mask)
return masked_image, input_points
def run_inpaint(prompt, negative_prompt, cfg, seed, invert, input_image, input_points):
"""
Runs the inpainting process based on user inputs.
"""
if input_image is None or len(input_points) == 0:
raise gr.Error("No points provided. Click on the image to select the object to segment with SAM.")
mask = generate_mask(input_image, input_points)
if invert:
mask = ~mask
try:
inpainted = replace_object(input_image, mask, prompt, negative_prompt, seed, cfg)
except Exception as e:
raise gr.Error(str(e))
return inpainted.resize((IMG_SIZE, IMG_SIZE))
def preprocess(input_img):
"""
Preprocesses the uploaded image to ensure it is square and resized.
"""
if input_img is None:
return None
width, height = input_img.size
if width != height:
# Add white padding to make the image square
new_size = max(width, height)
new_image = Image.new("RGB", (new_size, new_size), 'white')
left = (new_size - width) // 2
top = (new_size - height) // 2
new_image.paste(input_img, (left, top))
input_img = new_image
return input_img.resize((IMG_SIZE, IMG_SIZE))
# Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("# Object Replacement with SAM and Stable Diffusion Inpainting")
gr.Markdown("Upload an image, click on the object you want to replace, and generate a new image.")
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Upload Image", type="pil")
output_image = gr.Image(label="Generated Image", type="pil")
input_points = gr.State([]) # Store selected points
with gr.Column():
prompt = gr.Textbox(label="Prompt for Inpainting")
negative_prompt = gr.Textbox(label="Negative Prompt (Optional)")
cfg = gr.Slider(1, 20, value=7.5, label="Guidance Scale")
seed = gr.Number(value=42, label="Seed")
invert = gr.Checkbox(label="Invert Mask")
run_button = gr.Button("Run Inpainting")
reset_button = gr.Button("Reset Points")
input_image.select(get_points, inputs=[input_image, input_points], outputs=[output_image, input_points])
run_button.click(
run_inpaint,
inputs=[prompt, negative_prompt, cfg, seed, invert, input_image, input_points],
outputs=output_image
)
reset_button.click(lambda: (None, []), outputs=[output_image, input_points])
demo.launch()