wjbmattingly's picture
Upload 12 files
7860b95 verified
import warnings
warnings.simplefilter("ignore", UserWarning)
from uuid import uuid4
from laia.scripts.htr.decode_ctc import run as decode
from laia.common.arguments import CommonArgs, DataArgs, TrainerArgs, DecodeArgs
import sys
from tempfile import NamedTemporaryFile, mkdtemp
from pathlib import Path
from contextlib import redirect_stdout
import re
from PIL import Image
from bidi.algorithm import get_display
import multiprocessing
from ultralytics import YOLO
import cv2
import numpy as np
import pandas as pd
import logging
from typing import List, Optional, Tuple, Dict
from huggingface_hub import hf_hub_download
import gradio as gr
import supervision as sv
import os
import spaces
import torch
# Define models
MODEL_OPTIONS = {
"YOLOv11-Nano": "medieval-yolov11n.pt",
"YOLOv11-Small": "medieval-yolov11s.pt",
"YOLOv11-Medium": "medieval-yolov11m.pt",
"YOLOv11-Large": "medieval-yolov11l.pt",
"YOLOv11-XLarge": "medieval-yolov11x.pt"
}
# Dictionary to store loaded models
models: Dict[str, YOLO] = {}
# Load all models
for name, model_file in MODEL_OPTIONS.items():
model_path = hf_hub_download(
repo_id="biglam/medieval-manuscript-yolov11",
filename=model_file
)
models[name] = YOLO(model_path)
# Configure logging
logging.getLogger("lightning.pytorch").setLevel(logging.ERROR)
# Load YOLOv8 model
model = YOLO(model_path)
images = Path(mkdtemp())
DEFAULT_HEIGHT = 128
TEXT_DIRECTION = "LTR"
NUM_WORKERS = multiprocessing.cpu_count()
# Regex pattern for extracting results
IMAGE_ID_PATTERN = r"(?P<image_id>[-a-z0-9]{36})"
CONFIDENCE_PATTERN = r"(?P<confidence>[0-9.]+)" # For line
TEXT_PATTERN = r"\s*(?P<text>.*)\s*"
LINE_PREDICTION = re.compile(rf"{IMAGE_ID_PATTERN} {CONFIDENCE_PATTERN} {TEXT_PATTERN}")
# Create annotators
LABEL_ANNOTATOR = sv.LabelAnnotator(text_color=sv.Color.BLACK)
BOX_ANNOTATOR = sv.BoxAnnotator()
# Select device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
def get_width(image, height=DEFAULT_HEIGHT):
aspect_ratio = image.width / image.height
return height * aspect_ratio
def simplify_polygons(polygons: List[np.ndarray], approx_level: float = 0.01) -> List[Optional[np.ndarray]]:
"""Simplify polygon contours using Douglas-Peucker algorithm.
Args:
polygons: List of polygon contours
approx_level: Approximation level (0-1), lower values mean more simplification
Returns:
List of simplified polygons (or None for invalid polygons)
"""
result = []
for polygon in polygons:
if len(polygon) < 4:
result.append(None)
continue
perimeter = cv2.arcLength(polygon, True)
approx = cv2.approxPolyDP(polygon, approx_level * perimeter, True)
if len(approx) < 4:
result.append(None)
continue
result.append(approx.squeeze())
return result
def predict_text(input_img):
"""PyLaia text recognition function"""
model_dir = 'catmus-medieval'
temperature = 2.0
batch_size = 1
weights_path = f"{model_dir}/weights.ckpt"
syms_path = f"{model_dir}/syms.txt"
language_model_params = {"language_model_weight": 1.0}
use_language_model = True
if use_language_model:
language_model_params.update({
"language_model_path": f"{model_dir}/language_model.binary",
"lexicon_path": f"{model_dir}/lexicon.txt",
"tokens_path": f"{model_dir}/tokens.txt",
})
common_args = CommonArgs(
checkpoint="weights.ckpt",
train_path=f"{model_dir}",
experiment_dirname="",
)
data_args = DataArgs(batch_size=batch_size, color_mode="L")
trainer_args = TrainerArgs(progress_bar_refresh_rate=0)
decode_args = DecodeArgs(
include_img_ids=True,
join_string="",
convert_spaces=True,
print_line_confidence_scores=True,
print_word_confidence_scores=False,
temperature=temperature,
use_language_model=use_language_model,
**language_model_params,
)
with NamedTemporaryFile() as pred_stdout, NamedTemporaryFile() as img_list:
image_id = uuid4()
input_img = input_img.resize((int(get_width(input_img)), DEFAULT_HEIGHT))
input_img.save(f"{images}/{image_id}.jpg")
Path(img_list.name).write_text("\n".join([str(image_id)]))
with redirect_stdout(open(pred_stdout.name, mode="w")):
decode(
syms=str(syms_path),
img_list=img_list.name,
img_dirs=[str(images)],
common=common_args,
data=data_args,
trainer=trainer_args,
decode=decode_args,
num_workers=1,
)
sys.stdout.flush()
predictions = Path(pred_stdout.name).read_text().strip().splitlines()
_, score, text = LINE_PREDICTION.match(predictions[0]).groups()
return text, float(score)
@spaces.GPU
def detect_and_recognize(image, model_name, conf_threshold, iou_threshold):
if image is None:
return None, ""
# Get model path
model_path = hf_hub_download(
repo_id="biglam/medieval-manuscript-yolov11",
filename=MODEL_OPTIONS[model_name]
)
# Load model
model = YOLO(model_path)
# Perform inference
results = model.predict(
image,
conf=conf_threshold,
iou=iou_threshold,
classes=0,
device=device
)[0]
# Convert results to supervision Detections
boxes = results.boxes.xyxy.cpu().numpy()
confidence = results.boxes.conf.cpu().numpy()
class_ids = results.boxes.cls.cpu().numpy().astype(int)
# Sort boxes by y-coordinate
sorted_indices = np.argsort(boxes[:, 1])
boxes = boxes[sorted_indices]
confidence = confidence[sorted_indices]
# Create Detections object
detections = sv.Detections(
xyxy=boxes,
confidence=confidence,
class_id=class_ids
)
# Create labels
labels = [
f"Line {i+1} ({conf:.2f})"
for i, conf in enumerate(confidence)
]
# Annotate image
annotated_image = image.copy()
annotated_image = BOX_ANNOTATOR.annotate(scene=annotated_image, detections=detections)
annotated_image = LABEL_ANNOTATOR.annotate(scene=annotated_image, detections=detections, labels=labels)
# Create text summary
text_summary = "\n".join([f"Line {i+1}: Confidence {conf:.2f}" for i, conf in enumerate(confidence)])
return annotated_image, text_summary
def gradio_reset():
return None, None, ""
if __name__ == "__main__":
print(f"Using device: {device}")
with gr.Blocks() as demo:
gr.Markdown("# Medieval Manuscript Text Detection")
with gr.Row():
with gr.Column():
input_image = gr.Image(
label="Input Image",
type="numpy"
)
with gr.Accordion("Detection Settings", open=True):
model_selector = gr.Dropdown(
choices=list(MODEL_OPTIONS.keys()),
value=list(MODEL_OPTIONS.keys())[0],
label="Model",
info="Select YOLO model variant"
)
with gr.Row():
conf_threshold = gr.Slider(
label="Confidence Threshold",
minimum=0.0,
maximum=1.0,
step=0.05,
value=0.25,
)
iou_threshold = gr.Slider(
label="IoU Threshold",
minimum=0.0,
maximum=1.0,
step=0.05,
value=0.45,
)
with gr.Row():
clear_btn = gr.Button("Clear")
detect_btn = gr.Button("Detect", variant="primary")
with gr.Column():
output_image = gr.Image(
label="Detection Result",
type="numpy"
)
text_output = gr.Textbox(
label="Detection Summary",
lines=10
)
# Connect buttons to functions
detect_btn.click(
detect_and_recognize,
inputs=[input_image, model_selector, conf_threshold, iou_threshold],
outputs=[output_image, text_output]
)
clear_btn.click(
gradio_reset,
inputs=None,
outputs=[input_image, output_image, text_output]
)
demo.launch(server_name="0.0.0.0", server_port=7860, debug=True)