File size: 13,741 Bytes
bd0e32e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
from typing import Dict, List
import os
import sys
import glob
import argparse
import datetime
import shutil
import numpy as np
import cv2
from PIL import Image
from ultralytics import YOLO
from huggingface_hub import hf_hub_download

# XML generation imports
import xml.etree.ElementTree as ET
from xml.dom import minidom

# Define models
MODEL_OPTIONS = {
    "YOLOv11-Nano": "yolov11n-seg.pt",
    "YOLOv11-Small": "yolov11s-seg.pt",
    "YOLOv11-Medium": "yolov11m-seg.pt",
    "YOLOv11-Large": "yolov11l-seg.pt",
    "YOLOv11-XLarge": "yolov11x-seg.pt"
}

# Dictionary to store loaded models
models: Dict[str, YOLO] = {}

# Load specified model or default to Nano
def load_model(model_name: str = "YOLOv11-Nano") -> YOLO:
    if model_name not in models:
        model_file = MODEL_OPTIONS[model_name]
        model_path = hf_hub_download(
            repo_id="wjbmattingly/kraken-yiddish",
            filename=model_file
        )
        models[model_name] = YOLO(model_path)
    return models[model_name]

def process_image(
    image_path: str, 
    model_name: str = "YOLOv11-Medium", 
    conf_threshold: float = 0.25, 
    iou_threshold: float = 0.45
) -> tuple:
    """Process an image and return detection results and annotated image"""
    
    # Read the image
    image = cv2.imread(image_path)
    if image is None:
        raise ValueError(f"Cannot read image: {image_path}")
    
    # Convert BGR to RGB for YOLO
    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    
    # Get image dimensions
    height, width = image.shape[:2]
    
    # Get the selected model
    model = load_model(model_name)
    
    # Perform inference with YOLO
    results = model(
        image_rgb,
        conf=conf_threshold,
        iou=iou_threshold,
        verbose=False,
        device='cpu'
    )
    
    # Get the first result
    result = results[0]
    
    # Create annotated image for visualization
    annotated_image = result.plot(
        conf=True,
        line_width=None,
        font_size=None,
        boxes=True,
        masks=True,
        probs=True,
        labels=True
    )
    
    # Convert back to BGR for saving with OpenCV
    annotated_image = cv2.cvtColor(annotated_image, cv2.COLOR_RGB2BGR)
    
    return result, annotated_image, width, height

