|
import gradio as gr |
|
from transformers import pipeline |
|
from diffusers import StableDiffusionInpaintPipeline, StableDiffusionUpscalePipeline |
|
import torch |
|
import numpy as np |
|
from PIL import Image |
|
import cv2 |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
print(f"Running on device: {device}") |
|
|
|
|
|
segmenter = pipeline("mask-generation", model="facebook/sam-vit-huge") |
|
upscaler = StableDiffusionUpscalePipeline.from_pretrained("stabilityai/stable-diffusion-x4-upscaler").to(device) |
|
inpaint = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting").to(device) |
|
|
|
|
|
def upscale_image(image): |
|
return upscaler(prompt="high quality interior", image=image, num_inference_steps=20).images[0] |
|
|
|
def segment_image(image, points): |
|
masks = segmenter(image, points=points)["masks"] |
|
return [np.array(mask) for mask in masks] |
|
|
|
def refine_mask(mask): |
|
mask_cv = (mask * 255).astype(np.uint8) |
|
kernel = np.ones((5, 5), np.uint8) |
|
refined = cv2.dilate(cv2.erode(mask_cv, kernel, iterations=1), kernel, iterations=1) |
|
return refined / 255.0 |
|
|
|
def invert_mask(mask): |
|
return 1 - mask |
|
|
|
def make_seamless(image): |
|
img = np.array(image) |
|
h, w = img.shape[:2] |
|
seamless = cv2.seamlessClone(img, img, np.ones_like(img) * 255, (w // 2, h // 2), cv2.NORMAL_CLONE) |
|
return Image.fromarray(seamless) |
|
|
|
def paste_by_mask(base_image, paste_images, masks): |
|
base_array = np.array(base_image) |
|
result = base_array.copy() |
|
for paste_image, mask in zip(paste_images, masks): |
|
paste_array = np.array(paste_image.resize(base_image.size)) |
|
mask_array = mask[:, :, np.newaxis] |
|
result = np.where(mask_array, paste_array, result).astype(np.uint8) |
|
return Image.fromarray(result) |
|
|
|
def process_image(base_image, paste_images_input, points_input, objects_input): |
|
paste_images = paste_images_input if isinstance(paste_images_input, list) else [paste_images_input] |
|
points = [list(map(int, p.split(","))) for p in points_input.split(";")] |
|
objects = objects_input.split(";") |
|
|
|
if len(points) != len(objects) or len(points) != len(paste_images): |
|
raise ValueError(f"Number of points ({len(points)}), objects ({len(objects)}), and paste images ({len(paste_images)}) must match!") |
|
|
|
upscaled_base = upscale_image(base_image) |
|
masks = segment_image(upscaled_base, points) |
|
refined_masks = [refine_mask(mask) for mask in masks] |
|
inverted_masks = [invert_mask(mask) for mask in refined_masks] |
|
seamless_pastes = [make_seamless(img) for img in paste_images] |
|
pasted_image = paste_by_mask(upscaled_base, seamless_pastes, refined_masks) |
|
|
|
combined_prompt = ", ".join([f"modern {obj}" for obj in objects]) + " in high quality interior, 2025 trends" |
|
combined_mask = np.max([mask for mask in inverted_masks], axis=0) |
|
final_image = inpaint( |
|
prompt=combined_prompt, |
|
image=pasted_image, |
|
mask_image=Image.fromarray(combined_mask * 255).convert("L"), |
|
num_inference_steps=20, |
|
guidance_scale=3.5 |
|
).images[0] |
|
|
|
return pasted_image, final_image |
|
|
|
|
|
interface = gr.Interface( |
|
fn=process_image, |
|
inputs=[ |
|
gr.Image(type="pil", label="Base Image (Interior)"), |
|
gr.File(file_count="multiple", label="Paste Images (One per object)"), |
|
gr.Textbox( |
|
label="Points (x,y; separated by ';')", |
|
value="500,500;600,600", |
|
placeholder="e.g., '500,500;600,600' (one point per object)" |
|
), |
|
gr.Textbox( |
|
label="Objects (separated by ';')", |
|
value="counter;chair", |
|
placeholder="e.g., 'counter;chair' (one object per image)" |
|
) |
|
], |
|
outputs=[ |
|
gr.Image(label="Intermediate Image (After Pasting)"), |
|
gr.Image(label="Final Enhanced Interior") |
|
], |
|
title="Interior Design Enhancer", |
|
description="Upload a base image and new designs (one image per object). Specify points (x,y) and object names, separated by ';'. Ensure the number of points, objects, and images match (e.g., 2 points, 2 objects, 2 images)." |
|
) |
|
|
|
interface.launch(share=True) |