import gradio as gr
import yaml
import json
import base64
import tempfile
import os
from typing import Dict, List, Optional, Literal
from datetime import datetime
from PIL import Image, ImageDraw, ImageFont
import io
import spaces
from htrflow.volume.volume import Collection
from htrflow.pipeline.pipeline import Pipeline
PIPELINE_CONFIGS = {
"letter_english": {
"steps": [
{
"step": "Segmentation",
"settings": {
"model": "yolo",
"model_settings": {
"model": "Riksarkivet/yolov9-lines-within-regions-1"
},
"generation_settings": {"batch_size": 8},
},
},
{
"step": "TextRecognition",
"settings": {
"model": "TrOCR",
"model_settings": {"model": "microsoft/trocr-base-handwritten"},
"generation_settings": {"batch_size": 16},
},
},
{"step": "OrderLines"},
]
},
"letter_swedish": {
"steps": [
{
"step": "Segmentation",
"settings": {
"model": "yolo",
"model_settings": {
"model": "Riksarkivet/yolov9-lines-within-regions-1"
},
"generation_settings": {"batch_size": 8},
},
},
{
"step": "TextRecognition",
"settings": {
"model": "TrOCR",
"model_settings": {
"model": "Riksarkivet/trocr-base-handwritten-hist-swe-2"
},
"generation_settings": {"batch_size": 16},
},
},
{"step": "OrderLines"},
]
},
"spread_english": {
"steps": [
{
"step": "Segmentation",
"settings": {
"model": "yolo",
"model_settings": {"model": "Riksarkivet/yolov9-regions-1"},
"generation_settings": {"batch_size": 4},
},
},
{
"step": "Segmentation",
"settings": {
"model": "yolo",
"model_settings": {
"model": "Riksarkivet/yolov9-lines-within-regions-1"
},
"generation_settings": {"batch_size": 8},
},
},
{
"step": "TextRecognition",
"settings": {
"model": "TrOCR",
"model_settings": {"model": "microsoft/trocr-base-handwritten"},
"generation_settings": {"batch_size": 16},
},
},
{"step": "ReadingOrderMarginalia", "settings": {"two_page": True}},
]
},
"spread_swedish": {
"steps": [
{
"step": "Segmentation",
"settings": {
"model": "yolo",
"model_settings": {"model": "Riksarkivet/yolov9-regions-1"},
"generation_settings": {"batch_size": 4},
},
},
{
"step": "Segmentation",
"settings": {
"model": "yolo",
"model_settings": {
"model": "Riksarkivet/yolov9-lines-within-regions-1"
},
"generation_settings": {"batch_size": 8},
},
},
{
"step": "TextRecognition",
"settings": {
"model": "TrOCR",
"model_settings": {
"model": "Riksarkivet/trocr-base-handwritten-hist-swe-2"
},
"generation_settings": {"batch_size": 16},
},
},
{"step": "ReadingOrderMarginalia", "settings": {"two_page": True}},
]
},
}
@spaces.GPU
def process_htr(
image: Image.Image,
document_type: Literal[
"letter_english", "letter_swedish", "spread_english", "spread_swedish"
] = "spread_swedish",
confidence_threshold: float = 0.8,
custom_settings: Optional[str] = None,
) -> Dict:
"""
Process handwritten text recognition on uploaded images using HTRflow pipelines.
Supports templates for different document types (letters vs spreads) and
languages (English vs Swedish). Uses HTRflow's modular pipeline system with
configurable segmentation and text recognition models.
Args:
image (Image.Image): PIL Image object to process
document_type (str): Type of document processing template to use
confidence_threshold (float): Minimum confidence threshold for text recognition
custom_settings (str, optional): JSON string with custom pipeline settings
Returns:
dict: Processing results including extracted text, metadata, and processing state
"""
try:
if image is None:
return {"success": False, "error": "No image provided", "results": None}
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
image.save(temp_file.name, "PNG")
temp_image_path = temp_file.name
try:
if custom_settings:
try:
config = json.loads(custom_settings)
except json.JSONDecodeError:
return {
"success": False,
"error": "Invalid JSON in custom_settings parameter",
"results": None,
}
else:
config = PIPELINE_CONFIGS[document_type]
collection = Collection([temp_image_path])
pipeline = Pipeline.from_config(config)
processed_collection = pipeline.run(collection)
results = extract_processing_results(
processed_collection, confidence_threshold
)
img_buffer = io.BytesIO()
image.save(img_buffer, format="PNG")
image_base64 = base64.b64encode(img_buffer.getvalue()).decode("utf-8")
processing_state = {
"collection": serialize_collection(processed_collection),
"config": config,
"image_base64": image_base64,
"image_size": image.size,
"document_type": document_type,
"confidence_threshold": confidence_threshold,
"timestamp": datetime.now().isoformat(),
}
return {
"success": True,
"results": results,
"processing_state": json.dumps(processing_state),
"metadata": {
"total_lines": len(results.get("text_lines", [])),
"average_confidence": calculate_average_confidence(results),
"document_type": document_type,
"image_dimensions": image.size,
},
}
finally:
if os.path.exists(temp_image_path):
os.unlink(temp_image_path)
except Exception as e:
return {
"success": False,
"error": f"HTR processing failed: {str(e)}",
"results": None,
}
def visualize_results(
processing_state: str,
visualization_type: Literal[
"overlay", "confidence_heatmap", "text_regions"
] = "overlay",
show_confidence: bool = True,
highlight_low_confidence: bool = True,
image: Optional[Image.Image] = None,
) -> Dict:
"""
Generate interactive visualizations of HTR processing results.
Creates visual representations of text recognition results including bounding box
overlays, confidence heatmaps, and region segmentation displays. Supports multiple
visualization modes for different analysis needs.
Args:
processing_state (str): JSON string containing HTR processing results and metadata
visualization_type (str): Type of visualization to generate
show_confidence (bool): Whether to display confidence scores on visualization
highlight_low_confidence (bool): Whether to highlight low-confidence regions
image (Image.Image, optional): PIL Image object to use instead of state image
Returns:
dict: Visualization data including base64-encoded images and metadata
"""
try:
state = json.loads(processing_state)
collection = deserialize_collection(state["collection"])
confidence_threshold = state["confidence_threshold"]
if image is not None:
original_image = image
else:
image_data = base64.b64decode(state["image_base64"])
original_image = Image.open(io.BytesIO(image_data))
if visualization_type == "overlay":
viz_image = create_text_overlay_visualization(
original_image, collection, show_confidence, highlight_low_confidence
)
elif visualization_type == "confidence_heatmap":
viz_image = create_confidence_heatmap(
original_image, collection, confidence_threshold
)
elif visualization_type == "text_regions":
viz_image = create_region_visualization(original_image, collection)
img_buffer = io.BytesIO()
viz_image.save(img_buffer, format="PNG")
img_base64 = base64.b64encode(img_buffer.getvalue()).decode("utf-8")
viz_metadata = generate_visualization_metadata(collection, visualization_type)
return {
"success": True,
"visualization": {
"image_base64": img_base64,
"image_format": "PNG",
"visualization_type": visualization_type,
"dimensions": viz_image.size,
},
"metadata": viz_metadata,
"interactive_elements": extract_interactive_elements(collection),
}
except Exception as e:
return {
"success": False,
"error": f"Visualization generation failed: {str(e)}",
"visualization": None,
}
def export_results(
processing_state: str,
output_formats: List[Literal["txt", "json", "alto", "page"]] = ["txt"],
include_metadata: bool = True,
confidence_filter: float = 0.0,
) -> Dict:
"""
Export HTR results to multiple formats including plain text, structured JSON, ALTO XML, and PAGE XML.
Supports HTRflow's native export functionality with configurable output formats and
filtering options. Maintains document structure and metadata across all export formats.
Args:
processing_state (str): JSON string containing HTR processing results
output_formats (List[str]): List of output formats to generate
include_metadata (bool): Whether to include processing metadata in exports
confidence_filter (float): Minimum confidence threshold for included text
Returns:
dict: Export results with content for each requested format
"""
try:
# Parse processing state
state = json.loads(processing_state)
collection = deserialize_collection(state["collection"])
config = state["config"]
# Generate exports for each requested format
exports = {}
for format_type in output_formats:
if format_type == "txt":
exports["txt"] = export_plain_text(
collection, confidence_filter, include_metadata
)
elif format_type == "json":
exports["json"] = export_structured_json(
collection, confidence_filter, include_metadata
)
elif format_type == "alto":
exports["alto"] = export_alto_xml(
collection, confidence_filter, include_metadata
)
elif format_type == "page":
exports["page"] = export_page_xml(
collection, confidence_filter, include_metadata
)
# Calculate export statistics
export_stats = calculate_export_statistics(collection, confidence_filter)
return {
"success": True,
"exports": exports,
"statistics": export_stats,
"export_metadata": {
"formats_generated": output_formats,
"confidence_filter": confidence_filter,
"include_metadata": include_metadata,
"timestamp": datetime.now().isoformat(),
},
}
except Exception as e:
return {
"success": False,
"error": f"Export generation failed: {str(e)}",
"exports": None,
}
# Helper Functions
def extract_processing_results(
collection: Collection, confidence_threshold: float
) -> Dict:
"""Extract structured results from processed HTRflow Collection."""
results = {
"extracted_text": "",
"text_lines": [],
"regions": [],
"confidence_scores": [],
}
# Traverse collection hierarchy to extract text and metadata
for page in collection.pages:
for node in page.traverse():
if hasattr(node, "text") and node.text:
if (
hasattr(node, "confidence")
and node.confidence >= confidence_threshold
):
results["text_lines"].append(
{
"text": node.text,
"confidence": node.confidence,
"bbox": getattr(node, "bbox", None),
"node_id": getattr(node, "id", None),
}
)
results["extracted_text"] += node.text + "\n"
results["confidence_scores"].append(node.confidence)
return results
def serialize_collection(collection: Collection) -> str:
"""Serialize HTRflow Collection to JSON string for state storage."""
serialized_data = {"pages": [], "metadata": getattr(collection, "metadata", {})}
for page in collection.pages:
page_data = {
"nodes": [],
"image_path": getattr(page, "image_path", None),
"dimensions": getattr(page, "dimensions", None),
}
for node in page.traverse():
node_data = {
"text": getattr(node, "text", ""),
"confidence": getattr(node, "confidence", 1.0),
"bbox": getattr(node, "bbox", None),
"node_id": getattr(node, "id", None),
"node_type": type(node).__name__,
}
page_data["nodes"].append(node_data)
serialized_data["pages"].append(page_data)
return json.dumps(serialized_data)
def deserialize_collection(serialized_data: str):
"""Deserialize JSON string back to HTRflow Collection."""
data = json.loads(serialized_data)
# Mock collection classes for state reconstruction
class MockCollection:
def __init__(self, data):
self.pages = []
for page_data in data.get("pages", []):
page = MockPage(page_data)
self.pages.append(page)
class MockPage:
def __init__(self, page_data):
self.nodes = []
for node_data in page_data.get("nodes", []):
node = MockNode(node_data)
self.nodes.append(node)
def traverse(self):
return self.nodes
class MockNode:
def __init__(self, node_data):
self.text = node_data.get("text", "")
self.confidence = node_data.get("confidence", 1.0)
self.bbox = node_data.get("bbox")
self.id = node_data.get("node_id")
return MockCollection(data)
def calculate_average_confidence(results: Dict) -> float:
"""Calculate average confidence score from processing results."""
confidence_scores = results.get("confidence_scores", [])
if not confidence_scores:
return 0.0
return sum(confidence_scores) / len(confidence_scores)
def create_text_overlay_visualization(
image, collection, show_confidence, highlight_low_confidence
):
"""Create image with text bounding boxes and recognition results overlaid."""
viz_image = image.copy()
draw = ImageDraw.Draw(viz_image)
# Define visualization styles
bbox_color = (0, 255, 0) # Green for normal confidence
low_conf_color = (255, 165, 0) # Orange for low confidence
text_color = (255, 255, 255) # White text
try:
font = ImageFont.truetype("arial.ttf", 12)
except:
font = ImageFont.load_default()
# Draw bounding boxes and text for each recognized element
for page in collection.pages:
for node in page.traverse():
if (
hasattr(node, "bbox")
and hasattr(node, "text")
and node.bbox
and node.text
):
bbox = node.bbox
confidence = getattr(node, "confidence", 1.0)
# Choose color based on confidence
if highlight_low_confidence and confidence < 0.7:
color = low_conf_color
else:
color = bbox_color
# Draw bounding box
draw.rectangle(bbox, outline=color, width=2)
# Add confidence score if requested
if show_confidence:
conf_text = f"{confidence:.2f}"
draw.text((bbox[0], bbox[1] - 15), conf_text, fill=color, font=font)
return viz_image
def create_confidence_heatmap(image, collection, confidence_threshold):
"""Create confidence heatmap visualization."""
viz_image = image.copy()
# Create heatmap overlay based on confidence scores
for page in collection.pages:
for node in page.traverse():
if hasattr(node, "bbox") and hasattr(node, "confidence") and node.bbox:
confidence = node.confidence
# Color mapping: red (low) -> yellow (medium) -> green (high)
if confidence < 0.5:
color = (255, 0, 0, 100) # Red with transparency
elif confidence < 0.8:
color = (255, 255, 0, 100) # Yellow with transparency
else:
color = (0, 255, 0, 100) # Green with transparency
# Create overlay image for transparency
overlay = Image.new("RGBA", viz_image.size, (0, 0, 0, 0))
overlay_draw = ImageDraw.Draw(overlay)
overlay_draw.rectangle(node.bbox, fill=color)
viz_image = Image.alpha_composite(viz_image.convert("RGBA"), overlay)
return viz_image.convert("RGB")
def create_region_visualization(image, collection):
"""Create region segmentation visualization."""
viz_image = image.copy()
draw = ImageDraw.Draw(viz_image)
# Draw different colors for different region types
region_colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0)]
region_count = 0
for page in collection.pages:
for node in page.traverse():
if hasattr(node, "bbox") and node.bbox:
color = region_colors[region_count % len(region_colors)]
draw.rectangle(node.bbox, outline=color, width=3)
region_count += 1
return viz_image
def generate_visualization_metadata(collection, visualization_type):
"""Generate metadata for visualization results."""
total_elements = 0
confidence_stats = []
for page in collection.pages:
for node in page.traverse():
if hasattr(node, "text") and node.text:
total_elements += 1
if hasattr(node, "confidence"):
confidence_stats.append(node.confidence)
return {
"total_elements": total_elements,
"visualization_type": visualization_type,
"confidence_stats": {
"min": min(confidence_stats) if confidence_stats else 0,
"max": max(confidence_stats) if confidence_stats else 0,
"avg": sum(confidence_stats) / len(confidence_stats)
if confidence_stats
else 0,
},
}
def extract_interactive_elements(collection):
"""Extract interactive elements for visualization."""
elements = []
for page in collection.pages:
for node in page.traverse():
if (
hasattr(node, "bbox")
and hasattr(node, "text")
and node.bbox
and node.text
):
elements.append(
{
"bbox": node.bbox,
"text": node.text,
"confidence": getattr(node, "confidence", 1.0),
"node_id": getattr(node, "id", None),
}
)
return elements
def export_plain_text(
collection, confidence_filter: float, include_metadata: bool
) -> str:
"""Export recognition results as plain text."""
text_lines = []
if include_metadata:
text_lines.append(f"# HTR Export Results")
text_lines.append(f"# Confidence Filter: {confidence_filter}")
text_lines.append(f"# Export Time: {datetime.now().isoformat()}")
text_lines.append("")
# Extract text from collection hierarchy
for page in collection.pages:
for node in page.traverse():
if hasattr(node, "text") and node.text:
confidence = getattr(node, "confidence", 1.0)
if confidence >= confidence_filter:
text_lines.append(node.text)
return "\n".join(text_lines)
def export_structured_json(
collection, confidence_filter: float, include_metadata: bool
) -> str:
"""Export results as structured JSON with full hierarchy."""
result = {"document": {"pages": []}}
if include_metadata:
result["metadata"] = {
"confidence_filter": confidence_filter,
"export_time": datetime.now().isoformat(),
"total_pages": len(collection.pages),
}
# Build hierarchical structure
for page_idx, page in enumerate(collection.pages):
page_data = {"page_id": page_idx, "regions": []}
for node in page.traverse():
if hasattr(node, "text") and node.text:
confidence = getattr(node, "confidence", 1.0)
if confidence >= confidence_filter:
node_data = {
"text": node.text,
"confidence": confidence,
"bbox": getattr(node, "bbox", None),
"node_id": getattr(node, "id", None),
}
page_data["regions"].append(node_data)
result["document"]["pages"].append(page_data)
return json.dumps(result, indent=2, ensure_ascii=False)
def export_alto_xml(
collection, confidence_filter: float, include_metadata: bool
) -> str:
"""Export results as ALTO XML format."""
# Simplified ALTO XML generation
xml_lines = ['']
xml_lines.append('')
xml_lines.append(" ")
if include_metadata:
xml_lines.append(f" ")
xml_lines.append(f" htr_processed_image")
xml_lines.append(f" ")
xml_lines.append(" ")
xml_lines.append(" ")
xml_lines.append(" ")
for page in collection.pages:
for node in page.traverse():
if hasattr(node, "text") and node.text:
confidence = getattr(node, "confidence", 1.0)
if confidence >= confidence_filter:
bbox = getattr(node, "bbox", [0, 0, 100, 20])
xml_lines.append(
f' '
)
xml_lines.append(
f' '
)
xml_lines.append(" ")
xml_lines.append(" ")
xml_lines.append(" ")
xml_lines.append("")
return "\n".join(xml_lines)
def export_page_xml(
collection, confidence_filter: float, include_metadata: bool
) -> str:
"""Export results as PAGE XML format."""
# Simplified PAGE XML generation
xml_lines = ['']
xml_lines.append(
''
)
if include_metadata:
xml_lines.append(" ")
xml_lines.append(f" {datetime.now().isoformat()}")
xml_lines.append(" ")
xml_lines.append(" ")
for page in collection.pages:
for node in page.traverse():
if hasattr(node, "text") and node.text:
confidence = getattr(node, "confidence", 1.0)
if confidence >= confidence_filter:
bbox = getattr(node, "bbox", [0, 0, 100, 20])
xml_lines.append(f" ")
xml_lines.append(
f' '
)
xml_lines.append(f" ")
xml_lines.append(f' ')
xml_lines.append(f" {node.text}")
xml_lines.append(" ")
xml_lines.append(" ")
xml_lines.append(" ")
xml_lines.append(" ")
xml_lines.append("")
return "\n".join(xml_lines)
def calculate_export_statistics(collection, confidence_filter: float) -> Dict:
"""Calculate statistics for export results."""
total_text_elements = 0
filtered_text_elements = 0
confidence_scores = []
total_characters = 0
for page in collection.pages:
for node in page.traverse():
if hasattr(node, "text") and node.text:
total_text_elements += 1
confidence = getattr(node, "confidence", 1.0)
confidence_scores.append(confidence)
if confidence >= confidence_filter:
filtered_text_elements += 1
total_characters += len(node.text)
return {
"total_text_elements": total_text_elements,
"filtered_text_elements": filtered_text_elements,
"filter_retention_rate": filtered_text_elements / total_text_elements
if total_text_elements > 0
else 0,
"total_characters": total_characters,
"average_confidence": sum(confidence_scores) / len(confidence_scores)
if confidence_scores
else 0,
"confidence_range": {
"min": min(confidence_scores) if confidence_scores else 0,
"max": max(confidence_scores) if confidence_scores else 0,
},
}
# Main Gradio Application with MCP Server
def create_htrflow_mcp_server():
"""Create the complete HTRflow MCP server with all three tools."""
demo = gr.TabbedInterface(
[
gr.Interface(
fn=process_htr,
inputs=[
gr.Image(type="pil", label="Upload Image"),
gr.Dropdown(
choices=[
"letter_english",
"letter_swedish",
"spread_english",
"spread_swedish",
],
value="letter_english",
label="Document Type",
),
gr.Slider(0.0, 1.0, value=0.8, label="Confidence Threshold"),
gr.Textbox(
label="Custom Settings (JSON)",
placeholder="Optional custom pipeline settings",
),
],
outputs=gr.JSON(label="Processing Results"),
title="HTR Processing Tool",
description="Process handwritten text using configurable HTRflow pipelines",
api_name="process_htr",
),
gr.Interface(
fn=visualize_results,
inputs=[
gr.Textbox(
label="Processing State (JSON)",
placeholder="Paste processing results from HTR tool",
),
gr.Dropdown(
choices=["overlay", "confidence_heatmap", "text_regions"],
value="overlay",
label="Visualization Type",
),
gr.Checkbox(value=True, label="Show Confidence Scores"),
gr.Checkbox(value=True, label="Highlight Low Confidence"),
gr.Image(
type="pil",
label="Image (optional - will use image from processing state if not provided)",
),
],
outputs=gr.JSON(label="Visualization Results"),
title="Results Visualization Tool",
description="Generate interactive visualizations of HTR results",
api_name="visualize_results",
),
gr.Interface(
fn=export_results,
inputs=[
gr.Textbox(
label="Processing State (JSON)",
placeholder="Paste processing results from HTR tool",
),
gr.CheckboxGroup(
choices=["txt", "json", "alto", "page"],
value=["txt"],
label="Output Formats",
),
gr.Checkbox(value=True, label="Include Metadata"),
gr.Slider(0.0, 1.0, value=0.0, label="Confidence Filter"),
],
outputs=gr.JSON(label="Export Results"),
title="Export Tool",
description="Export HTR results to multiple formats",
api_name="export_results",
),
],
["HTR Processing", "Results Visualization", "Export Results"],
title="HTRflow MCP Server",
)
return demo
# Launch MCP Server
if __name__ == "__main__":
demo = create_htrflow_mcp_server()
demo.launch(mcp_server=True)