def create_page_xml(
    image_filename: str, 
    result,  
    width: int, 
    height: int
) -> str:
    """Create PAGE XML structure from YOLO results"""
    
    # Create the root element
    root = ET.Element("PcGts", {
        "xmlns": "http://schema.primaresearch.org/PAGE/gts/pagecontent/2019-07-15",
        "xmlns:xsi": "http://www.w3.org/2001/XMLSchema-instance",
        "xsi:schemaLocation": "http://schema.primaresearch.org/PAGE/gts/pagecontent/2019-07-15 http://schema.primaresearch.org/PAGE/gts/pagecontent/2019-07-15/pagecontent.xsd"
    })
    
    # Add metadata
    metadata = ET.SubElement(root, "Metadata")
    ET.SubElement(metadata, "Creator").text = "escriptorium"
    
    # Use a future date like in the example
    future_date = (datetime.datetime.now() + datetime.timedelta(days=365)).isoformat()
    ET.SubElement(metadata, "Created").text = future_date
    ET.SubElement(metadata, "LastChange").text = future_date
    
    # Add page element with original image filename
    page = ET.SubElement(root, "Page", {
        "imageFilename": os.path.basename(image_filename),
        "imageWidth": str(width),
        "imageHeight": str(height)
    })
    
    # Process each detected mask/contour as a separate TextRegion
    has_valid_masks = False
    
    if hasattr(result, 'masks') and result.masks is not None:
        masks = result.masks.xy
        
        # Create main text region for the right side (assuming right-to-left Hebrew/Yiddish text)
        # Use a unique timestamp for the ID
        timestamp = int(datetime.datetime.now().timestamp())
        main_region_id = f"eSc_textblock_TextRegion_{timestamp}"
        
        # Get bounding box of all masks to determine the text region
        all_points_x = []
        all_points_y = []
        valid_masks = []
        
        # First pass: filter all masks and collect valid points
        for mask_points in masks:
            # Filter out NaN values from mask points
            valid_points = [(p[0], p[1]) for p in mask_points if not (np.isnan(p[0]) or np.isnan(p[1]))]
            
            if valid_points and len(valid_points) >= 3:  # Only proceed if we have enough valid points
                valid_masks.append(valid_points)
                all_points_x.extend([p[0] for p in valid_points])
                all_points_y.extend([p[1] for p in valid_points])
                has_valid_masks = True
        
        # Calculate the text region coordinates if we have valid points
        if has_valid_masks and all_points_x and all_points_y:
            min_x = max(0, int(min(all_points_x)))
            max_x = min(width, int(max(all_points_x)))
            min_y = max(0, int(min(all_points_y)))
            max_y = min(height, int(max(all_points_y)))
            
            # Create main text region with calculated bounds
            main_text_region = ET.SubElement(page, "TextRegion", {
                "id": main_region_id,
                "custom": "structure {type:text_zone;}"
            })
            
            # Add coordinates for the text region (use rectangle format)
            region_points = f"{min_x},{min_y} {max_x},{min_y} {max_x},{max_y} {min_x},{max_y}"
            ET.SubElement(main_text_region, "Coords", {"points": region_points})
            
            # Process each valid mask
            for i, valid_points in enumerate(valid_masks):
                # Create text line with auto-incrementing ID
                line_id = f"eSc_line_r2l{i+1}" if i > 0 else "eSc_line_line_1610719743362_3154"
                text_line = ET.SubElement(main_text_region, "TextLine", {
                    "id": line_id,
                    "custom": "structure {type:text_line;}"
                })
                
                # Format mask points for PAGE XML format
                # Convert to int to avoid scientific notation
                points_str = " ".join([f"{int(p[0])},{int(p[1])}" for p in valid_points])
                
                # Add coordinates to the text line
                line_coords = ET.SubElement(text_line, "Coords", {
                    "points": points_str
                })
                
                # Calculate baseline points spanning the entire width of the polygon
                # Sort points by x-value to find the left and right boundaries
                points_by_x = sorted(valid_points, key=lambda p: p[0])
                leftmost_point = points_by_x[0]
                rightmost_point = points_by_x[-1]
                
                # Sort points by y-value (ascending) to find the bottom area of the line
                sorted_by_y = sorted(valid_points, key=lambda p: p[1])
                
                # Take points in the bottom third, but ensure we have at least one point
                bottom_third_index = max(0, int(len(sorted_by_y) * 0.67))
                bottom_points = sorted_by_y[bottom_third_index:]
                
                if not bottom_points:  # Fallback if no bottom points
                    bottom_points = sorted_by_y  # Use all points
                
                # Find the average y-value of bottom points for a straight baseline
                avg_y = sum(p[1] for p in bottom_points) / len(bottom_points)
                
                # Create baseline with two points spanning the full width
                left_x = leftmost_point[0]
                right_x = rightmost_point[0]
                
                # Create baseline string with exactly two points
                baseline_str = f"{int(left_x)},{int(avg_y)} {int(right_x)},{int(avg_y)}"
                
                # Add baseline
                baseline = ET.SubElement(text_line, "Baseline", {
                    "points": baseline_str
                })
                
                # Add empty text equivalent
                text_equiv = ET.SubElement(text_line, "TextEquiv")
                ET.SubElement(text_equiv, "Unicode")
            
            # Create a second text region for the left side
            # This is to mimic the structure in the example but with empty content
            left_region = ET.SubElement(page, "TextRegion", {
                "id": f"eSc_textblock_r1",
                "custom": "structure {type:text_zone;}"
            })
            
            # Left region takes up the left side of the page
            left_region_points = f"0,0 {min_x-10},{min_y} {min_x-10},{max_y} 0,{max_y}"
            ET.SubElement(left_region, "Coords", {"points": left_region_points})
    
    # If no valid masks were found, create a default text region covering the whole page
    if not has_valid_masks:
        print("Warning: No valid masks detected. Creating a default text region.")
        default_region = ET.SubElement(page, "TextRegion", {
            "id": f"eSc_textblock_default_{int(datetime.datetime.now().timestamp())}",
            "custom": "structure {type:text_zone;}"
        })
        default_points = f"0,0 {width},0 {width},{height} 0,{height}"
        ET.SubElement(default_region, "Coords", {"points": default_points})
    
    # Convert to string with pretty formatting
    xmlstr = minidom.parseString(ET.tostring(root)).toprettyxml(indent="  ")
    
    return xmlstr

