kraken-yiddish / yolo2xml.py
wjbmattingly's picture
Create yolo2xml.py
bd0e32e verified
from typing import Dict, List
import os
import sys
import glob
import argparse
import datetime
import shutil
import numpy as np
import cv2
from PIL import Image
from ultralytics import YOLO
from huggingface_hub import hf_hub_download
# XML generation imports
import xml.etree.ElementTree as ET
from xml.dom import minidom
# Define models
MODEL_OPTIONS = {
"YOLOv11-Nano": "yolov11n-seg.pt",
"YOLOv11-Small": "yolov11s-seg.pt",
"YOLOv11-Medium": "yolov11m-seg.pt",
"YOLOv11-Large": "yolov11l-seg.pt",
"YOLOv11-XLarge": "yolov11x-seg.pt"
}
# Dictionary to store loaded models
models: Dict[str, YOLO] = {}
# Load specified model or default to Nano
def load_model(model_name: str = "YOLOv11-Nano") -> YOLO:
if model_name not in models:
model_file = MODEL_OPTIONS[model_name]
model_path = hf_hub_download(
repo_id="wjbmattingly/kraken-yiddish",
filename=model_file
)
models[model_name] = YOLO(model_path)
return models[model_name]
def process_image(
image_path: str,
model_name: str = "YOLOv11-Medium",
conf_threshold: float = 0.25,
iou_threshold: float = 0.45
) -> tuple:
"""Process an image and return detection results and annotated image"""
# Read the image
image = cv2.imread(image_path)
if image is None:
raise ValueError(f"Cannot read image: {image_path}")
# Convert BGR to RGB for YOLO
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Get image dimensions
height, width = image.shape[:2]
# Get the selected model
model = load_model(model_name)
# Perform inference with YOLO
results = model(
image_rgb,
conf=conf_threshold,
iou=iou_threshold,
verbose=False,
device='cpu'
)
# Get the first result
result = results[0]
# Create annotated image for visualization
annotated_image = result.plot(
conf=True,
line_width=None,
font_size=None,
boxes=True,
masks=True,
probs=True,
labels=True
)
# Convert back to BGR for saving with OpenCV
annotated_image = cv2.cvtColor(annotated_image, cv2.COLOR_RGB2BGR)
return result, annotated_image, width, height
def create_page_xml(
image_filename: str,
result,
width: int,
height: int
) -> str:
"""Create PAGE XML structure from YOLO results"""
# Create the root element
root = ET.Element("PcGts", {
"xmlns": "http://schema.primaresearch.org/PAGE/gts/pagecontent/2019-07-15",
"xmlns:xsi": "http://www.w3.org/2001/XMLSchema-instance",
"xsi:schemaLocation": "http://schema.primaresearch.org/PAGE/gts/pagecontent/2019-07-15 http://schema.primaresearch.org/PAGE/gts/pagecontent/2019-07-15/pagecontent.xsd"
})
# Add metadata
metadata = ET.SubElement(root, "Metadata")
ET.SubElement(metadata, "Creator").text = "escriptorium"
# Use a future date like in the example
future_date = (datetime.datetime.now() + datetime.timedelta(days=365)).isoformat()
ET.SubElement(metadata, "Created").text = future_date
ET.SubElement(metadata, "LastChange").text = future_date
# Add page element with original image filename
page = ET.SubElement(root, "Page", {
"imageFilename": os.path.basename(image_filename),
"imageWidth": str(width),
"imageHeight": str(height)
})
# Process each detected mask/contour as a separate TextRegion
has_valid_masks = False
if hasattr(result, 'masks') and result.masks is not None:
masks = result.masks.xy
# Create main text region for the right side (assuming right-to-left Hebrew/Yiddish text)
# Use a unique timestamp for the ID
timestamp = int(datetime.datetime.now().timestamp())
main_region_id = f"eSc_textblock_TextRegion_{timestamp}"
# Get bounding box of all masks to determine the text region
all_points_x = []
all_points_y = []
valid_masks = []
# First pass: filter all masks and collect valid points
for mask_points in masks:
# Filter out NaN values from mask points
valid_points = [(p[0], p[1]) for p in mask_points if not (np.isnan(p[0]) or np.isnan(p[1]))]
if valid_points and len(valid_points) >= 3: # Only proceed if we have enough valid points
valid_masks.append(valid_points)
all_points_x.extend([p[0] for p in valid_points])
all_points_y.extend([p[1] for p in valid_points])
has_valid_masks = True
# Calculate the text region coordinates if we have valid points
if has_valid_masks and all_points_x and all_points_y:
min_x = max(0, int(min(all_points_x)))
max_x = min(width, int(max(all_points_x)))
min_y = max(0, int(min(all_points_y)))
max_y = min(height, int(max(all_points_y)))
# Create main text region with calculated bounds
main_text_region = ET.SubElement(page, "TextRegion", {
"id": main_region_id,
"custom": "structure {type:text_zone;}"
})
# Add coordinates for the text region (use rectangle format)
region_points = f"{min_x},{min_y} {max_x},{min_y} {max_x},{max_y} {min_x},{max_y}"
ET.SubElement(main_text_region, "Coords", {"points": region_points})
# Process each valid mask
for i, valid_points in enumerate(valid_masks):
# Create text line with auto-incrementing ID
line_id = f"eSc_line_r2l{i+1}" if i > 0 else "eSc_line_line_1610719743362_3154"
text_line = ET.SubElement(main_text_region, "TextLine", {
"id": line_id,
"custom": "structure {type:text_line;}"
})
# Format mask points for PAGE XML format
# Convert to int to avoid scientific notation
points_str = " ".join([f"{int(p[0])},{int(p[1])}" for p in valid_points])
# Add coordinates to the text line
line_coords = ET.SubElement(text_line, "Coords", {
"points": points_str
})
# Calculate baseline points spanning the entire width of the polygon
# Sort points by x-value to find the left and right boundaries
points_by_x = sorted(valid_points, key=lambda p: p[0])
leftmost_point = points_by_x[0]
rightmost_point = points_by_x[-1]
# Sort points by y-value (ascending) to find the bottom area of the line
sorted_by_y = sorted(valid_points, key=lambda p: p[1])
# Take points in the bottom third, but ensure we have at least one point
bottom_third_index = max(0, int(len(sorted_by_y) * 0.67))
bottom_points = sorted_by_y[bottom_third_index:]
if not bottom_points: # Fallback if no bottom points
bottom_points = sorted_by_y # Use all points
# Find the average y-value of bottom points for a straight baseline
avg_y = sum(p[1] for p in bottom_points) / len(bottom_points)
# Create baseline with two points spanning the full width
left_x = leftmost_point[0]
right_x = rightmost_point[0]
# Create baseline string with exactly two points
baseline_str = f"{int(left_x)},{int(avg_y)} {int(right_x)},{int(avg_y)}"
# Add baseline
baseline = ET.SubElement(text_line, "Baseline", {
"points": baseline_str
})
# Add empty text equivalent
text_equiv = ET.SubElement(text_line, "TextEquiv")
ET.SubElement(text_equiv, "Unicode")
# Create a second text region for the left side
# This is to mimic the structure in the example but with empty content
left_region = ET.SubElement(page, "TextRegion", {
"id": f"eSc_textblock_r1",
"custom": "structure {type:text_zone;}"
})
# Left region takes up the left side of the page
left_region_points = f"0,0 {min_x-10},{min_y} {min_x-10},{max_y} 0,{max_y}"
ET.SubElement(left_region, "Coords", {"points": left_region_points})
# If no valid masks were found, create a default text region covering the whole page
if not has_valid_masks:
print("Warning: No valid masks detected. Creating a default text region.")
default_region = ET.SubElement(page, "TextRegion", {
"id": f"eSc_textblock_default_{int(datetime.datetime.now().timestamp())}",
"custom": "structure {type:text_zone;}"
})
default_points = f"0,0 {width},0 {width},{height} 0,{height}"
ET.SubElement(default_region, "Coords", {"points": default_points})
# Convert to string with pretty formatting
xmlstr = minidom.parseString(ET.tostring(root)).toprettyxml(indent=" ")
return xmlstr
def save_results(image_path: str, annotated_image: np.ndarray, xml_content: str):
"""Save the original image to output/ and XML file to annotations/ directory"""
# Create output and annotations directories if they don't exist
output_dir = "output"
annotations_dir = "annotations"
os.makedirs(output_dir, exist_ok=True)
os.makedirs(annotations_dir, exist_ok=True)
# Get the base filename without extension
base_name = os.path.basename(image_path)
file_name_no_ext = os.path.splitext(base_name)[0]
# Copy the original image to output directory
output_image_path = os.path.join(output_dir, f"{file_name_no_ext}.jpg")
# Use shutil.copy to directly copy the file instead of reading/writing
shutil.copy(image_path, output_image_path)
# Save the XML file to annotations directory
output_xml_path = os.path.join(annotations_dir, f"{file_name_no_ext}.xml")
with open(output_xml_path, "w", encoding="utf-8") as f:
f.write(xml_content)
print(f"Results saved to:")
print(f" Image: {output_image_path}")
print(f" XML: {output_xml_path}")
def main():
parser = argparse.ArgumentParser(description="Convert YOLO segmentation to PAGE XML format")
parser.add_argument("image_path", help="Path to the input image or directory of images")
parser.add_argument("--model", default="YOLOv11-Medium", choices=MODEL_OPTIONS.keys(),
help="Model to use for detection")
parser.add_argument("--conf", type=float, default=0.25,
help="Confidence threshold for detection")
parser.add_argument("--iou", type=float, default=0.45,
help="IoU threshold for detection")
parser.add_argument("--batch", action="store_true",
help="Process all images in the directory if image_path is a directory")
args = parser.parse_args()
# Check if the path is a directory and batch mode is enabled
if os.path.isdir(args.image_path) and args.batch:
# Get all image files in the directory
image_files = []
for extension in ['.jpg', '.jpeg', '.png', '.tif', '.tiff']:
image_files.extend(glob.glob(os.path.join(args.image_path, f"*{extension}")))
image_files.extend(glob.glob(os.path.join(args.image_path, f"*{extension.upper()}")))
if not image_files:
print(f"No image files found in directory: {args.image_path}")
sys.exit(1)
print(f"Found {len(image_files)} images to process")
# Process each image
for i, image_path in enumerate(image_files):
print(f"Processing {i+1}/{len(image_files)}: {os.path.basename(image_path)}")
try:
# Process the image
result, annotated_image, width, height = process_image(
image_path,
args.model,
args.conf,
args.iou
)
# Create PAGE XML
xml_content = create_page_xml(image_path, result, width, height)
# Save results
save_results(image_path, annotated_image, xml_content)
except Exception as e:
print(f"Error processing {image_path}: {e}")
import traceback
traceback.print_exc()
else:
# Process a single image
try:
# Process the image
result, annotated_image, width, height = process_image(
args.image_path,
args.model,
args.conf,
args.iou
)
# Create PAGE XML
xml_content = create_page_xml(args.image_path, result, width, height)
# Save results
save_results(args.image_path, annotated_image, xml_content)
except Exception as e:
print(f"Error: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
main()