|
from typing import Dict, List |
|
import os |
|
import sys |
|
import glob |
|
import argparse |
|
import datetime |
|
import shutil |
|
import numpy as np |
|
import cv2 |
|
from PIL import Image |
|
from ultralytics import YOLO |
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
import xml.etree.ElementTree as ET |
|
from xml.dom import minidom |
|
|
|
|
|
MODEL_OPTIONS = { |
|
"YOLOv11-Nano": "yolov11n-seg.pt", |
|
"YOLOv11-Small": "yolov11s-seg.pt", |
|
"YOLOv11-Medium": "yolov11m-seg.pt", |
|
"YOLOv11-Large": "yolov11l-seg.pt", |
|
"YOLOv11-XLarge": "yolov11x-seg.pt" |
|
} |
|
|
|
|
|
models: Dict[str, YOLO] = {} |
|
|
|
|
|
def load_model(model_name: str = "YOLOv11-Nano") -> YOLO: |
|
if model_name not in models: |
|
model_file = MODEL_OPTIONS[model_name] |
|
model_path = hf_hub_download( |
|
repo_id="wjbmattingly/kraken-yiddish", |
|
filename=model_file |
|
) |
|
models[model_name] = YOLO(model_path) |
|
return models[model_name] |
|
|
|
def process_image( |
|
image_path: str, |
|
model_name: str = "YOLOv11-Medium", |
|
conf_threshold: float = 0.25, |
|
iou_threshold: float = 0.45 |
|
) -> tuple: |
|
"""Process an image and return detection results and annotated image""" |
|
|
|
|
|
image = cv2.imread(image_path) |
|
if image is None: |
|
raise ValueError(f"Cannot read image: {image_path}") |
|
|
|
|
|
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
height, width = image.shape[:2] |
|
|
|
|
|
model = load_model(model_name) |
|
|
|
|
|
results = model( |
|
image_rgb, |
|
conf=conf_threshold, |
|
iou=iou_threshold, |
|
verbose=False, |
|
device='cpu' |
|
) |
|
|
|
|
|
result = results[0] |
|
|
|
|
|
annotated_image = result.plot( |
|
conf=True, |
|
line_width=None, |
|
font_size=None, |
|
boxes=True, |
|
masks=True, |
|
probs=True, |
|
labels=True |
|
) |
|
|
|
|
|
annotated_image = cv2.cvtColor(annotated_image, cv2.COLOR_RGB2BGR) |
|
|
|
return result, annotated_image, width, height |
|
|
|
def create_page_xml( |
|
image_filename: str, |
|
result, |
|
width: int, |
|
height: int |
|
) -> str: |
|
"""Create PAGE XML structure from YOLO results""" |
|
|
|
|
|
root = ET.Element("PcGts", { |
|
"xmlns": "http://schema.primaresearch.org/PAGE/gts/pagecontent/2019-07-15", |
|
"xmlns:xsi": "http://www.w3.org/2001/XMLSchema-instance", |
|
"xsi:schemaLocation": "http://schema.primaresearch.org/PAGE/gts/pagecontent/2019-07-15 http://schema.primaresearch.org/PAGE/gts/pagecontent/2019-07-15/pagecontent.xsd" |
|
}) |
|
|
|
|
|
metadata = ET.SubElement(root, "Metadata") |
|
ET.SubElement(metadata, "Creator").text = "escriptorium" |
|
|
|
|
|
future_date = (datetime.datetime.now() + datetime.timedelta(days=365)).isoformat() |
|
ET.SubElement(metadata, "Created").text = future_date |
|
ET.SubElement(metadata, "LastChange").text = future_date |
|
|
|
|
|
page = ET.SubElement(root, "Page", { |
|
"imageFilename": os.path.basename(image_filename), |
|
"imageWidth": str(width), |
|
"imageHeight": str(height) |
|
}) |
|
|
|
|
|
has_valid_masks = False |
|
|
|
if hasattr(result, 'masks') and result.masks is not None: |
|
masks = result.masks.xy |
|
|
|
|
|
|
|
timestamp = int(datetime.datetime.now().timestamp()) |
|
main_region_id = f"eSc_textblock_TextRegion_{timestamp}" |
|
|
|
|
|
all_points_x = [] |
|
all_points_y = [] |
|
valid_masks = [] |
|
|
|
|
|
for mask_points in masks: |
|
|
|
valid_points = [(p[0], p[1]) for p in mask_points if not (np.isnan(p[0]) or np.isnan(p[1]))] |
|
|
|
if valid_points and len(valid_points) >= 3: |
|
valid_masks.append(valid_points) |
|
all_points_x.extend([p[0] for p in valid_points]) |
|
all_points_y.extend([p[1] for p in valid_points]) |
|
has_valid_masks = True |
|
|
|
|
|
if has_valid_masks and all_points_x and all_points_y: |
|
min_x = max(0, int(min(all_points_x))) |
|
max_x = min(width, int(max(all_points_x))) |
|
min_y = max(0, int(min(all_points_y))) |
|
max_y = min(height, int(max(all_points_y))) |
|
|
|
|
|
main_text_region = ET.SubElement(page, "TextRegion", { |
|
"id": main_region_id, |
|
"custom": "structure {type:text_zone;}" |
|
}) |
|
|
|
|
|
region_points = f"{min_x},{min_y} {max_x},{min_y} {max_x},{max_y} {min_x},{max_y}" |
|
ET.SubElement(main_text_region, "Coords", {"points": region_points}) |
|
|
|
|
|
for i, valid_points in enumerate(valid_masks): |
|
|
|
line_id = f"eSc_line_r2l{i+1}" if i > 0 else "eSc_line_line_1610719743362_3154" |
|
text_line = ET.SubElement(main_text_region, "TextLine", { |
|
"id": line_id, |
|
"custom": "structure {type:text_line;}" |
|
}) |
|
|
|
|
|
|
|
points_str = " ".join([f"{int(p[0])},{int(p[1])}" for p in valid_points]) |
|
|
|
|
|
line_coords = ET.SubElement(text_line, "Coords", { |
|
"points": points_str |
|
}) |
|
|
|
|
|
|
|
points_by_x = sorted(valid_points, key=lambda p: p[0]) |
|
leftmost_point = points_by_x[0] |
|
rightmost_point = points_by_x[-1] |
|
|
|
|
|
sorted_by_y = sorted(valid_points, key=lambda p: p[1]) |
|
|
|
|
|
bottom_third_index = max(0, int(len(sorted_by_y) * 0.67)) |
|
bottom_points = sorted_by_y[bottom_third_index:] |
|
|
|
if not bottom_points: |
|
bottom_points = sorted_by_y |
|
|
|
|
|
avg_y = sum(p[1] for p in bottom_points) / len(bottom_points) |
|
|
|
|
|
left_x = leftmost_point[0] |
|
right_x = rightmost_point[0] |
|
|
|
|
|
baseline_str = f"{int(left_x)},{int(avg_y)} {int(right_x)},{int(avg_y)}" |
|
|
|
|
|
baseline = ET.SubElement(text_line, "Baseline", { |
|
"points": baseline_str |
|
}) |
|
|
|
|
|
text_equiv = ET.SubElement(text_line, "TextEquiv") |
|
ET.SubElement(text_equiv, "Unicode") |
|
|
|
|
|
|
|
left_region = ET.SubElement(page, "TextRegion", { |
|
"id": f"eSc_textblock_r1", |
|
"custom": "structure {type:text_zone;}" |
|
}) |
|
|
|
|
|
left_region_points = f"0,0 {min_x-10},{min_y} {min_x-10},{max_y} 0,{max_y}" |
|
ET.SubElement(left_region, "Coords", {"points": left_region_points}) |
|
|
|
|
|
if not has_valid_masks: |
|
print("Warning: No valid masks detected. Creating a default text region.") |
|
default_region = ET.SubElement(page, "TextRegion", { |
|
"id": f"eSc_textblock_default_{int(datetime.datetime.now().timestamp())}", |
|
"custom": "structure {type:text_zone;}" |
|
}) |
|
default_points = f"0,0 {width},0 {width},{height} 0,{height}" |
|
ET.SubElement(default_region, "Coords", {"points": default_points}) |
|
|
|
|
|
xmlstr = minidom.parseString(ET.tostring(root)).toprettyxml(indent=" ") |
|
|
|
return xmlstr |
|
|
|
def save_results(image_path: str, annotated_image: np.ndarray, xml_content: str): |
|
"""Save the original image to output/ and XML file to annotations/ directory""" |
|
|
|
|
|
output_dir = "output" |
|
annotations_dir = "annotations" |
|
os.makedirs(output_dir, exist_ok=True) |
|
os.makedirs(annotations_dir, exist_ok=True) |
|
|
|
|
|
base_name = os.path.basename(image_path) |
|
file_name_no_ext = os.path.splitext(base_name)[0] |
|
|
|
|
|
output_image_path = os.path.join(output_dir, f"{file_name_no_ext}.jpg") |
|
|
|
shutil.copy(image_path, output_image_path) |
|
|
|
|
|
output_xml_path = os.path.join(annotations_dir, f"{file_name_no_ext}.xml") |
|
with open(output_xml_path, "w", encoding="utf-8") as f: |
|
f.write(xml_content) |
|
|
|
print(f"Results saved to:") |
|
print(f" Image: {output_image_path}") |
|
print(f" XML: {output_xml_path}") |
|
|
|
def main(): |
|
parser = argparse.ArgumentParser(description="Convert YOLO segmentation to PAGE XML format") |
|
parser.add_argument("image_path", help="Path to the input image or directory of images") |
|
parser.add_argument("--model", default="YOLOv11-Medium", choices=MODEL_OPTIONS.keys(), |
|
help="Model to use for detection") |
|
parser.add_argument("--conf", type=float, default=0.25, |
|
help="Confidence threshold for detection") |
|
parser.add_argument("--iou", type=float, default=0.45, |
|
help="IoU threshold for detection") |
|
parser.add_argument("--batch", action="store_true", |
|
help="Process all images in the directory if image_path is a directory") |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
if os.path.isdir(args.image_path) and args.batch: |
|
|
|
image_files = [] |
|
for extension in ['.jpg', '.jpeg', '.png', '.tif', '.tiff']: |
|
image_files.extend(glob.glob(os.path.join(args.image_path, f"*{extension}"))) |
|
image_files.extend(glob.glob(os.path.join(args.image_path, f"*{extension.upper()}"))) |
|
|
|
if not image_files: |
|
print(f"No image files found in directory: {args.image_path}") |
|
sys.exit(1) |
|
|
|
print(f"Found {len(image_files)} images to process") |
|
|
|
|
|
for i, image_path in enumerate(image_files): |
|
print(f"Processing {i+1}/{len(image_files)}: {os.path.basename(image_path)}") |
|
try: |
|
|
|
result, annotated_image, width, height = process_image( |
|
image_path, |
|
args.model, |
|
args.conf, |
|
args.iou |
|
) |
|
|
|
|
|
xml_content = create_page_xml(image_path, result, width, height) |
|
|
|
|
|
save_results(image_path, annotated_image, xml_content) |
|
|
|
except Exception as e: |
|
print(f"Error processing {image_path}: {e}") |
|
import traceback |
|
traceback.print_exc() |
|
else: |
|
|
|
try: |
|
|
|
result, annotated_image, width, height = process_image( |
|
args.image_path, |
|
args.model, |
|
args.conf, |
|
args.iou |
|
) |
|
|
|
|
|
xml_content = create_page_xml(args.image_path, result, width, height) |
|
|
|
|
|
save_results(args.image_path, annotated_image, xml_content) |
|
|
|
except Exception as e: |
|
print(f"Error: {e}") |
|
import traceback |
|
traceback.print_exc() |
|
sys.exit(1) |
|
|
|
if __name__ == "__main__": |
|
main() |