def save_results(image_path: str, annotated_image: np.ndarray, xml_content: str):
    """Save the original image to output/ and XML file to annotations/ directory"""
    
    # Create output and annotations directories if they don't exist
    output_dir = "output"
    annotations_dir = "annotations"
    os.makedirs(output_dir, exist_ok=True)
    os.makedirs(annotations_dir, exist_ok=True)
    
    # Get the base filename without extension
    base_name = os.path.basename(image_path)
    file_name_no_ext = os.path.splitext(base_name)[0]
    
    # Copy the original image to output directory
    output_image_path = os.path.join(output_dir, f"{file_name_no_ext}.jpg")
    # Use shutil.copy to directly copy the file instead of reading/writing
    shutil.copy(image_path, output_image_path)
    
    # Save the XML file to annotations directory
    output_xml_path = os.path.join(annotations_dir, f"{file_name_no_ext}.xml")
    with open(output_xml_path, "w", encoding="utf-8") as f:
        f.write(xml_content)
    
    print(f"Results saved to:")
    print(f"  Image: {output_image_path}")
    print(f"  XML: {output_xml_path}")

def main():
    parser = argparse.ArgumentParser(description="Convert YOLO segmentation to PAGE XML format")
    parser.add_argument("image_path", help="Path to the input image or directory of images")
    parser.add_argument("--model", default="YOLOv11-Medium", choices=MODEL_OPTIONS.keys(),
                        help="Model to use for detection")
    parser.add_argument("--conf", type=float, default=0.25, 
                        help="Confidence threshold for detection")
    parser.add_argument("--iou", type=float, default=0.45,
                        help="IoU threshold for detection")
    parser.add_argument("--batch", action="store_true",
                        help="Process all images in the directory if image_path is a directory")
    
    args = parser.parse_args()
    
    # Check if the path is a directory and batch mode is enabled
    if os.path.isdir(args.image_path) and args.batch:
        # Get all image files in the directory
        image_files = []
        for extension in ['.jpg', '.jpeg', '.png', '.tif', '.tiff']:
            image_files.extend(glob.glob(os.path.join(args.image_path, f"*{extension}")))
            image_files.extend(glob.glob(os.path.join(args.image_path, f"*{extension.upper()}")))
        
        if not image_files:
            print(f"No image files found in directory: {args.image_path}")
            sys.exit(1)
        
        print(f"Found {len(image_files)} images to process")
        
        # Process each image
        for i, image_path in enumerate(image_files):
            print(f"Processing {i+1}/{len(image_files)}: {os.path.basename(image_path)}")
            try:
                # Process the image
                result, annotated_image, width, height = process_image(
                    image_path, 
                    args.model, 
                    args.conf, 
                    args.iou
                )
                
                # Create PAGE XML
                xml_content = create_page_xml(image_path, result, width, height)
                
                # Save results
                save_results(image_path, annotated_image, xml_content)
                
            except Exception as e:
                print(f"Error processing {image_path}: {e}")
                import traceback
                traceback.print_exc()
    else:
        # Process a single image
        try:
            # Process the image
            result, annotated_image, width, height = process_image(
                args.image_path, 
                args.model, 
                args.conf, 
                args.iou
            )
            
            # Create PAGE XML
            xml_content = create_page_xml(args.image_path, result, width, height)
            
            # Save results
            save_results(args.image_path, annotated_image, xml_content)
            
        except Exception as e:
            print(f"Error: {e}")
            import traceback
            traceback.print_exc()
            sys.exit(1)

if __name__ == "__main__":
    main()