johnlockejrr's picture
Update app.py
12a97ec verified
raw
history blame
5.43 kB
from typing import Tuple, Dict
import gradio as gr
import supervision as sv
import numpy as np
import cv2
from huggingface_hub import hf_hub_download
from ultralytics import YOLO
# Define models
MODEL_OPTIONS = {
"YOLOv11-Small": "medieval-yolo11s-seg.pt"
}
# Dictionary to store loaded models
models: Dict[str, YOLO] = {}
# Load all models
for name, model_file in MODEL_OPTIONS.items():
try:
model_path = hf_hub_download(
repo_id="johnlockejrr/medieval-manuscript-yolov11-seg",
filename=model_file
)
models[name] = YOLO(model_path)
except Exception as e:
print(f"Error loading model {name}: {str(e)}")
# Create annotators
LABEL_ANNOTATOR = sv.LabelAnnotator(text_color=sv.Color.BLACK)
MASK_ANNOTATOR = sv.MaskAnnotator()
def process_masks(masks: np.ndarray, target_shape: Tuple[int, int]) -> np.ndarray:
"""Process and resize masks to target shape"""
if masks is None:
return None
processed_masks = []
h, w = target_shape
for mask in masks:
# Resize mask to target dimensions
resized_mask = cv2.resize(mask.astype(float), (w, h), interpolation=cv2.INTER_LINEAR)
# Threshold to create binary mask
processed_masks.append(resized_mask > 0.5)
return np.array(processed_masks)
def detect_and_annotate(
image: np.ndarray,
model_name: str,
conf_threshold: float,
iou_threshold: float
) -> np.ndarray:
try:
if image is None:
return None
model = models.get(model_name)
if model is None:
raise ValueError(f"Model {model_name} not loaded")
# Perform inference
results = model.predict(
image,
conf=conf_threshold,
iou=iou_threshold
)[0]
# Convert results to supervision Detections
boxes = results.boxes.xyxy.cpu().numpy()
confidence = results.boxes.conf.cpu().numpy()
class_ids = results.boxes.cls.cpu().numpy().astype(int)
# Process masks
masks = None
if results.masks is not None:
masks = results.masks.data.cpu().numpy()
masks = np.moveaxis(masks, 0, -1) # Change from (N,H,W) to (H,W,N)
masks = process_masks(masks, image.shape[:2])
# Create Detections object
detections = sv.Detections(
xyxy=boxes,
confidence=confidence,
class_id=class_ids,
mask=masks
)
# Create labels
labels = [
f"{results.names[class_id]} ({conf:.2f})"
for class_id, conf in zip(class_ids, confidence)
]
# Annotate image
annotated_image = image.copy()
if masks is not None:
annotated_image = MASK_ANNOTATOR.annotate(
scene=annotated_image,
detections=detections
)
annotated_image = LABEL_ANNOTATOR.annotate(
scene=annotated_image,
detections=detections,
labels=labels
)
return annotated_image
except Exception as e:
print(f"Error during detection: {str(e)}")
return image # Return original image on error
# Create Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Medieval Manuscript Segmentation with YOLO")
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Input Image", type='numpy')
with gr.Accordion("Detection Settings", open=True):
model_selector = gr.Dropdown(
choices=list(MODEL_OPTIONS.keys()),
value=list(MODEL_OPTIONS.keys())[0],
label="Model"
)
conf_threshold = gr.Slider(
label="Confidence Threshold",
minimum=0.0,
maximum=1.0,
step=0.05,
value=0.25
)
iou_threshold = gr.Slider(
label="IoU Threshold",
minimum=0.0,
maximum=1.0,
step=0.05,
value=0.45
)
detect_btn = gr.Button("Detect", variant="primary")
clear_btn = gr.Button("Clear")
with gr.Column():
output_image = gr.Image(label="Segmentation Result", type='numpy')
def process_image(image, model_name, conf_threshold, iou_threshold):
try:
if image is None:
return None, None
annotated_image = detect_and_annotate(image, model_name, conf_threshold, iou_threshold)
return image, annotated_image
except Exception as e:
print(f"Error in process_image: {str(e)}")
return image, image # Fallback to original image
def clear():
return None, None
detect_btn.click(
process_image,
inputs=[input_image, model_selector, conf_threshold, iou_threshold],
outputs=[input_image, output_image]
)
clear_btn.click(
clear,
inputs=None,
outputs=[input_image, output_image]
)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
show_error=True,
debug=True
)