hh / app.py
TDN-M's picture
Update app.py
4f728c6 verified
import gradio as gr
from transformers import pipeline
from diffusers import StableDiffusionInpaintPipeline, StableDiffusionUpscalePipeline
import torch
import numpy as np
from PIL import Image
import cv2
# Check device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Running on device: {device}")
# Load models
segmenter = pipeline("mask-generation", model="facebook/sam-vit-huge")
upscaler = StableDiffusionUpscalePipeline.from_pretrained("stabilityai/stable-diffusion-x4-upscaler").to(device)
inpaint = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting").to(device)
# Helper functions
def upscale_image(image):
return upscaler(prompt="high quality interior", image=image, num_inference_steps=20).images[0]
def segment_image(image, points):
masks = segmenter(image, points=points)["masks"]
return [np.array(mask) for mask in masks]
def refine_mask(mask):
mask_cv = (mask * 255).astype(np.uint8)
kernel = np.ones((5, 5), np.uint8)
refined = cv2.dilate(cv2.erode(mask_cv, kernel, iterations=1), kernel, iterations=1)
return refined / 255.0
def invert_mask(mask):
return 1 - mask
def make_seamless(image):
img = np.array(image)
h, w = img.shape[:2]
seamless = cv2.seamlessClone(img, img, np.ones_like(img) * 255, (w // 2, h // 2), cv2.NORMAL_CLONE)
return Image.fromarray(seamless)
def paste_by_mask(base_image, paste_images, masks):
base_array = np.array(base_image)
result = base_array.copy()
for paste_image, mask in zip(paste_images, masks):
paste_array = np.array(paste_image.resize(base_image.size))
mask_array = mask[:, :, np.newaxis]
result = np.where(mask_array, paste_array, result).astype(np.uint8)
return Image.fromarray(result)
def process_image(base_image, paste_images_input, points_input, objects_input):
paste_images = paste_images_input if isinstance(paste_images_input, list) else [paste_images_input]
points = [list(map(int, p.split(","))) for p in points_input.split(";")]
objects = objects_input.split(";")
if len(points) != len(objects) or len(points) != len(paste_images):
raise ValueError(f"Number of points ({len(points)}), objects ({len(objects)}), and paste images ({len(paste_images)}) must match!")
upscaled_base = upscale_image(base_image)
masks = segment_image(upscaled_base, points)
refined_masks = [refine_mask(mask) for mask in masks]
inverted_masks = [invert_mask(mask) for mask in refined_masks]
seamless_pastes = [make_seamless(img) for img in paste_images]
pasted_image = paste_by_mask(upscaled_base, seamless_pastes, refined_masks)
combined_prompt = ", ".join([f"modern {obj}" for obj in objects]) + " in high quality interior, 2025 trends"
combined_mask = np.max([mask for mask in inverted_masks], axis=0)
final_image = inpaint(
prompt=combined_prompt,
image=pasted_image,
mask_image=Image.fromarray(combined_mask * 255).convert("L"),
num_inference_steps=20,
guidance_scale=3.5
).images[0]
return pasted_image, final_image
# Gradio interface
interface = gr.Interface(
fn=process_image,
inputs=[
gr.Image(type="pil", label="Base Image (Interior)"),
gr.File(file_count="multiple", label="Paste Images (One per object)"),
gr.Textbox(
label="Points (x,y; separated by ';')",
value="500,500;600,600",
placeholder="e.g., '500,500;600,600' (one point per object)"
),
gr.Textbox(
label="Objects (separated by ';')",
value="counter;chair",
placeholder="e.g., 'counter;chair' (one object per image)"
)
],
outputs=[
gr.Image(label="Intermediate Image (After Pasting)"),
gr.Image(label="Final Enhanced Interior")
],
title="Interior Design Enhancer",
description="Upload a base image and new designs (one image per object). Specify points (x,y) and object names, separated by ';'. Ensure the number of points, objects, and images match (e.g., 2 points, 2 objects, 2 images)."
)
interface.launch(share=True)