drkareemkamal commited on
Commit
326536e
Β·
verified Β·
1 Parent(s): 90d6ba5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +158 -0
app.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import requests
3
+ import zipfile
4
+ import io
5
+ import os
6
+ import sys
7
+ import torch
8
+ import numpy as np
9
+ import matplotlib.pyplot as plt
10
+ import cv2
11
+ from PIL import Image
12
+ from diffusers import StableDiffusionInpaintPipeline, EulerDiscreteScheduler
13
+ import copy
14
+
15
+ # ========== Download SAM Repo ==========
16
+ def download_sam_repo():
17
+ repo_url = "https://github.com/facebookresearch/segment-anything/archive/refs/heads/main.zip"
18
+ repo_dir = "segment-anything-main"
19
+ if not os.path.exists(repo_dir):
20
+ st.info("πŸ”½ Downloading Segment Anything repo...")
21
+ response = requests.get(repo_url)
22
+ if response.status_code == 200:
23
+ with zipfile.ZipFile(io.BytesIO(response.content)) as zip_ref:
24
+ zip_ref.extractall(".")
25
+ st.success("βœ… Segment Anything repo downloaded!")
26
+ else:
27
+ st.error(f"❌ Failed to download repo: {response.status_code}")
28
+
29
+ download_sam_repo()
30
+ sys.path.append(os.path.abspath("segment-anything-main"))
31
+
32
+ from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
33
+
34
+ # ========== Download SAM Model Checkpoint ==========
35
+ def download_file(url, output_path):
36
+ if not os.path.exists(output_path):
37
+ st.info(f"πŸ”½ Downloading {os.path.basename(output_path)}...")
38
+ response = requests.get(url, stream=True)
39
+ with open(output_path, 'wb') as f:
40
+ for chunk in response.iter_content(chunk_size=8192):
41
+ if chunk:
42
+ f.write(chunk)
43
+ st.success(f"βœ… Downloaded {os.path.basename(output_path)}")
44
+
45
+ sam_url = "https://huggingface.co/camenduru/segment_anything/resolve/main/sam_vit_h_4b8939.pth"
46
+ download_file(sam_url, "sam_vit_h_4b8939.pth")
47
+
48
+ # ========== Load SAM Model ==========
49
+ @st.cache_resource
50
+ def load_sam_model():
51
+ sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
52
+ sam.to("cuda")
53
+ mask_generator = SamAutomaticMaskGenerator(
54
+ model=sam,
55
+ points_per_side=32,
56
+ pred_iou_thresh=0.99,
57
+ stability_score_thresh=0.92,
58
+ crop_n_layers=1,
59
+ crop_n_points_downscale_factor=2,
60
+ min_mask_region_area=100,
61
+ )
62
+ return mask_generator
63
+
64
+ # ========== Load SD Model ==========
65
+ @st.cache_resource
66
+ def load_sd_pipeline():
67
+ model_dir = 'stabilityai/stable-diffusion-2-inpainting'
68
+ scheduler = EulerDiscreteScheduler.from_pretrained(model_dir, subfolder='scheduler')
69
+ pipe = StableDiffusionInpaintPipeline.from_pretrained(
70
+ model_dir,
71
+ scheduler=scheduler,
72
+ torch_dtype=torch.float16,
73
+ revision="fp16"
74
+ ).to("cuda")
75
+ pipe.enable_xformers_memory_efficient_attention()
76
+ return pipe
77
+
78
+ # ========== Display masks ==========
79
+ def show_masks(image, masks):
80
+ fig, ax = plt.subplots(figsize=(10, 10))
81
+ ax.imshow(image)
82
+ for i, mask in enumerate(masks):
83
+ m = mask['segmentation']
84
+ color = np.random.random((1, 3)).tolist()[0]
85
+ img = np.ones((m.shape[0], m.shape[1], 3))
86
+ for j in range(3):
87
+ img[:, :, j] = color[j]
88
+ ax.imshow(np.dstack((img, m * 0.35)))
89
+ contours, _ = cv2.findContours(m.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
90
+ if contours:
91
+ cnt = contours[0]
92
+ M = cv2.moments(cnt)
93
+ if M["m00"] != 0:
94
+ cx = int(M["m10"] / M["m00"])
95
+ cy = int(M["m01"] / M["m00"])
96
+ ax.text(cx, cy, str(i), color='white', fontsize=16, ha='center', va='center', fontweight='bold')
97
+ ax.axis('off')
98
+ st.pyplot(fig)
99
+
100
+ # ========== Image Grid ==========
101
+ def create_image_grid(original_image, images, names, rows, columns):
102
+ images = copy.copy(images)
103
+ names = copy.copy(names)
104
+
105
+ images.insert(0, original_image)
106
+ names.insert(0, "Original")
107
+
108
+ fig, axes = plt.subplots(rows, columns, figsize=(15, 15))
109
+ for idx, (img, name) in enumerate(zip(images, names)):
110
+ row, col = divmod(idx, columns)
111
+ axes[row, col].imshow(img)
112
+ axes[row, col].set_title(name)
113
+ axes[row, col].axis('off')
114
+ for idx in range(len(images), rows * columns):
115
+ row, col = divmod(idx, columns)
116
+ axes[row, col].axis('off')
117
+ plt.tight_layout()
118
+ st.pyplot(fig)
119
+
120
+ # ========== Streamlit UI ==========
121
+ st.title("🎨 Segment & Inpaint with Streamlit")
122
+
123
+ uploaded_file = st.file_uploader("Upload an Image", type=['png', 'jpg', 'jpeg'])
124
+
125
+ if uploaded_file:
126
+ source_image = Image.open(uploaded_file).convert("RGB").resize((512, 512))
127
+ st.image(source_image, caption="Uploaded Image", use_column_width=True)
128
+
129
+ mask_gen = load_sam_model()
130
+ masks = mask_gen.generate(np.asarray(source_image))
131
+ st.write(f"Number of Segments Found: {len(masks)}")
132
+
133
+ show_masks(source_image, masks)
134
+
135
+ mask_idx = st.number_input(f"Select Mask Index (0 to {len(masks)-1})", min_value=0, max_value=len(masks)-1, value=0)
136
+ prompt = st.text_input("Enter Inpainting Prompt", "a skirt full of text")
137
+ generate = st.button("Generate Inpainting")
138
+
139
+ if generate:
140
+ segmentation_mask = masks[mask_idx]['segmentation']
141
+ stable_mask = Image.fromarray(segmentation_mask * 255).convert("RGB")
142
+
143
+ pipe = load_sd_pipeline()
144
+ generator = torch.Generator(device="cuda").manual_seed(77)
145
+
146
+ images = []
147
+ for i in range(4):
148
+ result = pipe(
149
+ prompt=prompt,
150
+ guidance_scale=7.5,
151
+ num_inference_steps=50,
152
+ generator=generator,
153
+ image=source_image,
154
+ mask_image=stable_mask
155
+ ).images[0]
156
+ images.append(result)
157
+
158
+ create_image_grid(source_image, images, [prompt]*4, 2, 3)