EEE515-HW3 / app.py
Ash2505's picture
Update app.py
e7b1b3f verified
raw
history blame
6.27 kB
import cv2
import numpy as np
from PIL import Image, ImageFilter
import torch
import gradio as gr
from torchvision import transforms
from transformers import (
AutoModelForImageSegmentation,
DepthProImageProcessorFast,
DepthProForDepthEstimation,
)
# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
# -----------------------------
# Load Segmentation Model (RMBG-2.0 by briaai)
# -----------------------------
seg_model = AutoModelForImageSegmentation.from_pretrained(
"briaai/RMBG-2.0", trust_remote_code=True
)
torch.set_float32_matmul_precision(["high", "highest"][0])
seg_model.to(device)
seg_model.eval()
# Define segmentation image size and transform
seg_image_size = (1024, 1024)
seg_transform = transforms.Compose([
transforms.Resize(seg_image_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# -----------------------------
# Load Depth Estimation Model (DepthPro by Apple)
# -----------------------------
depth_processor = DepthProImageProcessorFast.from_pretrained("apple/DepthPro-hf")
depth_model = DepthProForDepthEstimation.from_pretrained("apple/DepthPro-hf")
depth_model.to(device)
depth_model.eval()
# -----------------------------
# Define the Segmentation-Based Blur Effect
# -----------------------------
def segmentation_blur_effect(input_image: Image.Image):
imageResized = input_image.resize(seg_image_size)
input_tensor = seg_transform(imageResized).unsqueeze(0).to(device)
with torch.no_grad():
preds = seg_model(input_tensor)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(imageResized.size)
mask_np = np.array(mask.convert("L"))
_, maskBinary = cv2.threshold(mask_np, 127, 255, cv2.THRESH_BINARY)
img = cv2.cvtColor(np.array(imageResized), cv2.COLOR_RGB2BGR)
blurredBg = cv2.GaussianBlur(np.array(imageResized), (0, 0), sigmaX=15, sigmaY=15)
maskInv = cv2.bitwise_not(maskBinary)
maskInv3 = cv2.cvtColor(maskInv, cv2.COLOR_GRAY2BGR)
foreground = cv2.bitwise_and(img, cv2.bitwise_not(maskInv3))
background = cv2.bitwise_and(blurredBg, maskInv3)
finalImg = cv2.add(cv2.cvtColor(foreground, cv2.COLOR_BGR2RGB), background)
finalImg_pil = Image.fromarray(finalImg)
return finalImg_pil, mask
def lens_blur_effect(input_image: Image.Image, fg_threshold: float = 85, mg_threshold: float = 170):
inputs = depth_processor(images=input_image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = depth_model(**inputs)
post_processed_output = depth_processor.post_process_depth_estimation(
outputs, target_sizes=[(input_image.height, input_image.width)]
)
depth = post_processed_output[0]["predicted_depth"]
depth = (depth - depth.min()) / (depth.max() - depth.min())
depth = depth * 255.
depth = depth.detach().cpu().numpy()
depth_map = depth.astype(np.uint8)
depthImg = Image.fromarray(depth_map)
img = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR)
img_foreground = img.copy() # No blur for foreground
img_middleground = cv2.GaussianBlur(img, (0, 0), sigmaX=7, sigmaY=7)
img_background = cv2.GaussianBlur(img, (0, 0), sigmaX=15, sigmaY=15)
print(depth_map)
depth_map = depth_map.astype(np.float32) / depth_map.max()
threshold1 = fg_threshold
threshold2 = mg_threshold
mask_fg = (depth_map < threshold1).astype(np.float32)
mask_mg = ((depth_map >= threshold1) & (depth_map < threshold2)).astype(np.float32)
mask_bg = (depth_map >= threshold2).astype(np.float32)
mask_fg_3 = np.stack([mask_fg]*3, axis=-1)
mask_mg_3 = np.stack([mask_mg]*3, axis=-1)
mask_bg_3 = np.stack([mask_bg]*3, axis=-1)
final_img = (img_foreground * mask_fg_3 +
img_middleground * mask_mg_3 +
img_background * mask_bg_3).astype(np.uint8)
final_img_rgb = cv2.cvtColor(final_img, cv2.COLOR_BGR2RGB)
lensBlurImage = Image.fromarray(final_img_rgb)
mask_fg_img = Image.fromarray((mask_fg * 255).astype(np.uint8))
mask_mg_img = Image.fromarray((mask_mg * 255).astype(np.uint8))
mask_bg_img = Image.fromarray((mask_bg * 255).astype(np.uint8))
return depthImg, lensBlurImage, mask_fg_img, mask_mg_img, mask_bg_img
def process_image(input_image: Image.Image, fg_threshold: float, mg_threshold: float):
seg_blur, seg_mask = segmentation_blur_effect(input_image)
depth_map_img, lens_blur_img, mask_fg_img, mask_mg_img, mask_bg_img = lens_blur_effect(
input_image, fg_threshold, mg_threshold
)
return (
seg_blur,
# seg_mask,
depth_map_img,
lens_blur_img,
# mask_fg_img,
# mask_mg_img,
# mask_bg_img
)
def update_preset(preset: str):
presets = {
"Preset 1": {
"image_url": "https://i.ibb.co/fznz2b2b/hw3-q2.jpg",
"fg_threshold": 0.33,
"mg_threshold": 0.66
},
"Preset 2": {
"image_url": "https://i.ibb.co/HLZGW7qH/q26.jpg",
"fg_threshold": 0.2,
"mg_threshold": 0.66
}
}
preset_info = presets[preset]
response = requests.get(preset_info["image_url"])
image = Image.open(BytesIO(response.content)).convert("RGB")
return image, preset_info["fg_threshold"], preset_info["mg_threshold"]
title = "Blur Effects on Segmentation-Based Gaussian Blur & Depth-Based Lens Blur with Adjustable Depth Thresholds"
demo = gr.Interface(
fn=process_image,
inputs=[
gr.Image(type="pil", label="Input Image", value="https://i.ibb.co/fznz2b2b/hw3-q2.jpg"),
gr.Slider(minimum=0, maximum=1, step=0.01, value=0.33, label="Foreground Depth Threshold"),
gr.Slider(minimum=0, maximum=1, step=0.01, value=0.66, label="Middleground Depth Threshold")
],
outputs=[
gr.Image(type="pil", label="Segmentation-Based Blur"),
gr.Image(type="pil", label="Depth Map"),
gr.Image(type="pil", label="Depth-Based Lens Blur")
],
title=title,
allow_flagging="never"
)
if __name__ == "__main__":
demo.launch()