from typing import Dict, List import os import sys import glob import argparse import datetime import shutil import numpy as np import cv2 from PIL import Image from ultralytics import YOLO from huggingface_hub import hf_hub_download # XML generation imports import xml.etree.ElementTree as ET from xml.dom import minidom # Define models MODEL_OPTIONS = { "YOLOv11-Nano": "yolov11n-seg.pt", "YOLOv11-Small": "yolov11s-seg.pt", "YOLOv11-Medium": "yolov11m-seg.pt", "YOLOv11-Large": "yolov11l-seg.pt", "YOLOv11-XLarge": "yolov11x-seg.pt" } # Dictionary to store loaded models models: Dict[str, YOLO] = {} # Load specified model or default to Nano def load_model(model_name: str = "YOLOv11-Nano") -> YOLO: if model_name not in models: model_file = MODEL_OPTIONS[model_name] model_path = hf_hub_download( repo_id="wjbmattingly/kraken-yiddish", filename=model_file ) models[model_name] = YOLO(model_path) return models[model_name] def process_image( image_path: str, model_name: str = "YOLOv11-Medium", conf_threshold: float = 0.25, iou_threshold: float = 0.45 ) -> tuple: """Process an image and return detection results and annotated image""" # Read the image image = cv2.imread(image_path) if image is None: raise ValueError(f"Cannot read image: {image_path}") # Convert BGR to RGB for YOLO image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Get image dimensions height, width = image.shape[:2] # Get the selected model model = load_model(model_name) # Perform inference with YOLO results = model( image_rgb, conf=conf_threshold, iou=iou_threshold, verbose=False, device='cpu' ) # Get the first result result = results[0] # Create annotated image for visualization annotated_image = result.plot( conf=True, line_width=None, font_size=None, boxes=True, masks=True, probs=True, labels=True ) # Convert back to BGR for saving with OpenCV annotated_image = cv2.cvtColor(annotated_image, cv2.COLOR_RGB2BGR) return result, annotated_image, width, height def create_page_xml( image_filename: str, result, width: int, height: int ) -> str: """Create PAGE XML structure from YOLO results""" # Create the root element root = ET.Element("PcGts", { "xmlns": "http://schema.primaresearch.org/PAGE/gts/pagecontent/2019-07-15", "xmlns:xsi": "http://www.w3.org/2001/XMLSchema-instance", "xsi:schemaLocation": "http://schema.primaresearch.org/PAGE/gts/pagecontent/2019-07-15 http://schema.primaresearch.org/PAGE/gts/pagecontent/2019-07-15/pagecontent.xsd" }) # Add metadata metadata = ET.SubElement(root, "Metadata") ET.SubElement(metadata, "Creator").text = "escriptorium" # Use a future date like in the example future_date = (datetime.datetime.now() + datetime.timedelta(days=365)).isoformat() ET.SubElement(metadata, "Created").text = future_date ET.SubElement(metadata, "LastChange").text = future_date # Add page element with original image filename page = ET.SubElement(root, "Page", { "imageFilename": os.path.basename(image_filename), "imageWidth": str(width), "imageHeight": str(height) }) # Process each detected mask/contour as a separate TextRegion has_valid_masks = False if hasattr(result, 'masks') and result.masks is not None: masks = result.masks.xy # Create main text region for the right side (assuming right-to-left Hebrew/Yiddish text) # Use a unique timestamp for the ID timestamp = int(datetime.datetime.now().timestamp()) main_region_id = f"eSc_textblock_TextRegion_{timestamp}" # Get bounding box of all masks to determine the text region all_points_x = [] all_points_y = [] valid_masks = [] # First pass: filter all masks and collect valid points for mask_points in masks: # Filter out NaN values from mask points valid_points = [(p[0], p[1]) for p in mask_points if not (np.isnan(p[0]) or np.isnan(p[1]))] if valid_points and len(valid_points) >= 3: # Only proceed if we have enough valid points valid_masks.append(valid_points) all_points_x.extend([p[0] for p in valid_points]) all_points_y.extend([p[1] for p in valid_points]) has_valid_masks = True # Calculate the text region coordinates if we have valid points if has_valid_masks and all_points_x and all_points_y: min_x = max(0, int(min(all_points_x))) max_x = min(width, int(max(all_points_x))) min_y = max(0, int(min(all_points_y))) max_y = min(height, int(max(all_points_y))) # Create main text region with calculated bounds main_text_region = ET.SubElement(page, "TextRegion", { "id": main_region_id, "custom": "structure {type:text_zone;}" }) # Add coordinates for the text region (use rectangle format) region_points = f"{min_x},{min_y} {max_x},{min_y} {max_x},{max_y} {min_x},{max_y}" ET.SubElement(main_text_region, "Coords", {"points": region_points}) # Process each valid mask for i, valid_points in enumerate(valid_masks): # Create text line with auto-incrementing ID line_id = f"eSc_line_r2l{i+1}" if i > 0 else "eSc_line_line_1610719743362_3154" text_line = ET.SubElement(main_text_region, "TextLine", { "id": line_id, "custom": "structure {type:text_line;}" }) # Format mask points for PAGE XML format # Convert to int to avoid scientific notation points_str = " ".join([f"{int(p[0])},{int(p[1])}" for p in valid_points]) # Add coordinates to the text line line_coords = ET.SubElement(text_line, "Coords", { "points": points_str }) # Calculate baseline points spanning the entire width of the polygon # Sort points by x-value to find the left and right boundaries points_by_x = sorted(valid_points, key=lambda p: p[0]) leftmost_point = points_by_x[0] rightmost_point = points_by_x[-1] # Sort points by y-value (ascending) to find the bottom area of the line sorted_by_y = sorted(valid_points, key=lambda p: p[1]) # Take points in the bottom third, but ensure we have at least one point bottom_third_index = max(0, int(len(sorted_by_y) * 0.67)) bottom_points = sorted_by_y[bottom_third_index:] if not bottom_points: # Fallback if no bottom points bottom_points = sorted_by_y # Use all points # Find the average y-value of bottom points for a straight baseline avg_y = sum(p[1] for p in bottom_points) / len(bottom_points) # Create baseline with two points spanning the full width left_x = leftmost_point[0] right_x = rightmost_point[0] # Create baseline string with exactly two points baseline_str = f"{int(left_x)},{int(avg_y)} {int(right_x)},{int(avg_y)}" # Add baseline baseline = ET.SubElement(text_line, "Baseline", { "points": baseline_str }) # Add empty text equivalent text_equiv = ET.SubElement(text_line, "TextEquiv") ET.SubElement(text_equiv, "Unicode") # Create a second text region for the left side # This is to mimic the structure in the example but with empty content left_region = ET.SubElement(page, "TextRegion", { "id": f"eSc_textblock_r1", "custom": "structure {type:text_zone;}" }) # Left region takes up the left side of the page left_region_points = f"0,0 {min_x-10},{min_y} {min_x-10},{max_y} 0,{max_y}" ET.SubElement(left_region, "Coords", {"points": left_region_points}) # If no valid masks were found, create a default text region covering the whole page if not has_valid_masks: print("Warning: No valid masks detected. Creating a default text region.") default_region = ET.SubElement(page, "TextRegion", { "id": f"eSc_textblock_default_{int(datetime.datetime.now().timestamp())}", "custom": "structure {type:text_zone;}" }) default_points = f"0,0 {width},0 {width},{height} 0,{height}" ET.SubElement(default_region, "Coords", {"points": default_points}) # Convert to string with pretty formatting xmlstr = minidom.parseString(ET.tostring(root)).toprettyxml(indent=" ") return xmlstr def save_results(image_path: str, annotated_image: np.ndarray, xml_content: str): """Save the original image to output/ and XML file to annotations/ directory""" # Create output and annotations directories if they don't exist output_dir = "output" annotations_dir = "annotations" os.makedirs(output_dir, exist_ok=True) os.makedirs(annotations_dir, exist_ok=True) # Get the base filename without extension base_name = os.path.basename(image_path) file_name_no_ext = os.path.splitext(base_name)[0] # Copy the original image to output directory output_image_path = os.path.join(output_dir, f"{file_name_no_ext}.jpg") # Use shutil.copy to directly copy the file instead of reading/writing shutil.copy(image_path, output_image_path) # Save the XML file to annotations directory output_xml_path = os.path.join(annotations_dir, f"{file_name_no_ext}.xml") with open(output_xml_path, "w", encoding="utf-8") as f: f.write(xml_content) print(f"Results saved to:") print(f" Image: {output_image_path}") print(f" XML: {output_xml_path}") def main(): parser = argparse.ArgumentParser(description="Convert YOLO segmentation to PAGE XML format") parser.add_argument("image_path", help="Path to the input image or directory of images") parser.add_argument("--model", default="YOLOv11-Medium", choices=MODEL_OPTIONS.keys(), help="Model to use for detection") parser.add_argument("--conf", type=float, default=0.25, help="Confidence threshold for detection") parser.add_argument("--iou", type=float, default=0.45, help="IoU threshold for detection") parser.add_argument("--batch", action="store_true", help="Process all images in the directory if image_path is a directory") args = parser.parse_args() # Check if the path is a directory and batch mode is enabled if os.path.isdir(args.image_path) and args.batch: # Get all image files in the directory image_files = [] for extension in ['.jpg', '.jpeg', '.png', '.tif', '.tiff']: image_files.extend(glob.glob(os.path.join(args.image_path, f"*{extension}"))) image_files.extend(glob.glob(os.path.join(args.image_path, f"*{extension.upper()}"))) if not image_files: print(f"No image files found in directory: {args.image_path}") sys.exit(1) print(f"Found {len(image_files)} images to process") # Process each image for i, image_path in enumerate(image_files): print(f"Processing {i+1}/{len(image_files)}: {os.path.basename(image_path)}") try: # Process the image result, annotated_image, width, height = process_image( image_path, args.model, args.conf, args.iou ) # Create PAGE XML xml_content = create_page_xml(image_path, result, width, height) # Save results save_results(image_path, annotated_image, xml_content) except Exception as e: print(f"Error processing {image_path}: {e}") import traceback traceback.print_exc() else: # Process a single image try: # Process the image result, annotated_image, width, height = process_image( args.image_path, args.model, args.conf, args.iou ) # Create PAGE XML xml_content = create_page_xml(args.image_path, result, width, height) # Save results save_results(args.image_path, annotated_image, xml_content) except Exception as e: print(f"Error: {e}") import traceback traceback.print_exc() sys.exit(1) if __name__ == "__main__": main()