File size: 5,731 Bytes
51d89a8
5729cc3
 
3c568cd
5729cc3
 
 
 
 
3c568cd
5729cc3
51d89a8
5729cc3
51d89a8
5729cc3
 
 
 
 
 
51d89a8
5729cc3
51d89a8
5729cc3
 
51d89a8
5729cc3
 
 
 
 
 
 
 
 
 
 
51d89a8
 
5729cc3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51d89a8
5729cc3
51d89a8
5729cc3
 
51d89a8
5729cc3
 
 
 
 
 
 
 
 
 
 
 
 
 
51d89a8
5729cc3
 
51d89a8
5729cc3
 
 
 
 
 
 
51d89a8
5729cc3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51d89a8
5729cc3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import gradio as gr
import numpy as np
from PIL import Image, ImageDraw
import torch
from transformers import SamModel, SamProcessor
from diffusers import StableDiffusionInpaintPipeline

# Constants
IMG_SIZE = 512

def generate_mask(image, points):
    """
    Generates a mask using SAM based on input points.
    """
    if not points:
        return None

    # Initialize SAM model and processor on CPU
    sam_model = SamModel.from_pretrained("facebook/sam-vit-huge", torch_dtype=torch.float32).to("cpu")
    sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")

    inputs = sam_processor(image, points=points, return_tensors="pt").to("cpu")

    with torch.no_grad():
        outputs = sam_model(**inputs)

    masks = sam_processor.image_processor.post_process_masks(
        outputs.pred_masks.cpu(),
        inputs["original_sizes"].cpu(),
        inputs["reshaped_input_sizes"].cpu()
    )

    if len(masks) == 0:
        return None

    best_mask = masks[0][0][outputs.iou_scores.argmax()]
    binary_mask = ~best_mask.numpy().astype(bool).astype(int)
    return binary_mask


def replace_object(image, mask, prompt, negative_prompt, seed, guidance_scale):
    """
    Replaces the object in the image based on the mask and prompt.
    """
    if mask is None:
        return image

    # Initialize Inpainting pipeline on CPU with a compatible model
    inpaint_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
        "stabilityai/stable-diffusion-2-inpainting",
        torch_dtype=torch.float32
    ).to("cpu")

    mask_image = Image.fromarray((mask * 255).astype(np.uint8))

    generator = torch.Generator("cpu").manual_seed(seed)

    try:
        result = inpaint_pipeline(
            prompt=prompt,
            image=image,
            mask_image=mask_image,
            negative_prompt=negative_prompt if negative_prompt else None,
            generator=generator,
            guidance_scale=guidance_scale
        ).images[0]
        return result
    except Exception as e:
        print(f"Inpainting error: {e}")
        return image


def visualize_mask(image, mask):
    """
    Overlays the mask on the image for visualization.
    """
    if mask is None:
        return image

    bg_transparent = np.zeros(mask.shape + (4,), dtype=np.uint8)
    bg_transparent[mask == 1] = [0, 255, 0, 127]  # Green with transparency
    mask_rgba = Image.fromarray(bg_transparent)

    overlay = Image.alpha_composite(image.convert("RGBA"), mask_rgba)
    return overlay.convert("RGB")


def get_points(img, evt: gr.SelectData, input_points):
    """
    Captures points selected by the user on the image.
    """
    x, y = evt.index
    input_points.append([x, y])

    # Generate mask based on selected points
    mask = generate_mask(img, input_points)

    # Mark selected points with a green crossmark
    draw = ImageDraw.Draw(img)
    size = 10
    for point in input_points:
        px, py = point
        draw.line((px - size, py, px + size, py), fill="green", width=5)
        draw.line((px, py - size, px, py + size), fill="green", width=5)

    # Visualize the mask overlay
    masked_image = visualize_mask(img, mask)

    return masked_image, input_points


def run_inpaint(prompt, negative_prompt, cfg, seed, invert, input_image, input_points):
    """
    Runs the inpainting process based on user inputs.
    """
    if input_image is None or len(input_points) == 0:
        raise gr.Error("No points provided. Click on the image to select the object to segment with SAM.")

    mask = generate_mask(input_image, input_points)

    if invert:
        mask = ~mask

    try:
        inpainted = replace_object(input_image, mask, prompt, negative_prompt, seed, cfg)
    except Exception as e:
        raise gr.Error(str(e))

    return inpainted.resize((IMG_SIZE, IMG_SIZE))


def preprocess(input_img):
    """
    Preprocesses the uploaded image to ensure it is square and resized.
    """
    if input_img is None:
        return None

    width, height = input_img.size
    if width != height:
        # Add white padding to make the image square
        new_size = max(width, height)
        new_image = Image.new("RGB", (new_size, new_size), 'white')
        left = (new_size - width) // 2
        top = (new_size - height) // 2
        new_image.paste(input_img, (left, top))
        input_img = new_image

    return input_img.resize((IMG_SIZE, IMG_SIZE))


# Gradio Interface
with gr.Blocks() as demo:
    gr.Markdown("# Object Replacement with SAM and Stable Diffusion Inpainting")
    gr.Markdown("Upload an image, click on the object you want to replace, and generate a new image.")

    with gr.Row():
        with gr.Column():
            input_image = gr.Image(label="Upload Image", type="pil")
            output_image = gr.Image(label="Generated Image", type="pil")
            input_points = gr.State([])  # Store selected points

        with gr.Column():
            prompt = gr.Textbox(label="Prompt for Inpainting")
            negative_prompt = gr.Textbox(label="Negative Prompt (Optional)")
            cfg = gr.Slider(1, 20, value=7.5, label="Guidance Scale")
            seed = gr.Number(value=42, label="Seed")
            invert = gr.Checkbox(label="Invert Mask")
            run_button = gr.Button("Run Inpainting")
            reset_button = gr.Button("Reset Points")

    input_image.select(get_points, inputs=[input_image, input_points], outputs=[output_image, input_points])
    run_button.click(
        run_inpaint,
        inputs=[prompt, negative_prompt, cfg, seed, invert, input_image, input_points],
        outputs=output_image
    )
    reset_button.click(lambda: (None, []), outputs=[output_image, input_points])

demo.launch()