Update app.py
Browse files
app.py
CHANGED
@@ -12,12 +12,8 @@ print(f"Running on device: {device}")
|
|
12 |
|
13 |
# Load models
|
14 |
segmenter = pipeline("mask-generation", model="facebook/sam-vit-huge")
|
15 |
-
upscaler = StableDiffusionUpscalePipeline.from_pretrained(
|
16 |
-
|
17 |
-
).to(device)
|
18 |
-
inpaint = StableDiffusionInpaintPipeline.from_pretrained(
|
19 |
-
"runwayml/stable-diffusion-inpainting" # Bỏ torch_dtype=torch.float16
|
20 |
-
).to(device)
|
21 |
|
22 |
# Helper functions
|
23 |
def upscale_image(image):
|
@@ -56,14 +52,14 @@ def process_image(base_image, paste_images_input, points_input, objects_input):
|
|
56 |
points = [list(map(int, p.split(","))) for p in points_input.split(";")]
|
57 |
objects = objects_input.split(";")
|
58 |
|
59 |
-
if len(points) != len(objects) or len(points)
|
60 |
-
raise ValueError("Number of points, objects, and paste images must match!")
|
61 |
|
62 |
upscaled_base = upscale_image(base_image)
|
63 |
masks = segment_image(upscaled_base, points)
|
64 |
refined_masks = [refine_mask(mask) for mask in masks]
|
65 |
inverted_masks = [invert_mask(mask) for mask in refined_masks]
|
66 |
-
seamless_pastes = [make_seamless(img) for img in paste_images
|
67 |
pasted_image = paste_by_mask(upscaled_base, seamless_pastes, refined_masks)
|
68 |
|
69 |
combined_prompt = ", ".join([f"modern {obj}" for obj in objects]) + " in high quality interior, 2025 trends"
|
@@ -83,16 +79,24 @@ interface = gr.Interface(
|
|
83 |
fn=process_image,
|
84 |
inputs=[
|
85 |
gr.Image(type="pil", label="Base Image (Interior)"),
|
86 |
-
gr.File(file_count="multiple", label="Paste Images (
|
87 |
-
gr.Textbox(
|
88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
],
|
90 |
outputs=[
|
91 |
gr.Image(label="Intermediate Image (After Pasting)"),
|
92 |
gr.Image(label="Final Enhanced Interior")
|
93 |
],
|
94 |
title="Interior Design Enhancer",
|
95 |
-
description="Upload a base image and new designs. Specify points and objects
|
96 |
)
|
97 |
|
98 |
interface.launch(share=True)
|
|
|
12 |
|
13 |
# Load models
|
14 |
segmenter = pipeline("mask-generation", model="facebook/sam-vit-huge")
|
15 |
+
upscaler = StableDiffusionUpscalePipeline.from_pretrained("stabilityai/stable-diffusion-x4-upscaler").to(device)
|
16 |
+
inpaint = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting").to(device)
|
|
|
|
|
|
|
|
|
17 |
|
18 |
# Helper functions
|
19 |
def upscale_image(image):
|
|
|
52 |
points = [list(map(int, p.split(","))) for p in points_input.split(";")]
|
53 |
objects = objects_input.split(";")
|
54 |
|
55 |
+
if len(points) != len(objects) or len(points) != len(paste_images):
|
56 |
+
raise ValueError(f"Number of points ({len(points)}), objects ({len(objects)}), and paste images ({len(paste_images)}) must match!")
|
57 |
|
58 |
upscaled_base = upscale_image(base_image)
|
59 |
masks = segment_image(upscaled_base, points)
|
60 |
refined_masks = [refine_mask(mask) for mask in masks]
|
61 |
inverted_masks = [invert_mask(mask) for mask in refined_masks]
|
62 |
+
seamless_pastes = [make_seamless(img) for img in paste_images]
|
63 |
pasted_image = paste_by_mask(upscaled_base, seamless_pastes, refined_masks)
|
64 |
|
65 |
combined_prompt = ", ".join([f"modern {obj}" for obj in objects]) + " in high quality interior, 2025 trends"
|
|
|
79 |
fn=process_image,
|
80 |
inputs=[
|
81 |
gr.Image(type="pil", label="Base Image (Interior)"),
|
82 |
+
gr.File(file_count="multiple", label="Paste Images (One per object)"),
|
83 |
+
gr.Textbox(
|
84 |
+
label="Points (x,y; separated by ';')",
|
85 |
+
value="500,500;600,600",
|
86 |
+
placeholder="e.g., '500,500;600,600' (one point per object)"
|
87 |
+
),
|
88 |
+
gr.Textbox(
|
89 |
+
label="Objects (separated by ';')",
|
90 |
+
value="counter;chair",
|
91 |
+
placeholder="e.g., 'counter;chair' (one object per image)"
|
92 |
+
)
|
93 |
],
|
94 |
outputs=[
|
95 |
gr.Image(label="Intermediate Image (After Pasting)"),
|
96 |
gr.Image(label="Final Enhanced Interior")
|
97 |
],
|
98 |
title="Interior Design Enhancer",
|
99 |
+
description="Upload a base image and new designs (one image per object). Specify points (x,y) and object names, separated by ';'. Ensure the number of points, objects, and images match (e.g., 2 points, 2 objects, 2 images)."
|
100 |
)
|
101 |
|
102 |
interface.launch(share=True)
|