EEE515-HW3 / app.py
Ash2505's picture
Update app.py
246dbcd verified
raw
history blame
6.95 kB
import gradio as gr
from PIL import Image, ImageFilter
# import matplotlib.pyplot as plt
import torch
import cv2
import numpy as np
from torchvision import transforms
from transformers import AutoModelForImageSegmentation, DepthProImageProcessorFast, DepthProForDepthEstimation
import requests
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
birefnet = AutoModelForImageSegmentation.from_pretrained('ZhengPeng7/BiRefNet', trust_remote_code=True)
torch.set_float32_matmul_precision(['high', 'highest'][0])
birefnet.to('cuda')
birefnet.eval()
birefnet.half()
def extract_object(image, t1, t2):
# Data settings
image_size = (1024, 1024)
transform_image = transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# image = Image.open(imagepath)
image1 = image.copy()
input_images = transform_image(image1).unsqueeze(0).to('cuda').half()
# Prediction
with torch.no_grad():
preds = birefnet(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image1.size)
image1.putalpha(mask)
blurredBg = cv2.GaussianBlur(np.array(imageResized), (0, 0), sigmaX=15, sigmaY=15)
mask = np.array(result[1].convert("L"))
_, maskBinary = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
img = cv2.cvtColor(np.array(imageResized), cv2.COLOR_RGB2BGR)
maskInv = cv2.bitwise_not(maskBinary)
maskInv3 = cv2.cvtColor(maskInv, cv2.COLOR_GRAY2BGR)
foreground = cv2.bitwise_and(img, cv2.bitwise_not(maskInv3))
background = cv2.bitwise_and(blurredBg, maskInv3)
finalImg = cv2.add(cv2.cvtColor(foreground, cv2.COLOR_BGR2RGB), background)
# plt.figure(figsize=(15, 5))
# return image1, mask
# def depth_estimation():
imageProcessor = DepthProImageProcessorFast.from_pretrained("apple/DepthPro-hf")
model = DepthProForDepthEstimation.from_pretrained("apple/DepthPro-hf").to(device)
inputs = imageProcessor(images=imageResized, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
post_processed_output = imageProcessor.post_process_depth_estimation(
outputs, target_sizes=[(imageResized.height, imageResized.width)],
)
field_of_view = post_processed_output[0]["field_of_view"]
focal_length = post_processed_output[0]["focal_length"]
depth = post_processed_output[0]["predicted_depth"]
depth = (depth - depth.min()) / (depth.max() - depth.min())
depth = depth * 255.
depth = depth.detach().cpu().numpy()
# print(depth)
depthImg = Image.fromarray(depth.astype("uint8"))
# threshold1 = 255 / 20 # ~85
# threshold2 = 2 * 255 / 3 # ~170
threshold1 = (t1/10) * 255
threshold2 = (t2/10) * 255
# Precompute blurred versions for each region
img_foreground = img.copy() # No blur for foreground
img_middleground = cv2.GaussianBlur(img, (0, 0), sigmaX=7, sigmaY=7)
img_background = cv2.GaussianBlur(img, (0, 0), sigmaX=15, sigmaY=15)
# Create masks for each region (as float arrays for proper blending)
mask_fg = (depth < threshold1).astype(np.float32)
mask_mg = ((depth >= threshold1) & (depth < threshold2)).astype(np.float32)
mask_bg = (depth >= threshold2).astype(np.float32)
# Expand masks to 3 channels (H, W, 3)
mask_fg = np.stack([mask_fg]*3, axis=-1)
mask_mg = np.stack([mask_mg]*3, axis=-1)
mask_bg = np.stack([mask_bg]*3, axis=-1)
# Combine the images using the masks in a vectorized manner.
final_img = (img_foreground * mask_fg +
img_middleground * mask_mg +
img_background * mask_bg).astype(np.uint8)
# Convert the result back to RGB for display with matplotlib.
final_img_rgb = cv2.cvtColor(final_img, cv2.COLOR_BGR2RGB)
return image1, final_img
# Visualization
# plt.axis("off")
# subplots for 3 images: original, segmented, mask
# plt.figure(figsize=(15, 5))
# image = Image.open('/content/drive/MyDrive/eee515-hw3/hw3-q24.jpg')
# #resize the image to 512x512
# imageResized = image.resize((512, 512))
# result = extract_object(birefnet, imageResized)
# plt.subplot(1, 3, 1)
# plt.title("Original Resized Image")
# plt.imshow(imageResized)
# plt.subplot(1, 3, 2)
# plt.title("Segmented Image")
# plt.imshow(result[0])
# plt.subplot(1, 3, 3)
# plt.title("Mask")
# plt.imshow(result[1], cmap="gray")
# plt.show()
# Create a Gradio interface
def build_interface(image1, image2):
"""Build UI for gradio app
"""
title = "Bokeh and Lens Blur"
with gr.Blocks(theme=gr.themes.Soft(), title=title, fill_width=True) as interface:
with gr.Row():
# with gr.Column(scale=3):
# with gr.Group():
# input_text_box = gr.Textbox(
# value=None,
# label="Prompt",
# lines=2,
# )
# # gr.Markdown("### Set the values for Middleground and Background")
# # fg = gr.Slider(minimum=0, maximum=99, step=1, value=33, label="Middleground")
# # mg = gr.Slider(minimum=0, maximum=99, step=1, value=66, label="Background")
# with gr.Row():
# submit_button = gr.Button("Submit", variant="primary")
with gr.Column(scale=3):
model3d = gr.Model3D(
label="Output", height="45em", interactive=False
)
with gr.Column(scale=3):
model3d = gr.Model3D(
label="Output", height="45em", interactive=False
)
submit_button.click(
handle_text_prompt,
inputs=[
input_text_box,
variance
],
outputs=[
model3d
]
)
return interface
# demo = gr.Interface(sepia, gr.Image(), "image")
title = "Gaussian Blur Background App"
description = (
"Upload an image to apply a realistic background blur effect. "
"The app segments the foreground using RMBG-2.0 and then applies a Gaussian "
"blur (σ=15) to the background, simulating a video conferencing blur effect."
)
iface = gr.Interface(
fn=apply_blur_effect,
inputs=[gr.Image(type="pil", label="Input Image"), gr.Slider(minimum=0, maximum=40, step=1, value=33, label="Middleground"), gr.Slider(minimum=40, maximum=99, step=1, value=66, label="Background")],
outputs=[gr.Image(type="pil", label="Bokeh Image"), gr.Image(type="pil", label="Lens Blur Image")],
title=title,
description=description,
allow_flagging="never"
)
demo = build_interface()
demo.queue(default_concurrency_limit=1)
demo.launch()