import gradio as gr import yaml import json import base64 import tempfile import os from typing import Dict, List, Optional, Literal from datetime import datetime from PIL import Image, ImageDraw, ImageFont import io import spaces from htrflow.volume.volume import Collection from htrflow.pipeline.pipeline import Pipeline PIPELINE_CONFIGS = { "letter_english": { "steps": [ { "step": "Segmentation", "settings": { "model": "yolo", "model_settings": { "model": "Riksarkivet/yolov9-lines-within-regions-1" }, "generation_settings": {"batch_size": 8}, }, }, { "step": "TextRecognition", "settings": { "model": "TrOCR", "model_settings": {"model": "microsoft/trocr-base-handwritten"}, "generation_settings": {"batch_size": 16}, }, }, {"step": "OrderLines"}, ] }, "letter_swedish": { "steps": [ { "step": "Segmentation", "settings": { "model": "yolo", "model_settings": { "model": "Riksarkivet/yolov9-lines-within-regions-1" }, "generation_settings": {"batch_size": 8}, }, }, { "step": "TextRecognition", "settings": { "model": "TrOCR", "model_settings": { "model": "Riksarkivet/trocr-base-handwritten-hist-swe-2" }, "generation_settings": {"batch_size": 16}, }, }, {"step": "OrderLines"}, ] }, "spread_english": { "steps": [ { "step": "Segmentation", "settings": { "model": "yolo", "model_settings": {"model": "Riksarkivet/yolov9-regions-1"}, "generation_settings": {"batch_size": 4}, }, }, { "step": "Segmentation", "settings": { "model": "yolo", "model_settings": { "model": "Riksarkivet/yolov9-lines-within-regions-1" }, "generation_settings": {"batch_size": 8}, }, }, { "step": "TextRecognition", "settings": { "model": "TrOCR", "model_settings": {"model": "microsoft/trocr-base-handwritten"}, "generation_settings": {"batch_size": 16}, }, }, {"step": "ReadingOrderMarginalia", "settings": {"two_page": True}}, ] }, "spread_swedish": { "steps": [ { "step": "Segmentation", "settings": { "model": "yolo", "model_settings": {"model": "Riksarkivet/yolov9-regions-1"}, "generation_settings": {"batch_size": 4}, }, }, { "step": "Segmentation", "settings": { "model": "yolo", "model_settings": { "model": "Riksarkivet/yolov9-lines-within-regions-1" }, "generation_settings": {"batch_size": 8}, }, }, { "step": "TextRecognition", "settings": { "model": "TrOCR", "model_settings": { "model": "Riksarkivet/trocr-base-handwritten-hist-swe-2" }, "generation_settings": {"batch_size": 16}, }, }, {"step": "ReadingOrderMarginalia", "settings": {"two_page": True}}, ] }, } @spaces.GPU def process_htr( image: Image.Image, document_type: Literal[ "letter_english", "letter_swedish", "spread_english", "spread_swedish" ] = "spread_swedish", confidence_threshold: float = 0.8, custom_settings: Optional[str] = None, ) -> Dict: """ Process handwritten text recognition on uploaded images using HTRflow pipelines. Supports templates for different document types (letters vs spreads) and languages (English vs Swedish). Uses HTRflow's modular pipeline system with configurable segmentation and text recognition models. Args: image (Image.Image): PIL Image object to process document_type (str): Type of document processing template to use confidence_threshold (float): Minimum confidence threshold for text recognition custom_settings (str, optional): JSON string with custom pipeline settings Returns: dict: Processing results including extracted text, metadata, and processing state """ try: if image is None: return {"success": False, "error": "No image provided", "results": None} with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file: image.save(temp_file.name, "PNG") temp_image_path = temp_file.name try: if custom_settings: try: config = json.loads(custom_settings) except json.JSONDecodeError: return { "success": False, "error": "Invalid JSON in custom_settings parameter", "results": None, } else: config = PIPELINE_CONFIGS[document_type] collection = Collection([temp_image_path]) pipeline = Pipeline.from_config(config) processed_collection = pipeline.run(collection) results = extract_processing_results( processed_collection, confidence_threshold ) img_buffer = io.BytesIO() image.save(img_buffer, format="PNG") image_base64 = base64.b64encode(img_buffer.getvalue()).decode("utf-8") processing_state = { "collection": serialize_collection(processed_collection), "config": config, "image_base64": image_base64, "image_size": image.size, "document_type": document_type, "confidence_threshold": confidence_threshold, "timestamp": datetime.now().isoformat(), } return { "success": True, "results": results, "processing_state": json.dumps(processing_state), "metadata": { "total_lines": len(results.get("text_lines", [])), "average_confidence": calculate_average_confidence(results), "document_type": document_type, "image_dimensions": image.size, }, } finally: if os.path.exists(temp_image_path): os.unlink(temp_image_path) except Exception as e: return { "success": False, "error": f"HTR processing failed: {str(e)}", "results": None, } def visualize_results( processing_state: str, visualization_type: Literal[ "overlay", "confidence_heatmap", "text_regions" ] = "overlay", show_confidence: bool = True, highlight_low_confidence: bool = True, image: Optional[Image.Image] = None, ) -> Dict: """ Generate interactive visualizations of HTR processing results. Creates visual representations of text recognition results including bounding box overlays, confidence heatmaps, and region segmentation displays. Supports multiple visualization modes for different analysis needs. Args: processing_state (str): JSON string containing HTR processing results and metadata visualization_type (str): Type of visualization to generate show_confidence (bool): Whether to display confidence scores on visualization highlight_low_confidence (bool): Whether to highlight low-confidence regions image (Image.Image, optional): PIL Image object to use instead of state image Returns: dict: Visualization data including base64-encoded images and metadata """ try: state = json.loads(processing_state) collection = deserialize_collection(state["collection"]) confidence_threshold = state["confidence_threshold"] if image is not None: original_image = image else: image_data = base64.b64decode(state["image_base64"]) original_image = Image.open(io.BytesIO(image_data)) if visualization_type == "overlay": viz_image = create_text_overlay_visualization( original_image, collection, show_confidence, highlight_low_confidence ) elif visualization_type == "confidence_heatmap": viz_image = create_confidence_heatmap( original_image, collection, confidence_threshold ) elif visualization_type == "text_regions": viz_image = create_region_visualization(original_image, collection) img_buffer = io.BytesIO() viz_image.save(img_buffer, format="PNG") img_base64 = base64.b64encode(img_buffer.getvalue()).decode("utf-8") viz_metadata = generate_visualization_metadata(collection, visualization_type) return { "success": True, "visualization": { "image_base64": img_base64, "image_format": "PNG", "visualization_type": visualization_type, "dimensions": viz_image.size, }, "metadata": viz_metadata, "interactive_elements": extract_interactive_elements(collection), } except Exception as e: return { "success": False, "error": f"Visualization generation failed: {str(e)}", "visualization": None, } def export_results( processing_state: str, output_formats: List[Literal["txt", "json", "alto", "page"]] = ["txt"], include_metadata: bool = True, confidence_filter: float = 0.0, ) -> Dict: """ Export HTR results to multiple formats including plain text, structured JSON, ALTO XML, and PAGE XML. Supports HTRflow's native export functionality with configurable output formats and filtering options. Maintains document structure and metadata across all export formats. Args: processing_state (str): JSON string containing HTR processing results output_formats (List[str]): List of output formats to generate include_metadata (bool): Whether to include processing metadata in exports confidence_filter (float): Minimum confidence threshold for included text Returns: dict: Export results with content for each requested format """ try: # Parse processing state state = json.loads(processing_state) collection = deserialize_collection(state["collection"]) config = state["config"] # Generate exports for each requested format exports = {} for format_type in output_formats: if format_type == "txt": exports["txt"] = export_plain_text( collection, confidence_filter, include_metadata ) elif format_type == "json": exports["json"] = export_structured_json( collection, confidence_filter, include_metadata ) elif format_type == "alto": exports["alto"] = export_alto_xml( collection, confidence_filter, include_metadata ) elif format_type == "page": exports["page"] = export_page_xml( collection, confidence_filter, include_metadata ) # Calculate export statistics export_stats = calculate_export_statistics(collection, confidence_filter) return { "success": True, "exports": exports, "statistics": export_stats, "export_metadata": { "formats_generated": output_formats, "confidence_filter": confidence_filter, "include_metadata": include_metadata, "timestamp": datetime.now().isoformat(), }, } except Exception as e: return { "success": False, "error": f"Export generation failed: {str(e)}", "exports": None, } # Helper Functions def extract_processing_results( collection: Collection, confidence_threshold: float ) -> Dict: """Extract structured results from processed HTRflow Collection.""" results = { "extracted_text": "", "text_lines": [], "regions": [], "confidence_scores": [], } # Traverse collection hierarchy to extract text and metadata for page in collection.pages: for node in page.traverse(): if hasattr(node, "text") and node.text: if ( hasattr(node, "confidence") and node.confidence >= confidence_threshold ): results["text_lines"].append( { "text": node.text, "confidence": node.confidence, "bbox": getattr(node, "bbox", None), "node_id": getattr(node, "id", None), } ) results["extracted_text"] += node.text + "\n" results["confidence_scores"].append(node.confidence) return results def serialize_collection(collection: Collection) -> str: """Serialize HTRflow Collection to JSON string for state storage.""" serialized_data = {"pages": [], "metadata": getattr(collection, "metadata", {})} for page in collection.pages: page_data = { "nodes": [], "image_path": getattr(page, "image_path", None), "dimensions": getattr(page, "dimensions", None), } for node in page.traverse(): node_data = { "text": getattr(node, "text", ""), "confidence": getattr(node, "confidence", 1.0), "bbox": getattr(node, "bbox", None), "node_id": getattr(node, "id", None), "node_type": type(node).__name__, } page_data["nodes"].append(node_data) serialized_data["pages"].append(page_data) return json.dumps(serialized_data) def deserialize_collection(serialized_data: str): """Deserialize JSON string back to HTRflow Collection.""" data = json.loads(serialized_data) # Mock collection classes for state reconstruction class MockCollection: def __init__(self, data): self.pages = [] for page_data in data.get("pages", []): page = MockPage(page_data) self.pages.append(page) class MockPage: def __init__(self, page_data): self.nodes = [] for node_data in page_data.get("nodes", []): node = MockNode(node_data) self.nodes.append(node) def traverse(self): return self.nodes class MockNode: def __init__(self, node_data): self.text = node_data.get("text", "") self.confidence = node_data.get("confidence", 1.0) self.bbox = node_data.get("bbox") self.id = node_data.get("node_id") return MockCollection(data) def calculate_average_confidence(results: Dict) -> float: """Calculate average confidence score from processing results.""" confidence_scores = results.get("confidence_scores", []) if not confidence_scores: return 0.0 return sum(confidence_scores) / len(confidence_scores) def create_text_overlay_visualization( image, collection, show_confidence, highlight_low_confidence ): """Create image with text bounding boxes and recognition results overlaid.""" viz_image = image.copy() draw = ImageDraw.Draw(viz_image) # Define visualization styles bbox_color = (0, 255, 0) # Green for normal confidence low_conf_color = (255, 165, 0) # Orange for low confidence text_color = (255, 255, 255) # White text try: font = ImageFont.truetype("arial.ttf", 12) except: font = ImageFont.load_default() # Draw bounding boxes and text for each recognized element for page in collection.pages: for node in page.traverse(): if ( hasattr(node, "bbox") and hasattr(node, "text") and node.bbox and node.text ): bbox = node.bbox confidence = getattr(node, "confidence", 1.0) # Choose color based on confidence if highlight_low_confidence and confidence < 0.7: color = low_conf_color else: color = bbox_color # Draw bounding box draw.rectangle(bbox, outline=color, width=2) # Add confidence score if requested if show_confidence: conf_text = f"{confidence:.2f}" draw.text((bbox[0], bbox[1] - 15), conf_text, fill=color, font=font) return viz_image def create_confidence_heatmap(image, collection, confidence_threshold): """Create confidence heatmap visualization.""" viz_image = image.copy() # Create heatmap overlay based on confidence scores for page in collection.pages: for node in page.traverse(): if hasattr(node, "bbox") and hasattr(node, "confidence") and node.bbox: confidence = node.confidence # Color mapping: red (low) -> yellow (medium) -> green (high) if confidence < 0.5: color = (255, 0, 0, 100) # Red with transparency elif confidence < 0.8: color = (255, 255, 0, 100) # Yellow with transparency else: color = (0, 255, 0, 100) # Green with transparency # Create overlay image for transparency overlay = Image.new("RGBA", viz_image.size, (0, 0, 0, 0)) overlay_draw = ImageDraw.Draw(overlay) overlay_draw.rectangle(node.bbox, fill=color) viz_image = Image.alpha_composite(viz_image.convert("RGBA"), overlay) return viz_image.convert("RGB") def create_region_visualization(image, collection): """Create region segmentation visualization.""" viz_image = image.copy() draw = ImageDraw.Draw(viz_image) # Draw different colors for different region types region_colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0)] region_count = 0 for page in collection.pages: for node in page.traverse(): if hasattr(node, "bbox") and node.bbox: color = region_colors[region_count % len(region_colors)] draw.rectangle(node.bbox, outline=color, width=3) region_count += 1 return viz_image def generate_visualization_metadata(collection, visualization_type): """Generate metadata for visualization results.""" total_elements = 0 confidence_stats = [] for page in collection.pages: for node in page.traverse(): if hasattr(node, "text") and node.text: total_elements += 1 if hasattr(node, "confidence"): confidence_stats.append(node.confidence) return { "total_elements": total_elements, "visualization_type": visualization_type, "confidence_stats": { "min": min(confidence_stats) if confidence_stats else 0, "max": max(confidence_stats) if confidence_stats else 0, "avg": sum(confidence_stats) / len(confidence_stats) if confidence_stats else 0, }, } def extract_interactive_elements(collection): """Extract interactive elements for visualization.""" elements = [] for page in collection.pages: for node in page.traverse(): if ( hasattr(node, "bbox") and hasattr(node, "text") and node.bbox and node.text ): elements.append( { "bbox": node.bbox, "text": node.text, "confidence": getattr(node, "confidence", 1.0), "node_id": getattr(node, "id", None), } ) return elements def export_plain_text( collection, confidence_filter: float, include_metadata: bool ) -> str: """Export recognition results as plain text.""" text_lines = [] if include_metadata: text_lines.append(f"# HTR Export Results") text_lines.append(f"# Confidence Filter: {confidence_filter}") text_lines.append(f"# Export Time: {datetime.now().isoformat()}") text_lines.append("") # Extract text from collection hierarchy for page in collection.pages: for node in page.traverse(): if hasattr(node, "text") and node.text: confidence = getattr(node, "confidence", 1.0) if confidence >= confidence_filter: text_lines.append(node.text) return "\n".join(text_lines) def export_structured_json( collection, confidence_filter: float, include_metadata: bool ) -> str: """Export results as structured JSON with full hierarchy.""" result = {"document": {"pages": []}} if include_metadata: result["metadata"] = { "confidence_filter": confidence_filter, "export_time": datetime.now().isoformat(), "total_pages": len(collection.pages), } # Build hierarchical structure for page_idx, page in enumerate(collection.pages): page_data = {"page_id": page_idx, "regions": []} for node in page.traverse(): if hasattr(node, "text") and node.text: confidence = getattr(node, "confidence", 1.0) if confidence >= confidence_filter: node_data = { "text": node.text, "confidence": confidence, "bbox": getattr(node, "bbox", None), "node_id": getattr(node, "id", None), } page_data["regions"].append(node_data) result["document"]["pages"].append(page_data) return json.dumps(result, indent=2, ensure_ascii=False) def export_alto_xml( collection, confidence_filter: float, include_metadata: bool ) -> str: """Export results as ALTO XML format.""" # Simplified ALTO XML generation xml_lines = [''] xml_lines.append('') xml_lines.append(" ") if include_metadata: xml_lines.append(f" ") xml_lines.append(f" htr_processed_image") xml_lines.append(f" ") xml_lines.append(" ") xml_lines.append(" ") xml_lines.append(" ") for page in collection.pages: for node in page.traverse(): if hasattr(node, "text") and node.text: confidence = getattr(node, "confidence", 1.0) if confidence >= confidence_filter: bbox = getattr(node, "bbox", [0, 0, 100, 20]) xml_lines.append( f' ' ) xml_lines.append( f' ' ) xml_lines.append(" ") xml_lines.append(" ") xml_lines.append(" ") xml_lines.append("") return "\n".join(xml_lines) def export_page_xml( collection, confidence_filter: float, include_metadata: bool ) -> str: """Export results as PAGE XML format.""" # Simplified PAGE XML generation xml_lines = [''] xml_lines.append( '' ) if include_metadata: xml_lines.append(" ") xml_lines.append(f" {datetime.now().isoformat()}") xml_lines.append(" ") xml_lines.append(" ") for page in collection.pages: for node in page.traverse(): if hasattr(node, "text") and node.text: confidence = getattr(node, "confidence", 1.0) if confidence >= confidence_filter: bbox = getattr(node, "bbox", [0, 0, 100, 20]) xml_lines.append(f" ") xml_lines.append( f' ' ) xml_lines.append(f" ") xml_lines.append(f' ') xml_lines.append(f" {node.text}") xml_lines.append(" ") xml_lines.append(" ") xml_lines.append(" ") xml_lines.append(" ") xml_lines.append("") return "\n".join(xml_lines) def calculate_export_statistics(collection, confidence_filter: float) -> Dict: """Calculate statistics for export results.""" total_text_elements = 0 filtered_text_elements = 0 confidence_scores = [] total_characters = 0 for page in collection.pages: for node in page.traverse(): if hasattr(node, "text") and node.text: total_text_elements += 1 confidence = getattr(node, "confidence", 1.0) confidence_scores.append(confidence) if confidence >= confidence_filter: filtered_text_elements += 1 total_characters += len(node.text) return { "total_text_elements": total_text_elements, "filtered_text_elements": filtered_text_elements, "filter_retention_rate": filtered_text_elements / total_text_elements if total_text_elements > 0 else 0, "total_characters": total_characters, "average_confidence": sum(confidence_scores) / len(confidence_scores) if confidence_scores else 0, "confidence_range": { "min": min(confidence_scores) if confidence_scores else 0, "max": max(confidence_scores) if confidence_scores else 0, }, } # Main Gradio Application with MCP Server def create_htrflow_mcp_server(): """Create the complete HTRflow MCP server with all three tools.""" demo = gr.TabbedInterface( [ gr.Interface( fn=process_htr, inputs=[ gr.Image(type="pil", label="Upload Image"), gr.Dropdown( choices=[ "letter_english", "letter_swedish", "spread_english", "spread_swedish", ], value="letter_english", label="Document Type", ), gr.Slider(0.0, 1.0, value=0.8, label="Confidence Threshold"), gr.Textbox( label="Custom Settings (JSON)", placeholder="Optional custom pipeline settings", ), ], outputs=gr.JSON(label="Processing Results"), title="HTR Processing Tool", description="Process handwritten text using configurable HTRflow pipelines", api_name="process_htr", ), gr.Interface( fn=visualize_results, inputs=[ gr.Textbox( label="Processing State (JSON)", placeholder="Paste processing results from HTR tool", ), gr.Dropdown( choices=["overlay", "confidence_heatmap", "text_regions"], value="overlay", label="Visualization Type", ), gr.Checkbox(value=True, label="Show Confidence Scores"), gr.Checkbox(value=True, label="Highlight Low Confidence"), gr.Image( type="pil", label="Image (optional - will use image from processing state if not provided)", ), ], outputs=gr.JSON(label="Visualization Results"), title="Results Visualization Tool", description="Generate interactive visualizations of HTR results", api_name="visualize_results", ), gr.Interface( fn=export_results, inputs=[ gr.Textbox( label="Processing State (JSON)", placeholder="Paste processing results from HTR tool", ), gr.CheckboxGroup( choices=["txt", "json", "alto", "page"], value=["txt"], label="Output Formats", ), gr.Checkbox(value=True, label="Include Metadata"), gr.Slider(0.0, 1.0, value=0.0, label="Confidence Filter"), ], outputs=gr.JSON(label="Export Results"), title="Export Tool", description="Export HTR results to multiple formats", api_name="export_results", ), ], ["HTR Processing", "Results Visualization", "Export Results"], title="HTRflow MCP Server", ) return demo # Launch MCP Server if __name__ == "__main__": demo = create_htrflow_mcp_server() demo.launch(mcp_server=True)