File size: 13,741 Bytes
bd0e32e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 |
from typing import Dict, List
import os
import sys
import glob
import argparse
import datetime
import shutil
import numpy as np
import cv2
from PIL import Image
from ultralytics import YOLO
from huggingface_hub import hf_hub_download
# XML generation imports
import xml.etree.ElementTree as ET
from xml.dom import minidom
# Define models
MODEL_OPTIONS = {
"YOLOv11-Nano": "yolov11n-seg.pt",
"YOLOv11-Small": "yolov11s-seg.pt",
"YOLOv11-Medium": "yolov11m-seg.pt",
"YOLOv11-Large": "yolov11l-seg.pt",
"YOLOv11-XLarge": "yolov11x-seg.pt"
}
# Dictionary to store loaded models
models: Dict[str, YOLO] = {}
# Load specified model or default to Nano
def load_model(model_name: str = "YOLOv11-Nano") -> YOLO:
if model_name not in models:
model_file = MODEL_OPTIONS[model_name]
model_path = hf_hub_download(
repo_id="wjbmattingly/kraken-yiddish",
filename=model_file
)
models[model_name] = YOLO(model_path)
return models[model_name]
def process_image(
image_path: str,
model_name: str = "YOLOv11-Medium",
conf_threshold: float = 0.25,
iou_threshold: float = 0.45
) -> tuple:
"""Process an image and return detection results and annotated image"""
# Read the image
image = cv2.imread(image_path)
if image is None:
raise ValueError(f"Cannot read image: {image_path}")
# Convert BGR to RGB for YOLO
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Get image dimensions
height, width = image.shape[:2]
# Get the selected model
model = load_model(model_name)
# Perform inference with YOLO
results = model(
image_rgb,
conf=conf_threshold,
iou=iou_threshold,
verbose=False,
device='cpu'
)
# Get the first result
result = results[0]
# Create annotated image for visualization
annotated_image = result.plot(
conf=True,
line_width=None,
font_size=None,
boxes=True,
masks=True,
probs=True,
labels=True
)
# Convert back to BGR for saving with OpenCV
annotated_image = cv2.cvtColor(annotated_image, cv2.COLOR_RGB2BGR)
return result, annotated_image, width, height
def create_page_xml(
image_filename: str,
result,
width: int,
height: int
) -> str:
"""Create PAGE XML structure from YOLO results"""
# Create the root element
root = ET.Element("PcGts", {
"xmlns": "http://schema.primaresearch.org/PAGE/gts/pagecontent/2019-07-15",
"xmlns:xsi": "http://www.w3.org/2001/XMLSchema-instance",
"xsi:schemaLocation": "http://schema.primaresearch.org/PAGE/gts/pagecontent/2019-07-15 http://schema.primaresearch.org/PAGE/gts/pagecontent/2019-07-15/pagecontent.xsd"
})
# Add metadata
metadata = ET.SubElement(root, "Metadata")
ET.SubElement(metadata, "Creator").text = "escriptorium"
# Use a future date like in the example
future_date = (datetime.datetime.now() + datetime.timedelta(days=365)).isoformat()
ET.SubElement(metadata, "Created").text = future_date
ET.SubElement(metadata, "LastChange").text = future_date
# Add page element with original image filename
page = ET.SubElement(root, "Page", {
"imageFilename": os.path.basename(image_filename),
"imageWidth": str(width),
"imageHeight": str(height)
})
# Process each detected mask/contour as a separate TextRegion
has_valid_masks = False
if hasattr(result, 'masks') and result.masks is not None:
masks = result.masks.xy
# Create main text region for the right side (assuming right-to-left Hebrew/Yiddish text)
# Use a unique timestamp for the ID
timestamp = int(datetime.datetime.now().timestamp())
main_region_id = f"eSc_textblock_TextRegion_{timestamp}"
# Get bounding box of all masks to determine the text region
all_points_x = []
all_points_y = []
valid_masks = []
# First pass: filter all masks and collect valid points
for mask_points in masks:
# Filter out NaN values from mask points
valid_points = [(p[0], p[1]) for p in mask_points if not (np.isnan(p[0]) or np.isnan(p[1]))]
if valid_points and len(valid_points) >= 3: # Only proceed if we have enough valid points
valid_masks.append(valid_points)
all_points_x.extend([p[0] for p in valid_points])
all_points_y.extend([p[1] for p in valid_points])
has_valid_masks = True
# Calculate the text region coordinates if we have valid points
if has_valid_masks and all_points_x and all_points_y:
min_x = max(0, int(min(all_points_x)))
max_x = min(width, int(max(all_points_x)))
min_y = max(0, int(min(all_points_y)))
max_y = min(height, int(max(all_points_y)))
# Create main text region with calculated bounds
main_text_region = ET.SubElement(page, "TextRegion", {
"id": main_region_id,
"custom": "structure {type:text_zone;}"
})
# Add coordinates for the text region (use rectangle format)
region_points = f"{min_x},{min_y} {max_x},{min_y} {max_x},{max_y} {min_x},{max_y}"
ET.SubElement(main_text_region, "Coords", {"points": region_points})
# Process each valid mask
for i, valid_points in enumerate(valid_masks):
# Create text line with auto-incrementing ID
line_id = f"eSc_line_r2l{i+1}" if i > 0 else "eSc_line_line_1610719743362_3154"
text_line = ET.SubElement(main_text_region, "TextLine", {
"id": line_id,
"custom": "structure {type:text_line;}"
})
# Format mask points for PAGE XML format
# Convert to int to avoid scientific notation
points_str = " ".join([f"{int(p[0])},{int(p[1])}" for p in valid_points])
# Add coordinates to the text line
line_coords = ET.SubElement(text_line, "Coords", {
"points": points_str
})
# Calculate baseline points spanning the entire width of the polygon
# Sort points by x-value to find the left and right boundaries
points_by_x = sorted(valid_points, key=lambda p: p[0])
leftmost_point = points_by_x[0]
rightmost_point = points_by_x[-1]
# Sort points by y-value (ascending) to find the bottom area of the line
sorted_by_y = sorted(valid_points, key=lambda p: p[1])
# Take points in the bottom third, but ensure we have at least one point
bottom_third_index = max(0, int(len(sorted_by_y) * 0.67))
bottom_points = sorted_by_y[bottom_third_index:]
if not bottom_points: # Fallback if no bottom points
bottom_points = sorted_by_y # Use all points
# Find the average y-value of bottom points for a straight baseline
avg_y = sum(p[1] for p in bottom_points) / len(bottom_points)
# Create baseline with two points spanning the full width
left_x = leftmost_point[0]
right_x = rightmost_point[0]
# Create baseline string with exactly two points
baseline_str = f"{int(left_x)},{int(avg_y)} {int(right_x)},{int(avg_y)}"
# Add baseline
baseline = ET.SubElement(text_line, "Baseline", {
"points": baseline_str
})
# Add empty text equivalent
text_equiv = ET.SubElement(text_line, "TextEquiv")
ET.SubElement(text_equiv, "Unicode")
# Create a second text region for the left side
# This is to mimic the structure in the example but with empty content
left_region = ET.SubElement(page, "TextRegion", {
"id": f"eSc_textblock_r1",
"custom": "structure {type:text_zone;}"
})
# Left region takes up the left side of the page
left_region_points = f"0,0 {min_x-10},{min_y} {min_x-10},{max_y} 0,{max_y}"
ET.SubElement(left_region, "Coords", {"points": left_region_points})
# If no valid masks were found, create a default text region covering the whole page
if not has_valid_masks:
print("Warning: No valid masks detected. Creating a default text region.")
default_region = ET.SubElement(page, "TextRegion", {
"id": f"eSc_textblock_default_{int(datetime.datetime.now().timestamp())}",
"custom": "structure {type:text_zone;}"
})
default_points = f"0,0 {width},0 {width},{height} 0,{height}"
ET.SubElement(default_region, "Coords", {"points": default_points})
# Convert to string with pretty formatting
xmlstr = minidom.parseString(ET.tostring(root)).toprettyxml(indent=" ")
return xmlstr
def save_results(image_path: str, annotated_image: np.ndarray, xml_content: str):
"""Save the original image to output/ and XML file to annotations/ directory"""
# Create output and annotations directories if they don't exist
output_dir = "output"
annotations_dir = "annotations"
os.makedirs(output_dir, exist_ok=True)
os.makedirs(annotations_dir, exist_ok=True)
# Get the base filename without extension
base_name = os.path.basename(image_path)
file_name_no_ext = os.path.splitext(base_name)[0]
# Copy the original image to output directory
output_image_path = os.path.join(output_dir, f"{file_name_no_ext}.jpg")
# Use shutil.copy to directly copy the file instead of reading/writing
shutil.copy(image_path, output_image_path)
# Save the XML file to annotations directory
output_xml_path = os.path.join(annotations_dir, f"{file_name_no_ext}.xml")
with open(output_xml_path, "w", encoding="utf-8") as f:
f.write(xml_content)
print(f"Results saved to:")
print(f" Image: {output_image_path}")
print(f" XML: {output_xml_path}")
def main():
parser = argparse.ArgumentParser(description="Convert YOLO segmentation to PAGE XML format")
parser.add_argument("image_path", help="Path to the input image or directory of images")
parser.add_argument("--model", default="YOLOv11-Medium", choices=MODEL_OPTIONS.keys(),
help="Model to use for detection")
parser.add_argument("--conf", type=float, default=0.25,
help="Confidence threshold for detection")
parser.add_argument("--iou", type=float, default=0.45,
help="IoU threshold for detection")
parser.add_argument("--batch", action="store_true",
help="Process all images in the directory if image_path is a directory")
args = parser.parse_args()
# Check if the path is a directory and batch mode is enabled
if os.path.isdir(args.image_path) and args.batch:
# Get all image files in the directory
image_files = []
for extension in ['.jpg', '.jpeg', '.png', '.tif', '.tiff']:
image_files.extend(glob.glob(os.path.join(args.image_path, f"*{extension}")))
image_files.extend(glob.glob(os.path.join(args.image_path, f"*{extension.upper()}")))
if not image_files:
print(f"No image files found in directory: {args.image_path}")
sys.exit(1)
print(f"Found {len(image_files)} images to process")
# Process each image
for i, image_path in enumerate(image_files):
print(f"Processing {i+1}/{len(image_files)}: {os.path.basename(image_path)}")
try:
# Process the image
result, annotated_image, width, height = process_image(
image_path,
args.model,
args.conf,
args.iou
)
# Create PAGE XML
xml_content = create_page_xml(image_path, result, width, height)
# Save results
save_results(image_path, annotated_image, xml_content)
except Exception as e:
print(f"Error processing {image_path}: {e}")
import traceback
traceback.print_exc()
else:
# Process a single image
try:
# Process the image
result, annotated_image, width, height = process_image(
args.image_path,
args.model,
args.conf,
args.iou
)
# Create PAGE XML
xml_content = create_page_xml(args.image_path, result, width, height)
# Save results
save_results(args.image_path, annotated_image, xml_content)
except Exception as e:
print(f"Error: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
main() |