import requests from tqdm import tqdm def download_file(url, output_path): response = requests.get(url, stream=True) total_size = int(response.headers.get('content-length', 0)) block_size = 1024 # 1 Kibibyte t = tqdm(total=total_size, unit='iB', unit_scale=True) with open(output_path, 'wb') as f: for data in response.iter_content(block_size): t.update(len(data)) f.write(data) t.close() if total_size != 0 and t.n != total_size: print("ERROR: Something went wrong in download") else: print(f"✅ Downloaded: {output_path}") # Example: Download SAM checkpoint sam_checkpoint_url = "https://huggingface.co/camenduru/segment_anything/resolve/main/sam_vit_h_4b8939.pth" download_file(sam_checkpoint_url, "sam_vit_h_4b8939.pth") # Example: Download Stable Diffusion checkpoint sd_checkpoint_url = "https://huggingface.co/stabilityai/stable-diffusion-2-inpainting/resolve/main/model_index.json" download_file(sd_checkpoint_url, "model_index.json") ########################### import os import requests import zipfile import io import streamlit as st def download_sam_repo(): repo_url = "https://github.com/facebookresearch/segment-anything/archive/refs/heads/main.zip" repo_dir = "segment-anything-main" if not os.path.exists(repo_dir): st.info("🔽 Downloading Segment Anything repo...") response = requests.get(repo_url) if response.status_code == 200: with zipfile.ZipFile(io.BytesIO(response.content)) as zip_ref: zip_ref.extractall(".") st.success("✅ Segment Anything repo downloaded and extracted!") else: st.error(f"❌ Failed to download repo: {response.status_code}") else: st.info("✅ Segment Anything repo already exists.") # Call it at app start download_sam_repo() import sys sys.path.append(os.path.abspath("segment-anything-main")) from segment_anything import sam_model_registry, SamAutomaticMaskGenerator ########################## sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth") #sam.to(device="cpu") import streamlit as st import torch import numpy as np import cv2 import matplotlib.pyplot as plt from PIL import Image from segment_anything import sam_model_registry, SamAutomaticMaskGenerator from diffusers import StableDiffusionInpaintPipeline, EulerDiscreteScheduler import copy # =========================== # Initialize SAM & Diffusion # =========================== @st.cache_resource def load_sam(): sam_checkpoint = "sam_vit_h_4b8939.pth" model_type = "vit_h" device = "cpu" sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) sam.to(device=device) mask_generator = SamAutomaticMaskGenerator( model=sam, points_per_side=32, pred_iou_thresh=0.99, stability_score_thresh=0.92, crop_n_layers=1, crop_n_points_downscale_factor=2, min_mask_region_area=100, ) return mask_generator @st.cache_resource def load_pipeline(): model_dir = 'stabilityai/stable-diffusion-2-inpainting' scheduler = EulerDiscreteScheduler.from_pretrained(model_dir, subfolder='scheduler') pipe = StableDiffusionInpaintPipeline.from_pretrained( model_dir, scheduler=scheduler, torch_dtype=torch.float16, revision="fp16" ) #pipe = pipe.to('cuda') pipe.enable_xformers_memory_efficient_attention() return pipe # =================== # Helper Functions # =================== def create_image_grid(original_image, images, names, rows, columns): images = copy.copy(images) names = copy.copy(names) images.insert(0, original_image) names.insert(0, "Original") fig, axes = plt.subplots(rows, columns, figsize=(15, 15)) for idx, (img, name) in enumerate(zip(images, names)): row, col = divmod(idx, columns) axes[row, col].imshow(img) axes[row, col].set_title(name) axes[row, col].axis('off') for idx in range(len(images), rows * columns): row, col = divmod(idx, columns) axes[row, col].axis('off') plt.tight_layout() st.pyplot(fig) # =================== # Streamlit UI # =================== st.title("🎨 Segment & Inpaint Anything App") uploaded_file = st.file_uploader("Upload an Image", type=['png', 'jpg', 'jpeg']) if uploaded_file: source_image = Image.open(uploaded_file).convert("RGB") st.image(source_image, caption="Uploaded Image", use_column_width=True) # SAM mask_generator = load_sam() segmentation_image = np.asarray(source_image) masks = mask_generator.generate(segmentation_image) st.write(f"Number of segments detected: {len(masks)}") # Show masks fig, ax = plt.subplots(figsize=(10, 10)) ax.imshow(source_image) for i, mask in enumerate(masks): m = mask['segmentation'] color = np.random.random((1, 3)).tolist()[0] img = np.ones((m.shape[0], m.shape[1], 3)) for j in range(3): img[:, :, j] = color[j] ax.imshow(np.dstack((img, m * 0.35))) contours, _ = cv2.findContours(m.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) if contours: cnt = contours[0] M = cv2.moments(cnt) if M["m00"] != 0: cx = int(M["m10"] / M["m00"]) cy = int(M["m01"] / M["m00"]) ax.text(cx, cy, str(i), color='white', fontsize=16, ha='center', va='center', fontweight='bold') ax.axis('off') st.pyplot(fig) # Ask user to choose mask_index = st.number_input(f"Choose Mask Index (0 to {len(masks)-1})", min_value=0, max_value=len(masks)-1, value=0) inpainting_prompt = st.text_input("Enter your Inpainting Prompt", "a skirt full of text") generate_btn = st.button("Generate Inpainting") if generate_btn: selected_mask = masks[mask_index]['segmentation'] stable_diffusion_mask = Image.fromarray(selected_mask * 255).convert("RGB") pipe = load_pipeline() generator = torch.Generator(device="cpu").manual_seed(77) num_images = 4 images = [] for _ in range(num_images): image = pipe( prompt=inpainting_prompt, guidance_scale=7.5, num_inference_steps=60, generator=generator, image=source_image, mask_image=stable_diffusion_mask ).images[0] images.append(image) # Show Grid create_image_grid(source_image, images, [inpainting_prompt]*num_images, 2, 3)