htrflow_mcp / visualizer.py
Gabriel's picture
Update visualizer.py
af0cb92 verified
import xml.etree.ElementTree as ET
from PIL import Image, ImageDraw, ImageFont
import tempfile
import os
from typing import Optional, Tuple, List
import math
def htrflow_visualizer(image_path: str, htr_document_path: str, server_name: str = "https://gabriel-htrflow-mcp.hf.space") -> Optional[str]:
"""
Visualize HTR results by overlaying text regions and polygons on the original image.
Args:
image_path (str): Path to the original document image file
htr_document_path (str): Path to the HTR XML file (ALTO or PAGE format)
Returns:
str: File path to the generated visualization imagegenerated visualization image for direct download via gr.File (server_name/gradio_api/file=/tmp/gradio/{temp_folder}/{file_name})
e.g : https://gabriel-htrflow-mcp.hf.space/gradio_api/file=/tmp/gradio/34d5c1a8b7d5445469c4f7c638c490e0e3046b3008a0182f89c688b1b42d139b/htr_visualization.png
"""
try:
if not image_path or not htr_document_path:
return None
image = Image.open(image_path)
draw = ImageDraw.Draw(image)
tree = ET.parse(htr_document_path)
root = tree.getroot()
if "alto" in root.tag.lower() or root.find(".//TextBlock") is not None:
_visualize_alto_xml(draw, root, image.size)
elif "PAGE" in root.tag or "PcGts" in root.tag:
_visualize_page_xml(draw, root, image.size)
else:
if root.find(".//*[@points]") is not None:
_visualize_page_xml(draw, root, image.size)
else:
_visualize_alto_xml(draw, root, image.size)
temp_dir = tempfile.mkdtemp()
output_path = os.path.join(temp_dir, "htr_visualization.png")
image.save(output_path)
return output_path
except Exception:
return None
def _parse_points(points_str: str) -> List[Tuple[int, int]]:
if not points_str:
return []
points = []
for coord in points_str.strip().split():
if "," in coord:
try:
x, y = coord.split(",")
points.append((int(float(x)), int(float(y))))
except ValueError:
continue
return points
def _calculate_polygon_area(points: List[Tuple[int, int]]) -> float:
if len(points) < 3:
return 0
area = 0
n = len(points)
for i in range(n):
j = (i + 1) % n
area += points[i][0] * points[j][1]
area -= points[j][0] * points[i][1]
return abs(area) / 2
def _get_dynamic_font_size(
polygons: List[List[Tuple[int, int]]], image_size: Tuple[int, int]
) -> int:
if not polygons:
return 16
total_area = 0
valid_count = 0
for points in polygons:
area = _calculate_polygon_area(points)
if area > 0:
total_area += area
valid_count += 1
if valid_count == 0:
return 16
avg_area = total_area / valid_count
font_size = int(math.sqrt(avg_area) * 0.2)
return max(12, min(72, font_size))
def _get_font(size: int) -> Optional[ImageFont.FreeTypeFont]:
try:
font_paths = [
"/usr/share/fonts/truetype/liberation/LiberationSans-Regular.ttf",
"/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf",
"/System/Library/Fonts/Helvetica.ttc",
"C:\\Windows\\Fonts\\arial.ttf",
]
for font_path in font_paths:
if os.path.exists(font_path):
return ImageFont.truetype(font_path, size)
return ImageFont.load_default()
except:
return ImageFont.load_default()
def _get_namespace(root: ET.Element) -> Optional[str]:
if "}" in root.tag:
return root.tag.split("}")[0] + "}"
return None
def _visualize_page_xml(
draw: ImageDraw.Draw, root: ET.Element, image_size: Tuple[int, int]
):
text_lines = []
for elem in root.iter():
if elem.tag.endswith("TextLine"):
text_lines.append(elem)
line_data = []
all_polygons = []
for text_line in text_lines:
coords_elem = None
for child in text_line:
if child.tag.endswith("Coords"):
coords_elem = child
break
if coords_elem is not None:
points_str = coords_elem.get("points", "")
points = _parse_points(points_str)
if len(points) >= 3:
text_content = ""
confidence = None
for te in text_line.iter():
if te.tag.endswith("Unicode") and te.text:
text_content = te.text.strip()
break
for te in text_line.iter():
if te.tag.endswith("TextEquiv"):
conf_str = te.get("conf")
if conf_str:
try:
confidence = float(conf_str)
except:
pass
break
display_text = text_content
if confidence is not None:
display_text = f"{text_content} ({confidence:.3f})"
line_data.append((points, display_text))
all_polygons.append(points)
font_size = _get_dynamic_font_size(all_polygons, image_size)
font = _get_font(font_size)
for i, (points, text) in enumerate(line_data):
color = "red" if i % 2 == 0 else "blue"
draw.polygon(points, outline=color, width=2)
if text:
centroid_x = sum(p[0] for p in points) // len(points)
centroid_y = sum(p[1] for p in points) // len(points)
if font != ImageFont.load_default():
bbox = draw.textbbox(
(centroid_x, centroid_y), text, font=font, anchor="mm"
)
bbox = (bbox[0] - 2, bbox[1] - 2, bbox[2] + 2, bbox[3] + 2)
draw.rectangle(bbox, fill=(255, 255, 255, 200), outline="black")
draw.text(
(centroid_x, centroid_y), text, fill="black", font=font, anchor="mm"
)
else:
draw.text((centroid_x, centroid_y), text, fill="black")
def _visualize_alto_xml(
draw: ImageDraw.Draw, root: ET.Element, image_size: Tuple[int, int]
):
namespace = _get_namespace(root)
text_lines = []
for elem in root.iter():
if elem.tag.endswith("TextLine"):
text_lines.append(elem)
line_data = []
all_polygons = []
for text_line in text_lines:
points = []
for shape in text_line.iter():
if shape.tag.endswith("Shape"):
for polygon in shape.iter():
if polygon.tag.endswith("Polygon"):
points_str = polygon.get("POINTS", "")
points = _parse_points(points_str)
break
break
if len(points) >= 3:
text_content = ""
confidence = None
for string_elem in text_line.iter():
if string_elem.tag.endswith("String"):
text_content = string_elem.get("CONTENT", "")
wc_str = string_elem.get("WC")
if wc_str:
try:
confidence = float(wc_str)
except:
pass
break
display_text = text_content
if confidence is not None:
display_text = f"{text_content} ({confidence:.3f})"
line_data.append((points, display_text))
all_polygons.append(points)
font_size = _get_dynamic_font_size(all_polygons, image_size)
font = _get_font(font_size)
for i, (points, text) in enumerate(line_data):
color = "red" if i % 2 == 0 else "blue"
draw.polygon(points, outline=color, width=2)
if text:
centroid_x = sum(p[0] for p in points) // len(points)
centroid_y = sum(p[1] for p in points) // len(points)
if font != ImageFont.load_default():
bbox = draw.textbbox(
(centroid_x, centroid_y), text, font=font, anchor="mm"
)
bbox = (bbox[0] - 2, bbox[1] - 2, bbox[2] + 2, bbox[3] + 2)
draw.rectangle(bbox, fill=(255, 255, 255, 200), outline="black")
draw.text(
(centroid_x, centroid_y), text, fill="black", font=font, anchor="mm"
)
else:
draw.text((centroid_x, centroid_y), text, fill="black")