TDN-M commited on
Commit
5729cc3
·
verified ·
1 Parent(s): f357c98

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +163 -69
app.py CHANGED
@@ -1,86 +1,180 @@
1
- import cv2
2
- import numpy as np
3
  import gradio as gr
4
- from huggingface_hub import hf_hub_download
5
- from segment_anything import sam_model_registry, SamPredictor
6
  import torch
 
 
 
 
 
7
 
8
- # Tải mô hình SAM từ Hugging Face
9
- def load_sam_model():
10
- # Tải checkpoint từ Hugging Face với map_location=torch.device('cpu')
11
- checkpoint_path = hf_hub_download(repo_id="facebook/sam-vit-huge", filename="pytorch_model.bin")
12
-
13
- # Load checkpoint với map_location=torch.device('cpu')
14
- checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
15
-
16
- # Khởi tạo mô hình SAM
17
- model_type = "vit_h"
18
- device = "cuda" if torch.cuda.is_available() else "cpu"
19
-
20
- # Truyền checkpoint vào mô hình
21
- sam = sam_model_registry[model_type]()
22
- sam.load_state_dict(checkpoint)
23
- sam.to(device=device)
24
- predictor = SamPredictor(sam)
25
- return predictor
26
-
27
- predictor = load_sam_model()
28
-
29
- def generate_mask(image, event: gr.SelectData):
30
  """
31
- Generate a binary mask for the selected object.
32
- :param image: The input image (numpy array).
33
- :param event: Gradio SelectData containing the click coordinates.
34
- :return: A binary mask where the selected object is black, and the rest is white.
35
  """
36
- # Preprocess the image for SAM
37
- image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
38
- predictor.set_image(image_rgb)
39
-
40
- # Get the click coordinates
41
- x, y = event.index
42
- input_point = np.array([[x, y]])
43
- input_label = np.array([1]) # 1 indicates foreground
44
-
45
- # Generate masks
46
- masks, scores, logits = predictor.predict(
47
- point_coords=input_point,
48
- point_labels=input_label,
49
- multimask_output=True,
50
- )
51
 
52
- # Select the best mask based on the score
53
- best_mask = masks[np.argmax(scores)]
54
 
55
- # Convert the mask to a binary image (black for the object, white for the background)
56
- binary_mask = (best_mask * 255).astype(np.uint8)
57
- binary_mask = cv2.bitwise_not(binary_mask) # Invert colors (black for object)
58
 
 
 
 
 
 
 
 
 
 
 
 
59
  return binary_mask
60
 
61
- def app():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  """
63
- Create the Gradio interface.
64
  """
65
- with gr.Blocks() as demo:
66
- gr.Markdown("# Image Segmentation with Segment Anything Model (SAM)")
67
- gr.Markdown("Upload an image, click on an object to select it, and generate a binary mask.")
68
 
69
- with gr.Row():
70
- with gr.Column():
71
- input_image = gr.Image(label="Upload Image", type="numpy")
72
- output_mask = gr.Image(label="Generated Mask", type="numpy")
 
 
 
 
 
 
 
 
 
 
73
 
74
- with gr.Column():
75
- gr.Markdown("### Instructions")
76
- gr.Markdown("1. Upload an image.")
77
- gr.Markdown("2. Click on the object you want to change.")
78
- gr.Markdown("3. The mask will be generated automatically.")
79
 
80
- input_image.select(generate_mask, inputs=[input_image], outputs=output_mask)
 
 
 
 
 
 
81
 
82
- return demo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
- if __name__ == "__main__":
85
- demo = app()
86
- demo.launch()
 
 
 
1
  import gradio as gr
2
+ import numpy as np
3
+ from PIL import Image, ImageDraw
4
  import torch
5
+ from transformers import SamModel, SamProcessor
6
+ from diffusers import StableDiffusionInpaintPipeline
7
+
8
+ # Constants
9
+ IMG_SIZE = 512
10
 
11
+ def generate_mask(image, points):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  """
13
+ Generates a mask using SAM based on input points.
 
 
 
14
  """
15
+ if not points:
16
+ return None
17
+
18
+ # Initialize SAM model and processor on CPU
19
+ sam_model = SamModel.from_pretrained("facebook/sam-vit-huge", torch_dtype=torch.float32).to("cpu")
20
+ sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
 
 
 
 
 
 
 
 
 
21
 
22
+ inputs = sam_processor(image, points=points, return_tensors="pt").to("cpu")
 
23
 
24
+ with torch.no_grad():
25
+ outputs = sam_model(**inputs)
 
26
 
27
+ masks = sam_processor.image_processor.post_process_masks(
28
+ outputs.pred_masks.cpu(),
29
+ inputs["original_sizes"].cpu(),
30
+ inputs["reshaped_input_sizes"].cpu()
31
+ )
32
+
33
+ if len(masks) == 0:
34
+ return None
35
+
36
+ best_mask = masks[0][0][outputs.iou_scores.argmax()]
37
+ binary_mask = ~best_mask.numpy().astype(bool).astype(int)
38
  return binary_mask
39
 
