Image_Stitching / app.py
basab1142's picture
Update app.py
d1aedac verified
import streamlit as st
import cv2
import numpy as np
from PIL import Image
import time
from streamlit_drawable_canvas import st_canvas
import matplotlib.pylab as plt
from estimate_homography import calculate_homography, fit_image_in_target_space
stitched_image_rgb, stitched_result = None, None
# Function to load an image from uploaded file
def load_image(uploaded_file):
img = cv2.imdecode(np.frombuffer(uploaded_file.read(), np.uint8), cv2.IMREAD_GRAYSCALE)
return img
# Function to compute stereo vision and disparity map
def compute_stereo_vision(img1, img2):
# Feature Detection and Matching using ORB (ORB is a good alternative for uncalibrated cameras)
orb = cv2.ORB_create() # ORB is a good alternative to SIFT for uncalibrated cameras
kp1, des1 = orb.detectAndCompute(img1, None)
kp2, des2 = orb.detectAndCompute(img2, None)
# BFMatcher with default params
bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True)
matches = bf.match(des1, des2)
# Sort matches by distance
matches = sorted(matches, key=lambda x: x.distance)
# Estimate the Fundamental Matrix
pts1 = np.array([kp1[m.queryIdx].pt for m in matches])
pts2 = np.array([kp2[m.trainIdx].pt for m in matches])
# Fundamental matrix using RANSAC to reject outliers
F, mask = cv2.findFundamentalMat(pts1, pts2, cv2.FM_RANSAC)
# Estimate the Camera Pose (Rotation and Translation)
K = np.eye(3) # Assuming no camera calibration
E = K.T @ F @ K # Essential matrix
_, R, T, _ = cv2.recoverPose(E, pts1, pts2)
# Stereo Rectification
stereo_rectify = cv2.stereoRectify(K, None, K, None, img1.shape[::-1], R, T, alpha=0)
left_map_x, left_map_y = cv2.initUndistortRectifyMap(K, None, R, K, img1.shape[::-1], cv2.CV_32F)
right_map_x, right_map_y = cv2.initUndistortRectifyMap(K, None, R, K, img2.shape[::-1], cv2.CV_32F)
# Apply the rectification transformations to the images
img1_rectified = cv2.remap(img1, left_map_x, left_map_y, interpolation=cv2.INTER_LINEAR)
img2_rectified = cv2.remap(img2, right_map_x, right_map_y, interpolation=cv2.INTER_LINEAR)
# Resize img2_rectified to match img1_rectified size (if necessary)
if img1_rectified.shape != img2_rectified.shape:
img2_rectified = cv2.resize(img2_rectified, (img1_rectified.shape[1], img1_rectified.shape[0]))
# Disparity Map Computation using StereoBM
stereo = cv2.StereoBM_create(numDisparities=16, blockSize=15)
disparity = stereo.compute(img1_rectified, img2_rectified)
return disparity, img1_rectified, img2_rectified
def run_point_est(world_pts, img_pts, img):
if isinstance(img_pts, list):
img_pts = np.array(img_pts)
if isinstance(world_pts, list):
world_pts = np.array(world_pts)
# Plot the original image with marked points
st.write("Original Image with Points")
plt.figure()
plt.imshow(img)
plt.scatter(img_pts[:, 0], img_pts[:, 1], color='red')
plt.axis("off")
plt.title("Original Image with img points marked in red")
st.pyplot(plt)
H = calculate_homography(img_pts, world_pts) # img_pts = H * world_pts
#### Cross check ####
t_one = np.ones((img_pts.shape[0], 1))
t_out_pts = np.concatenate((world_pts, t_one), axis=1)
x = np.matmul(H, t_out_pts.T)
x = x / x[-1, :]
st.write("Given Image Points:", img_pts)
st.write("Calculated Image Points:", x.T)
st.write("Homography Matrix (OpenCV):", cv2.findHomography(world_pts, img_pts)[0])
st.write("Calculated Homography Matrix:", H)
#####################
h, w, _ = img.shape
corners_img = np.array([[0, 0], [w, 0], [w, h], [0, h]])
H_inv = np.linalg.inv(H)
t_out_pts = np.concatenate((corners_img, t_one), axis=1)
world_crd_corners = np.matmul(H_inv, t_out_pts.T)
world_crd_corners = world_crd_corners / world_crd_corners[-1, :] # Normalize
min_crd = np.amin(world_crd_corners.T, axis=0)
max_crd = np.amax(world_crd_corners.T, axis=0)
offset = min_crd.astype(np.int64)
offset[2] = 0
width_world = np.ceil(max_crd - min_crd)[0] + 1
height_world = np.ceil(max_crd - min_crd)[1] + 1
world_img = np.zeros((int(height_world), int(width_world), 3), dtype=np.uint8)
mask = np.ones((int(height_world), int(width_world)))
out = fit_image_in_target_space(img, world_img, mask, H, offset)
st.write("Corrected Image")
plt.figure()
plt.imshow(out)
plt.axis("off")
plt.title("Corrected Image with Point Point Correspondence")
st.pyplot(plt)
# Function to stitch images
def stitch_images(images):
stitcher = cv2.Stitcher_create() if cv2.__version__.startswith('4') else cv2.createStitcher()
status, stitched_image = stitcher.stitch(images)
if status == cv2.Stitcher_OK:
return stitched_image, status
else:
return None, status
# Function to match features
def match_features(images):
if len(images) < 2:
return None, "At least two images are required for feature matching."
gray1 = cv2.cvtColor(images[0], cv2.COLOR_BGR2GRAY)
gray2 = cv2.cvtColor(images[1], cv2.COLOR_BGR2GRAY)
sift = cv2.SIFT_create()
keypoints1, descriptors1 = sift.detectAndCompute(gray1, None)
keypoints2, descriptors2 = sift.detectAndCompute(gray2, None)
bf = cv2.BFMatcher(cv2.NORM_L2, crossCheck=True)
matches = bf.match(descriptors1, descriptors2)
matches = sorted(matches, key=lambda x: x.distance)
matched_image = cv2.drawMatches(images[0], keypoints1, images[1], keypoints2, matches[:50], None, flags=cv2.DrawMatchesFlags_NOT_DRAW_SINGLE_POINTS)
return matched_image, None
# Function to cartoonify an image
def cartoonify_image(image):
# Convert to grayscale
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
gray_blur = cv2.medianBlur(gray, 7)
edges = cv2.adaptiveThreshold(
gray_blur, 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, 9, 10
)
color = cv2.bilateralFilter(image, 9, 250, 250)
cartoon = cv2.bitwise_and(color, color, mask=edges)
return cartoon
# Streamlit layout and UI
st.set_page_config(page_title="Image Stitching and Feature Matching", layout="wide")
st.title("Image Stitching and Feature Matching Application")
# State to store captured images
if "captured_images" not in st.session_state:
st.session_state["captured_images"] = []
if "stitched_image" not in st.session_state:
st.session_state["stitched_image"] = None
# Sidebar for displaying captured images
st.sidebar.header("Captured Images")
if st.session_state["captured_images"]:
placeholder = st.sidebar.empty()
with placeholder.container():
for i, img in enumerate(st.session_state["captured_images"]):
img_thumbnail = cv2.resize(img, (100, 100))
st.image(cv2.cvtColor(img_thumbnail, cv2.COLOR_BGR2RGB), caption=f"Image {i+1}", use_container_width =False)
if st.button(f"Delete Image {i+1}", key=f"delete_{i}"):
st.session_state["captured_images"].pop(i)
placeholder.empty() # Clear and refresh the sidebar
break
# Capture the image from camera input
st.header("Upload or Capture Images")
uploaded_files = st.file_uploader("Upload images", type=["jpg", "jpeg", "png"], accept_multiple_files=True)
captured_image = st.camera_input("Take a picture using your camera")
if st.button("Add Captured Image"):
if captured_image:
captured_image_array = cv2.cvtColor(np.array(Image.open(captured_image)), cv2.COLOR_RGB2BGR)
st.session_state["captured_images"].append(captured_image_array)
st.success(f"Captured image {len(st.session_state['captured_images'])} added!")
# Combine uploaded and captured images
images = [cv2.cvtColor(np.array(Image.open(file)), cv2.COLOR_RGB2BGR) for file in uploaded_files]
images.extend(st.session_state["captured_images"])
st.write(f"Total images: {len(images)}")
# Placeholder for dynamic updates
loading_placeholder = st.empty()
# Function to show the loading animation
def show_loading_bar(placeholder):
with placeholder:
st.write("Processing images... Please wait.")
time.sleep(2)
if st.button("Stitch Images"):
if len(images) < 2:
st.error("Please provide at least two images for stitching.")
else:
show_loading_bar(loading_placeholder)
stitched_result, status = stitch_images(images)
loading_placeholder.empty()
if stitched_result is not None:
stitched_image_rgb = cv2.cvtColor(stitched_result, cv2.COLOR_BGR2RGB)
st.image(stitched_image_rgb, caption="Stitched Image", use_container_width=True)
st.session_state["stitched_image"] = stitched_image_rgb
st.success("Stitching completed successfully!")
else:
st.error(f"Stitching failed with status: {status}.")
# Always display the stitched image if it exists in the session state
if "stitched_image" in st.session_state and st.session_state["stitched_image"] is not None:
st.header("Stitched Image")
st.image(st.session_state["stitched_image"], caption="Stitched Image", use_container_width=True)
if st.button("Show Matching Features"):
if len(images) < 2:
st.error("Please provide at least two images for feature matching.")
else:
show_loading_bar(loading_placeholder)
matched_image, error = match_features(images)
loading_placeholder.empty()
if matched_image is not None:
matched_image_rgb = cv2.cvtColor(matched_image, cv2.COLOR_BGR2RGB)
st.image(matched_image_rgb, caption="Feature Matching Visualization", use_container_width=True)
st.success("Feature matching completed successfully!")
else:
st.error(error)
if st.session_state["stitched_image"] is not None:
st.header("Homography Transformation on Stitched Image")
st.write("### Select Points on Stitched Image")
stitched_image = st.session_state["stitched_image"]
image = Image.fromarray(cv2.cvtColor(stitched_image, cv2.COLOR_BGR2RGB))
canvas_result = st_canvas(
fill_color="rgba(255, 0, 0, 0.3)",
stroke_width=3,
background_image=image,
update_streamlit=True,
drawing_mode="point",
height=image.height,
width=image.width,
key="canvas",
)
img_pts = []
if canvas_result.json_data is not None:
for obj in canvas_result.json_data["objects"]:
if obj["type"] == "circle":
x = obj["left"] + obj["width"] / 2
y = obj["top"] + obj["height"] / 2
img_pts.append([int(x), int(y)])
if img_pts:
st.write("### Selected Image Points")
st.write(img_pts)
st.write("### Enter Corresponding World Points")
world_pts = st.text_area(
"Enter world points as a list of tuples (e.g., [(0, 0), (300, 0), (0, 400), (300, 400)])",
value="[(0, 0), (300, 0), (0, 400), (300, 400)]",
)
if st.button("Run Homography Transformation"):
try:
world_pts = eval(world_pts)
if len(world_pts) != len(img_pts):
st.error("The number of world points must match the number of image points.")
else:
run_point_est(world_pts, img_pts, stitched_image)
except Exception as e:
st.error(f"Error: {e}")
if "stitched_image" in st.session_state:
st.header("Cartoonify & Do Homography on Your Stitched Image")
if st.button("Cartoonify Stitched Image"):
cartoon = cartoonify_image(cv2.cvtColor(st.session_state["stitched_image"], cv2.COLOR_RGB2BGR))
st.image(cv2.cvtColor(cartoon, cv2.COLOR_BGR2RGB), caption="Cartoonified Image", use_container_width=True)
st.success("Cartoonification completed successfully!")
# Upload images
st.subheader("Upload Left and Right Images")
left_image_file = st.file_uploader("Choose the Left Image", type=["jpg", "png", "jpeg"])
right_image_file = st.file_uploader("Choose the Right Image", type=["jpg", "png", "jpeg"])
# Check if both images are uploaded
if left_image_file and right_image_file:
# Load the uploaded images
img1 = load_image(left_image_file)
img2 = load_image(right_image_file)
# Display the uploaded images
st.image(img1, caption="Left Image", use_container_width =True)
st.image(img2, caption="Right Image", use_container_width =True)
# Compute the stereo vision and disparity map
disparity, img1_rectified, img2_rectified = compute_stereo_vision(img1, img2)
# Display the rectified images
# st.subheader("Rectified Left Image")
# st.image(img1_rectified, caption="Rectified Left Image", use_container_width =True)
# st.subheader("Rectified Right Image")
# st.image(img2_rectified, caption="Rectified Right Image", use_container_width =True)
# Show the disparity map
fig, ax = plt.subplots()
st.subheader("Disparity Map")
plt.imshow(disparity, cmap='gray')
plt.title("Disparity Map")
plt.colorbar()
st.pyplot(fig)
# # Optionally: Display an anaglyph or combined view of the images
# anaglyph = cv2.merge([img1_rectified, np.zeros_like(img1_rectified), img2_rectified])
# st.subheader("Anaglyph Stereo View")
# st.image(anaglyph, caption="Anaglyph Stereo View", use_container_width =True)
# if "img_pts" not in st.session_state:
# st.session_state["img_pts"] = []
# if "world_pts" not in st.session_state:
# st.session_state["world_pts"] = []
# if "homography_ready" not in st.session_state:
# st.session_state["homography_ready"] = False
# if st.button('Homography Transformation'):
# if st.session_state["stitched_image"] is not None:
# st.write("### Select Points on Stitched Image")
# stitched_image = st.session_state["stitched_image"]
# image = Image.fromarray(cv2.cvtColor(stitched_image, cv2.COLOR_BGR2RGB))
# # Display canvas for selecting points
# canvas_result = st_canvas(
# fill_color="rgba(255, 0, 0, 0.3)",
# stroke_width=3,
# background_image=image,
# update_streamlit=True,
# drawing_mode="point",
# height=image.height,
# width=image.width,
# key="canvas",
# )
# # Collect selected points
# if canvas_result.json_data is not None:
# img_pts_temp = []
# for obj in canvas_result.json_data["objects"]:
# if obj["type"] == "circle":
# x = obj["left"] + obj["width"] / 2
# y = obj["top"] + obj["height"] / 2
# img_pts_temp.append([int(x), int(y)])
# # Only update points if there are new ones
# if img_pts_temp:
# st.session_state["img_pts"] = img_pts_temp
# # Display the selected points
# if st.session_state["img_pts"]:
# st.write("### Selected Image Points")
# st.write(st.session_state["img_pts"])
# # Input world points
# world_pts_input = st.text_area(
# "Enter world points as a list of tuples (e.g., [(0, 0), (300, 0), (0, 400), (300, 400)])",
# value="[(0, 0), (300, 0), (0, 400), (300, 400)]",
# )
# if st.button("Confirm Points and Run Homography"):
# try:
# st.session_state["world_pts"] = eval(world_pts_input)
# if len(st.session_state["world_pts"]) != len(st.session_state["img_pts"]):
# st.error("The number of world points must match the number of image points.")
# else:
# st.session_state["homography_ready"] = True
# st.success("Points confirmed! Ready for homography transformation.")
# except Exception as e:
# st.error(f"Error parsing world points: {e}")
# # Perform homography transformation
# if st.session_state.get("homography_ready"):
# st.write("### Running Homography Transformation...")
# try:
# run_point_est(
# st.session_state["world_pts"],
# st.session_state["img_pts"],
# st.session_state["stitched_image"],
# )
# st.session_state["homography_ready"] = False # Reset the flag after execution
# except Exception as e:
# st.error(f"Error during homography transformation: {e}")