Spaces:
Sleeping
Sleeping
import requests | |
from tqdm import tqdm | |
def download_file(url, output_path): | |
response = requests.get(url, stream=True) | |
total_size = int(response.headers.get('content-length', 0)) | |
block_size = 1024 # 1 Kibibyte | |
t = tqdm(total=total_size, unit='iB', unit_scale=True) | |
with open(output_path, 'wb') as f: | |
for data in response.iter_content(block_size): | |
t.update(len(data)) | |
f.write(data) | |
t.close() | |
if total_size != 0 and t.n != total_size: | |
print("ERROR: Something went wrong in download") | |
else: | |
print(f"β Downloaded: {output_path}") | |
# Example: Download SAM checkpoint | |
sam_checkpoint_url = "https://huggingface.co/camenduru/segment_anything/resolve/main/sam_vit_h_4b8939.pth" | |
download_file(sam_checkpoint_url, "sam_vit_h_4b8939.pth") | |
# Example: Download Stable Diffusion checkpoint | |
sd_checkpoint_url = "https://huggingface.co/stabilityai/stable-diffusion-2-inpainting/resolve/main/model_index.json" | |
download_file(sd_checkpoint_url, "model_index.json") | |
########################### | |
import os | |
import requests | |
import zipfile | |
import io | |
import streamlit as st | |
def download_sam_repo(): | |
repo_url = "https://github.com/facebookresearch/segment-anything/archive/refs/heads/main.zip" | |
repo_dir = "segment-anything-main" | |
if not os.path.exists(repo_dir): | |
st.info("π½ Downloading Segment Anything repo...") | |
response = requests.get(repo_url) | |
if response.status_code == 200: | |
with zipfile.ZipFile(io.BytesIO(response.content)) as zip_ref: | |
zip_ref.extractall(".") | |
st.success("β Segment Anything repo downloaded and extracted!") | |
else: | |
st.error(f"β Failed to download repo: {response.status_code}") | |
else: | |
st.info("β Segment Anything repo already exists.") | |
# Call it at app start | |
download_sam_repo() | |
import sys | |
sys.path.append(os.path.abspath("segment-anything-main")) | |
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator | |
########################## | |
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth") | |
#sam.to(device="cpu") | |
import streamlit as st | |
import torch | |
import numpy as np | |
import cv2 | |
import matplotlib.pyplot as plt | |
from PIL import Image | |
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator | |
from diffusers import StableDiffusionInpaintPipeline, EulerDiscreteScheduler | |
import copy | |
# =========================== | |
# Initialize SAM & Diffusion | |
# =========================== | |
def load_sam(): | |
sam_checkpoint = "sam_vit_h_4b8939.pth" | |
model_type = "vit_h" | |
device = "cpu" | |
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) | |
sam.to(device=device) | |
mask_generator = SamAutomaticMaskGenerator( | |
model=sam, | |
points_per_side=32, | |
pred_iou_thresh=0.99, | |
stability_score_thresh=0.92, | |
crop_n_layers=1, | |
crop_n_points_downscale_factor=2, | |
min_mask_region_area=100, | |
) | |
return mask_generator | |
def load_pipeline(): | |
model_dir = 'stabilityai/stable-diffusion-2-inpainting' | |
scheduler = EulerDiscreteScheduler.from_pretrained(model_dir, subfolder='scheduler') | |
pipe = StableDiffusionInpaintPipeline.from_pretrained( | |
model_dir, | |
scheduler=scheduler, | |
torch_dtype=torch.float16, | |
revision="fp16" | |
) | |
#pipe = pipe.to('cuda') | |
pipe.enable_xformers_memory_efficient_attention() | |
return pipe | |
# =================== | |
# Helper Functions | |
# =================== | |
def create_image_grid(original_image, images, names, rows, columns): | |
images = copy.copy(images) | |
names = copy.copy(names) | |
images.insert(0, original_image) | |
names.insert(0, "Original") | |
fig, axes = plt.subplots(rows, columns, figsize=(15, 15)) | |
for idx, (img, name) in enumerate(zip(images, names)): | |
row, col = divmod(idx, columns) | |
axes[row, col].imshow(img) | |
axes[row, col].set_title(name) | |
axes[row, col].axis('off') | |
for idx in range(len(images), rows * columns): | |
row, col = divmod(idx, columns) | |
axes[row, col].axis('off') | |
plt.tight_layout() | |
st.pyplot(fig) | |
# =================== | |
# Streamlit UI | |
# =================== | |
st.title("π¨ Segment & Inpaint Anything App") | |
uploaded_file = st.file_uploader("Upload an Image", type=['png', 'jpg', 'jpeg']) | |
if uploaded_file: | |
source_image = Image.open(uploaded_file).convert("RGB") | |
st.image(source_image, caption="Uploaded Image", use_column_width=True) | |
# SAM | |
mask_generator = load_sam() | |
segmentation_image = np.asarray(source_image) | |
masks = mask_generator.generate(segmentation_image) | |
st.write(f"Number of segments detected: {len(masks)}") | |
# Show masks | |
fig, ax = plt.subplots(figsize=(10, 10)) | |
ax.imshow(source_image) | |
for i, mask in enumerate(masks): | |
m = mask['segmentation'] | |
color = np.random.random((1, 3)).tolist()[0] | |
img = np.ones((m.shape[0], m.shape[1], 3)) | |
for j in range(3): | |
img[:, :, j] = color[j] | |
ax.imshow(np.dstack((img, m * 0.35))) | |
contours, _ = cv2.findContours(m.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
if contours: | |
cnt = contours[0] | |
M = cv2.moments(cnt) | |
if M["m00"] != 0: | |
cx = int(M["m10"] / M["m00"]) | |
cy = int(M["m01"] / M["m00"]) | |
ax.text(cx, cy, str(i), color='white', fontsize=16, ha='center', va='center', fontweight='bold') | |
ax.axis('off') | |
st.pyplot(fig) | |
# Ask user to choose | |
mask_index = st.number_input(f"Choose Mask Index (0 to {len(masks)-1})", min_value=0, max_value=len(masks)-1, value=0) | |
inpainting_prompt = st.text_input("Enter your Inpainting Prompt", "a skirt full of text") | |
generate_btn = st.button("Generate Inpainting") | |
if generate_btn: | |
selected_mask = masks[mask_index]['segmentation'] | |
stable_diffusion_mask = Image.fromarray(selected_mask * 255).convert("RGB") | |
pipe = load_pipeline() | |
generator = torch.Generator(device="cpu").manual_seed(77) | |
num_images = 4 | |
images = [] | |
for _ in range(num_images): | |
image = pipe( | |
prompt=inpainting_prompt, | |
guidance_scale=7.5, | |
num_inference_steps=60, | |
generator=generator, | |
image=source_image, | |
mask_image=stable_diffusion_mask | |
).images[0] | |
images.append(image) | |
# Show Grid | |
create_image_grid(source_image, images, [inpainting_prompt]*num_images, 2, 3) | |