40
+
41
+ def replace_object(image, mask, prompt, negative_prompt, seed, guidance_scale):
42
+ """
43
+ Replaces the object in the image based on the mask and prompt.
44
+ """
45
+ if mask is None:
46
+ return image
47
+
48
+ # Initialize Inpainting pipeline on CPU with a compatible model
49
+ inpaint_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
50
+ "stabilityai/stable-diffusion-2-inpainting",
51
+ torch_dtype=torch.float32
52
+ ).to("cpu")
53
+
54
+ mask_image = Image.fromarray((mask * 255).astype(np.uint8))
55
+
56
+ generator = torch.Generator("cpu").manual_seed(seed)
57
+
58
+ try:
59
+ result = inpaint_pipeline(
60
+ prompt=prompt,
61
+ image=image,
62
+ mask_image=mask_image,
63
+ negative_prompt=negative_prompt if negative_prompt else None,
64
+ generator=generator,
65
+ guidance_scale=guidance_scale
66
+ ).images[0]
67
+ return result
68
+ except Exception as e:
69
+ print(f"Inpainting error: {e}")
70
+ return image
71
+
72
+
73
+ def visualize_mask(image, mask):
74
  """
75
+ Overlays the mask on the image for visualization.
76
  """
77
+ if mask is None:
78
+ return image
 
79
 
80
+ bg_transparent = np.zeros(mask.shape + (4,), dtype=np.uint8)
81
+ bg_transparent[mask == 1] = [0, 255, 0, 127] # Green with transparency
82
+ mask_rgba = Image.fromarray(bg_transparent)
83
+
84
+ overlay = Image.alpha_composite(image.convert("RGBA"), mask_rgba)
85
+ return overlay.convert("RGB")
86
+
87
+
88
+ def get_points(img, evt: gr.SelectData, input_points):
89
+ """
90
+ Captures points selected by the user on the image.
91
+ """
92
+ x, y = evt.index
93
+ input_points.append([x, y])
94
 
95
+ # Generate mask based on selected points
96
+ mask = generate_mask(img, input_points)
 
 
 
97
 
98
+ # Mark selected points with a green crossmark
99
+ draw = ImageDraw.Draw(img)
100
+ size = 10
101
+ for point in input_points:
102
+ px, py = point
103
+ draw.line((px - size, py, px + size, py), fill="green", width=5)
104
+ draw.line((px, py - size, px, py + size), fill="green", width=5)
105
 
106
+ # Visualize the mask overlay
107
+ masked_image = visualize_mask(img, mask)
108
+
109
+ return masked_image, input_points
110
+
111
+
112
+ def run_inpaint(prompt, negative_prompt, cfg, seed, invert, input_image, input_points):
113
+ """
114
+ Runs the inpainting process based on user inputs.
115
+ """
116
+ if input_image is None or len(input_points) == 0:
117
+ raise gr.Error("No points provided. Click on the image to select the object to segment with SAM.")
118
+
119
+ mask = generate_mask(input_image, input_points)
120
+
121
+ if invert:
122
+ mask = ~mask
123
+
124
+ try:
125
+ inpainted = replace_object(input_image, mask, prompt, negative_prompt, seed, cfg)
126
+ except Exception as e:
127
+ raise gr.Error(str(e))
128
+
129
+ return inpainted.resize((IMG_SIZE, IMG_SIZE))
130
+
131
+
132
+ def preprocess(input_img):
133
+ """
134
+ Preprocesses the uploaded image to ensure it is square and resized.
135
+ """
136
+ if input_img is None:
137
+ return None
138
+
139
+ width, height = input_img.size
140
+ if width != height:
141
+ # Add white padding to make the image square
142
+ new_size = max(width, height)
143
+ new_image = Image.new("RGB", (new_size, new_size), 'white')
144
+ left = (new_size - width) // 2
145
+ top = (new_size - height) // 2
146
+ new_image.paste(input_img, (left, top))
147
+ input_img = new_image
148
+
149
+ return input_img.resize((IMG_SIZE, IMG_SIZE))
150
+
151
+
152
+ # Gradio Interface
153
+ with gr.Blocks() as demo:
154
+ gr.Markdown("# Object Replacement with SAM and Stable Diffusion Inpainting")
155
+ gr.Markdown("Upload an image, click on the object you want to replace, and generate a new image.")
156
+
157
+ with gr.Row():
158
+ with gr.Column():
159
+ input_image = gr.Image(label="Upload Image", type="pil")
160
+ output_image = gr.Image(label="Generated Image", type="pil")
161
+ input_points = gr.State([]) # Store selected points
162
+
163
+ with gr.Column():
164
+ prompt = gr.Textbox(label="Prompt for Inpainting")
165
+ negative_prompt = gr.Textbox(label="Negative Prompt (Optional)")
166
+ cfg = gr.Slider(1, 20, value=7.5, label="Guidance Scale")
167
+ seed = gr.Number(value=42, label="Seed")
168
+ invert = gr.Checkbox(label="Invert Mask")
169
+ run_button = gr.Button("Run Inpainting")
170
+ reset_button = gr.Button("Reset Points")
171
+
172
+ input_image.select(get_points, inputs=[input_image, input_points], outputs=[output_image, input_points])
173
+ run_button.click(
174
+ run_inpaint,
175
+ inputs=[prompt, negative_prompt, cfg, seed, invert, input_image, input_points],
176
+ outputs=output_image
177
+ )
178
+ reset_button.click(lambda: (None, []), outputs=[output_image, input_points])
179
 
180
+ demo.launch()