Spaces:
Runtime error
Runtime error
import warnings | |
warnings.simplefilter("ignore", UserWarning) | |
from uuid import uuid4 | |
from laia.scripts.htr.decode_ctc import run as decode | |
from laia.common.arguments import CommonArgs, DataArgs, TrainerArgs, DecodeArgs | |
import sys | |
from tempfile import NamedTemporaryFile, mkdtemp | |
from pathlib import Path | |
from contextlib import redirect_stdout | |
import re | |
from PIL import Image | |
from bidi.algorithm import get_display | |
import multiprocessing | |
from ultralytics import YOLO | |
import cv2 | |
import numpy as np | |
import pandas as pd | |
import logging | |
from typing import List, Optional, Tuple, Dict | |
from huggingface_hub import hf_hub_download | |
import gradio as gr | |
import supervision as sv | |
import os | |
import spaces | |
import torch | |
# Define models | |
MODEL_OPTIONS = { | |
"YOLOv11-Nano": "medieval-yolov11n.pt", | |
"YOLOv11-Small": "medieval-yolov11s.pt", | |
"YOLOv11-Medium": "medieval-yolov11m.pt", | |
"YOLOv11-Large": "medieval-yolov11l.pt", | |
"YOLOv11-XLarge": "medieval-yolov11x.pt" | |
} | |
# Dictionary to store loaded models | |
models: Dict[str, YOLO] = {} | |
# Load all models | |
for name, model_file in MODEL_OPTIONS.items(): | |
model_path = hf_hub_download( | |
repo_id="biglam/medieval-manuscript-yolov11", | |
filename=model_file | |
) | |
models[name] = YOLO(model_path) | |
# Configure logging | |
logging.getLogger("lightning.pytorch").setLevel(logging.ERROR) | |
# Load YOLOv8 model | |
model = YOLO(model_path) | |
images = Path(mkdtemp()) | |
DEFAULT_HEIGHT = 128 | |
TEXT_DIRECTION = "LTR" | |
NUM_WORKERS = multiprocessing.cpu_count() | |
# Regex pattern for extracting results | |
IMAGE_ID_PATTERN = r"(?P<image_id>[-a-z0-9]{36})" | |
CONFIDENCE_PATTERN = r"(?P<confidence>[0-9.]+)" # For line | |
TEXT_PATTERN = r"\s*(?P<text>.*)\s*" | |
LINE_PREDICTION = re.compile(rf"{IMAGE_ID_PATTERN} {CONFIDENCE_PATTERN} {TEXT_PATTERN}") | |
# Create annotators | |
LABEL_ANNOTATOR = sv.LabelAnnotator(text_color=sv.Color.BLACK) | |
BOX_ANNOTATOR = sv.BoxAnnotator() | |
# Select device | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
def get_width(image, height=DEFAULT_HEIGHT): | |
aspect_ratio = image.width / image.height | |
return height * aspect_ratio | |
def simplify_polygons(polygons: List[np.ndarray], approx_level: float = 0.01) -> List[Optional[np.ndarray]]: | |
"""Simplify polygon contours using Douglas-Peucker algorithm. | |
Args: | |
polygons: List of polygon contours | |
approx_level: Approximation level (0-1), lower values mean more simplification | |
Returns: | |
List of simplified polygons (or None for invalid polygons) | |
""" | |
result = [] | |
for polygon in polygons: | |
if len(polygon) < 4: | |
result.append(None) | |
continue | |
perimeter = cv2.arcLength(polygon, True) | |
approx = cv2.approxPolyDP(polygon, approx_level * perimeter, True) | |
if len(approx) < 4: | |
result.append(None) | |
continue | |
result.append(approx.squeeze()) | |
return result | |
def predict_text(input_img): | |
"""PyLaia text recognition function""" | |
model_dir = 'catmus-medieval' | |
temperature = 2.0 | |
batch_size = 1 | |
weights_path = f"{model_dir}/weights.ckpt" | |
syms_path = f"{model_dir}/syms.txt" | |
language_model_params = {"language_model_weight": 1.0} | |
use_language_model = True | |
if use_language_model: | |
language_model_params.update({ | |
"language_model_path": f"{model_dir}/language_model.binary", | |
"lexicon_path": f"{model_dir}/lexicon.txt", | |
"tokens_path": f"{model_dir}/tokens.txt", | |
}) | |
common_args = CommonArgs( | |
checkpoint="weights.ckpt", | |
train_path=f"{model_dir}", | |
experiment_dirname="", | |
) | |
data_args = DataArgs(batch_size=batch_size, color_mode="L") | |
trainer_args = TrainerArgs(progress_bar_refresh_rate=0) | |
decode_args = DecodeArgs( | |
include_img_ids=True, | |
join_string="", | |
convert_spaces=True, | |
print_line_confidence_scores=True, | |
print_word_confidence_scores=False, | |
temperature=temperature, | |
use_language_model=use_language_model, | |
**language_model_params, | |
) | |
with NamedTemporaryFile() as pred_stdout, NamedTemporaryFile() as img_list: | |
image_id = uuid4() | |
input_img = input_img.resize((int(get_width(input_img)), DEFAULT_HEIGHT)) | |
input_img.save(f"{images}/{image_id}.jpg") | |
Path(img_list.name).write_text("\n".join([str(image_id)])) | |
with redirect_stdout(open(pred_stdout.name, mode="w")): | |
decode( | |
syms=str(syms_path), | |
img_list=img_list.name, | |
img_dirs=[str(images)], | |
common=common_args, | |
data=data_args, | |
trainer=trainer_args, | |
decode=decode_args, | |
num_workers=1, | |
) | |
sys.stdout.flush() | |
predictions = Path(pred_stdout.name).read_text().strip().splitlines() | |
_, score, text = LINE_PREDICTION.match(predictions[0]).groups() | |
return text, float(score) | |
def detect_and_recognize(image, model_name, conf_threshold, iou_threshold): | |
if image is None: | |
return None, "" | |
# Get model path | |
model_path = hf_hub_download( | |
repo_id="biglam/medieval-manuscript-yolov11", | |
filename=MODEL_OPTIONS[model_name] | |
) | |
# Load model | |
model = YOLO(model_path) | |
# Perform inference | |
results = model.predict( | |
image, | |
conf=conf_threshold, | |
iou=iou_threshold, | |
classes=0, | |
device=device | |
)[0] | |
# Convert results to supervision Detections | |
boxes = results.boxes.xyxy.cpu().numpy() | |
confidence = results.boxes.conf.cpu().numpy() | |
class_ids = results.boxes.cls.cpu().numpy().astype(int) | |
# Sort boxes by y-coordinate | |
sorted_indices = np.argsort(boxes[:, 1]) | |
boxes = boxes[sorted_indices] | |
confidence = confidence[sorted_indices] | |
# Create Detections object | |
detections = sv.Detections( | |
xyxy=boxes, | |
confidence=confidence, | |
class_id=class_ids | |
) | |
# Create labels | |
labels = [ | |
f"Line {i+1} ({conf:.2f})" | |
for i, conf in enumerate(confidence) | |
] | |
# Annotate image | |
annotated_image = image.copy() | |
annotated_image = BOX_ANNOTATOR.annotate(scene=annotated_image, detections=detections) | |
annotated_image = LABEL_ANNOTATOR.annotate(scene=annotated_image, detections=detections, labels=labels) | |
# Create text summary | |
text_summary = "\n".join([f"Line {i+1}: Confidence {conf:.2f}" for i, conf in enumerate(confidence)]) | |
return annotated_image, text_summary | |
def gradio_reset(): | |
return None, None, "" | |
if __name__ == "__main__": | |
print(f"Using device: {device}") | |
with gr.Blocks() as demo: | |
gr.Markdown("# Medieval Manuscript Text Detection") | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image( | |
label="Input Image", | |
type="numpy" | |
) | |
with gr.Accordion("Detection Settings", open=True): | |
model_selector = gr.Dropdown( | |
choices=list(MODEL_OPTIONS.keys()), | |
value=list(MODEL_OPTIONS.keys())[0], | |
label="Model", | |
info="Select YOLO model variant" | |
) | |
with gr.Row(): | |
conf_threshold = gr.Slider( | |
label="Confidence Threshold", | |
minimum=0.0, | |
maximum=1.0, | |
step=0.05, | |
value=0.25, | |
) | |
iou_threshold = gr.Slider( | |
label="IoU Threshold", | |
minimum=0.0, | |
maximum=1.0, | |
step=0.05, | |
value=0.45, | |
) | |
with gr.Row(): | |
clear_btn = gr.Button("Clear") | |
detect_btn = gr.Button("Detect", variant="primary") | |
with gr.Column(): | |
output_image = gr.Image( | |
label="Detection Result", | |
type="numpy" | |
) | |
text_output = gr.Textbox( | |
label="Detection Summary", | |
lines=10 | |
) | |
# Connect buttons to functions | |
detect_btn.click( | |
detect_and_recognize, | |
inputs=[input_image, model_selector, conf_threshold, iou_threshold], | |
outputs=[output_image, text_output] | |
) | |
clear_btn.click( | |
gradio_reset, | |
inputs=None, | |
outputs=[input_image, output_image, text_output] | |
) | |
demo.launch(server_name="0.0.0.0", server_port=7860, debug=True) | |