fashion_app / app0.py
drkareemkamal's picture
Rename app.py to app0.py
90d6ba5 verified
import requests
from tqdm import tqdm
def download_file(url, output_path):
response = requests.get(url, stream=True)
total_size = int(response.headers.get('content-length', 0))
block_size = 1024 # 1 Kibibyte
t = tqdm(total=total_size, unit='iB', unit_scale=True)
with open(output_path, 'wb') as f:
for data in response.iter_content(block_size):
t.update(len(data))
f.write(data)
t.close()
if total_size != 0 and t.n != total_size:
print("ERROR: Something went wrong in download")
else:
print(f"βœ… Downloaded: {output_path}")
# Example: Download SAM checkpoint
sam_checkpoint_url = "https://huggingface.co/camenduru/segment_anything/resolve/main/sam_vit_h_4b8939.pth"
download_file(sam_checkpoint_url, "sam_vit_h_4b8939.pth")
# Example: Download Stable Diffusion checkpoint
sd_checkpoint_url = "https://huggingface.co/stabilityai/stable-diffusion-2-inpainting/resolve/main/model_index.json"
download_file(sd_checkpoint_url, "model_index.json")
###########################
import os
import requests
import zipfile
import io
import streamlit as st
def download_sam_repo():
repo_url = "https://github.com/facebookresearch/segment-anything/archive/refs/heads/main.zip"
repo_dir = "segment-anything-main"
if not os.path.exists(repo_dir):
st.info("πŸ”½ Downloading Segment Anything repo...")
response = requests.get(repo_url)
if response.status_code == 200:
with zipfile.ZipFile(io.BytesIO(response.content)) as zip_ref:
zip_ref.extractall(".")
st.success("βœ… Segment Anything repo downloaded and extracted!")
else:
st.error(f"❌ Failed to download repo: {response.status_code}")
else:
st.info("βœ… Segment Anything repo already exists.")
# Call it at app start
download_sam_repo()
import sys
sys.path.append(os.path.abspath("segment-anything-main"))
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
##########################
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
#sam.to(device="cpu")
import streamlit as st
import torch
import numpy as np
import cv2
import matplotlib.pyplot as plt
from PIL import Image
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
from diffusers import StableDiffusionInpaintPipeline, EulerDiscreteScheduler
import copy
# ===========================
# Initialize SAM & Diffusion
# ===========================
@st.cache_resource
def load_sam():
sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"
device = "cpu"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
mask_generator = SamAutomaticMaskGenerator(
model=sam,
points_per_side=32,
pred_iou_thresh=0.99,
stability_score_thresh=0.92,
crop_n_layers=1,
crop_n_points_downscale_factor=2,
min_mask_region_area=100,
)
return mask_generator
@st.cache_resource
def load_pipeline():
model_dir = 'stabilityai/stable-diffusion-2-inpainting'
scheduler = EulerDiscreteScheduler.from_pretrained(model_dir, subfolder='scheduler')
pipe = StableDiffusionInpaintPipeline.from_pretrained(
model_dir,
scheduler=scheduler,
torch_dtype=torch.float16,
revision="fp16"
)
#pipe = pipe.to('cuda')
pipe.enable_xformers_memory_efficient_attention()
return pipe
# ===================
# Helper Functions
# ===================
def create_image_grid(original_image, images, names, rows, columns):
images = copy.copy(images)
names = copy.copy(names)
images.insert(0, original_image)
names.insert(0, "Original")
fig, axes = plt.subplots(rows, columns, figsize=(15, 15))
for idx, (img, name) in enumerate(zip(images, names)):
row, col = divmod(idx, columns)
axes[row, col].imshow(img)
axes[row, col].set_title(name)
axes[row, col].axis('off')
for idx in range(len(images), rows * columns):
row, col = divmod(idx, columns)
axes[row, col].axis('off')
plt.tight_layout()
st.pyplot(fig)
# ===================
# Streamlit UI
# ===================
st.title("🎨 Segment & Inpaint Anything App")
uploaded_file = st.file_uploader("Upload an Image", type=['png', 'jpg', 'jpeg'])
if uploaded_file:
source_image = Image.open(uploaded_file).convert("RGB")
st.image(source_image, caption="Uploaded Image", use_column_width=True)
# SAM
mask_generator = load_sam()
segmentation_image = np.asarray(source_image)
masks = mask_generator.generate(segmentation_image)
st.write(f"Number of segments detected: {len(masks)}")
# Show masks
fig, ax = plt.subplots(figsize=(10, 10))
ax.imshow(source_image)
for i, mask in enumerate(masks):
m = mask['segmentation']
color = np.random.random((1, 3)).tolist()[0]
img = np.ones((m.shape[0], m.shape[1], 3))
for j in range(3):
img[:, :, j] = color[j]
ax.imshow(np.dstack((img, m * 0.35)))
contours, _ = cv2.findContours(m.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if contours:
cnt = contours[0]
M = cv2.moments(cnt)
if M["m00"] != 0:
cx = int(M["m10"] / M["m00"])
cy = int(M["m01"] / M["m00"])
ax.text(cx, cy, str(i), color='white', fontsize=16, ha='center', va='center', fontweight='bold')
ax.axis('off')
st.pyplot(fig)
# Ask user to choose
mask_index = st.number_input(f"Choose Mask Index (0 to {len(masks)-1})", min_value=0, max_value=len(masks)-1, value=0)
inpainting_prompt = st.text_input("Enter your Inpainting Prompt", "a skirt full of text")
generate_btn = st.button("Generate Inpainting")
if generate_btn:
selected_mask = masks[mask_index]['segmentation']
stable_diffusion_mask = Image.fromarray(selected_mask * 255).convert("RGB")
pipe = load_pipeline()
generator = torch.Generator(device="cpu").manual_seed(77)
num_images = 4
images = []
for _ in range(num_images):
image = pipe(
prompt=inpainting_prompt,
guidance_scale=7.5,
num_inference_steps=60,
generator=generator,
image=source_image,
mask_image=stable_diffusion_mask
).images[0]
images.append(image)
# Show Grid
create_image_grid(source_image, images, [inpainting_prompt]*num_images, 2, 3)