Gabriel commited on
Commit
37ebc76
·
verified ·
1 Parent(s): eba24ec

Create visualizer.py

Browse files
Files changed (1) hide show
  1. visualizer.py +257 -0
visualizer.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import xml.etree.ElementTree as ET
2
+ from PIL import Image, ImageDraw, ImageFont
3
+ import tempfile
4
+ import os
5
+ from typing import Optional, Tuple, List
6
+ import math
7
+
8
+
9
+ def htrflow_visualizer(image_path: str, htr_document_path: str) -> Optional[str]:
10
+ """
11
+ Visualize HTR results by overlaying text regions and polygons on the original image.
12
+
13
+ Args:
14
+ image_path (str): Path to the original document image file
15
+ htr_document_path (str): Path to the HTR XML file (ALTO or PAGE format)
16
+
17
+ Returns:
18
+ Optional[str]: Path to the generated visualization image, or None if failed
19
+ """
20
+ try:
21
+ if not image_path or not htr_document_path:
22
+ return None
23
+
24
+ image = Image.open(image_path)
25
+ draw = ImageDraw.Draw(image)
26
+
27
+ tree = ET.parse(htr_document_path)
28
+ root = tree.getroot()
29
+
30
+ if "alto" in root.tag.lower() or root.find(".//TextBlock") is not None:
31
+ _visualize_alto_xml(draw, root, image.size)
32
+ elif "PAGE" in root.tag or "PcGts" in root.tag:
33
+ _visualize_page_xml(draw, root, image.size)
34
+ else:
35
+ if root.find(".//*[@points]") is not None:
36
+ _visualize_page_xml(draw, root, image.size)
37
+ else:
38
+ _visualize_alto_xml(draw, root, image.size)
39
+
40
+ temp_dir = tempfile.mkdtemp()
41
+ output_path = os.path.join(temp_dir, "htr_visualization.png")
42
+ image.save(output_path)
43
+
44
+ return output_path
45
+
46
+ except Exception:
47
+ return None
48
+
49
+
50
+ def _parse_points(points_str: str) -> List[Tuple[int, int]]:
51
+ if not points_str:
52
+ return []
53
+
54
+ points = []
55
+ for coord in points_str.strip().split():
56
+ if "," in coord:
57
+ try:
58
+ x, y = coord.split(",")
59
+ points.append((int(float(x)), int(float(y))))
60
+ except ValueError:
61
+ continue
62
+ return points
63
+
64
+
65
+ def _calculate_polygon_area(points: List[Tuple[int, int]]) -> float:
66
+ if len(points) < 3:
67
+ return 0
68
+
69
+ area = 0
70
+ n = len(points)
71
+ for i in range(n):
72
+ j = (i + 1) % n
73
+ area += points[i][0] * points[j][1]
74
+ area -= points[j][0] * points[i][1]
75
+ return abs(area) / 2
76
+
77
+
78
+ def _get_dynamic_font_size(
79
+ polygons: List[List[Tuple[int, int]]], image_size: Tuple[int, int]
80
+ ) -> int:
81
+ if not polygons:
82
+ return 16
83
+
84
+ total_area = 0
85
+ valid_count = 0
86
+
87
+ for points in polygons:
88
+ area = _calculate_polygon_area(points)
89
+ if area > 0:
90
+ total_area += area
91
+ valid_count += 1
92
+
93
+ if valid_count == 0:
94
+ return 16
95
+
96
+ avg_area = total_area / valid_count
97
+ font_size = int(math.sqrt(avg_area) * 0.2)
98
+
99
+ return max(12, min(72, font_size))
100
+
101
+
102
+ def _get_font(size: int) -> Optional[ImageFont.FreeTypeFont]:
103
+ try:
104
+ font_paths = [
105
+ "/usr/share/fonts/truetype/liberation/LiberationSans-Regular.ttf",
106
+ "/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf",
107
+ "/System/Library/Fonts/Helvetica.ttc",
108
+ "C:\\Windows\\Fonts\\arial.ttf",
109
+ ]
110
+
111
+ for font_path in font_paths:
112
+ if os.path.exists(font_path):
113
+ return ImageFont.truetype(font_path, size)
114
+
115
+ return ImageFont.load_default()
116
+ except:
117
+ return ImageFont.load_default()
118
+
119
+
120
+ def _get_namespace(root: ET.Element) -> Optional[str]:
121
+ if "}" in root.tag:
122
+ return root.tag.split("}")[0] + "}"
123
+ return None
124
+
125
+
126
+ def _visualize_page_xml(
127
+ draw: ImageDraw.Draw, root: ET.Element, image_size: Tuple[int, int]
128
+ ):
129
+ text_lines = []
130
+ for elem in root.iter():
131
+ if elem.tag.endswith("TextLine"):
132
+ text_lines.append(elem)
133
+
134
+ line_data = []
135
+ all_polygons = []
136
+
137
+ for text_line in text_lines:
138
+ coords_elem = None
139
+ for child in text_line:
140
+ if child.tag.endswith("Coords"):
141
+ coords_elem = child
142
+ break
143
+
144
+ if coords_elem is not None:
145
+ points_str = coords_elem.get("points", "")
146
+ points = _parse_points(points_str)
147
+
148
+ if len(points) >= 3:
149
+ text_content = ""
150
+ confidence = None
151
+
152
+ for te in text_line.iter():
153
+ if te.tag.endswith("Unicode") and te.text:
154
+ text_content = te.text.strip()
155
+ break
156
+
157
+ for te in text_line.iter():
158
+ if te.tag.endswith("TextEquiv"):
159
+ conf_str = te.get("conf")
160
+ if conf_str:
161
+ try:
162
+ confidence = float(conf_str)
163
+ except:
164
+ pass
165
+ break
166
+
167
+ display_text = text_content
168
+ if confidence is not None:
169
+ display_text = f"{text_content} ({confidence:.3f})"
170
+
171
+ line_data.append((points, display_text))
172
+ all_polygons.append(points)
173
+
174
+ font_size = _get_dynamic_font_size(all_polygons, image_size)
175
+ font = _get_font(font_size)
176
+
177
+ for i, (points, text) in enumerate(line_data):
178
+ color = "red" if i % 2 == 0 else "blue"
179
+ draw.polygon(points, outline=color, width=2)
180
+
181
+ if text:
182
+ centroid_x = sum(p[0] for p in points) // len(points)
183
+ centroid_y = sum(p[1] for p in points) // len(points)
184
+
185
+ if font != ImageFont.load_default():
186
+ bbox = draw.textbbox((centroid_x, centroid_y), text, font=font, anchor="mm")
187
+ bbox = (bbox[0] - 2, bbox[1] - 2, bbox[2] + 2, bbox[3] + 2)
188
+ draw.rectangle(bbox, fill=(255, 255, 255, 200), outline="black")
189
+ draw.text((centroid_x, centroid_y), text, fill="black", font=font, anchor="mm")
190
+ else:
191
+ draw.text((centroid_x, centroid_y), text, fill="black")
192
+
193
+
194
+ def _visualize_alto_xml(
195
+ draw: ImageDraw.Draw, root: ET.Element, image_size: Tuple[int, int]
196
+ ):
197
+ namespace = _get_namespace(root)
198
+
199
+ text_lines = []
200
+ for elem in root.iter():
201
+ if elem.tag.endswith("TextLine"):
202
+ text_lines.append(elem)
203
+
204
+ line_data = []
205
+ all_polygons = []
206
+
207
+ for text_line in text_lines:
208
+ points = []
209
+ for shape in text_line.iter():
210
+ if shape.tag.endswith("Shape"):
211
+ for polygon in shape.iter():
212
+ if polygon.tag.endswith("Polygon"):
213
+ points_str = polygon.get("POINTS", "")
214
+ points = _parse_points(points_str)
215
+ break
216
+ break
217
+
218
+ if len(points) >= 3:
219
+ text_content = ""
220
+ confidence = None
221
+
222
+ for string_elem in text_line.iter():
223
+ if string_elem.tag.endswith("String"):
224
+ text_content = string_elem.get("CONTENT", "")
225
+ wc_str = string_elem.get("WC")
226
+ if wc_str:
227
+ try:
228
+ confidence = float(wc_str)
229
+ except:
230
+ pass
231
+ break
232
+
233
+ display_text = text_content
234
+ if confidence is not None:
235
+ display_text = f"{text_content} ({confidence:.3f})"
236
+
237
+ line_data.append((points, display_text))
238
+ all_polygons.append(points)
239
+
240
+ font_size = _get_dynamic_font_size(all_polygons, image_size)
241
+ font = _get_font(font_size)
242
+
243
+ for i, (points, text) in enumerate(line_data):
244
+ color = "red" if i % 2 == 0 else "blue"
245
+ draw.polygon(points, outline=color, width=2)
246
+
247
+ if text:
248
+ centroid_x = sum(p[0] for p in points) // len(points)
249
+ centroid_y = sum(p[1] for p in points) // len(points)
250
+
251
+ if font != ImageFont.load_default():
252
+ bbox = draw.textbbox((centroid_x, centroid_y), text, font=font, anchor="mm")
253
+ bbox = (bbox[0] - 2, bbox[1] - 2, bbox[2] + 2, bbox[3] + 2)
254
+ draw.rectangle(bbox, fill=(255, 255, 255, 200), outline="black")
255
+ draw.text((centroid_x, centroid_y), text, fill="black", font=font, anchor="mm")
256
+ else:
257
+ draw.text((centroid_x, centroid_y), text, fill="black")