import gradio as gr import json import base64 import tempfile import os from typing import Dict, List, Optional, Literal from datetime import datetime from PIL import Image, ImageDraw, ImageFont import io import spaces import shutil from pathlib import Path from htrflow.volume.volume import Collection from htrflow.pipeline.pipeline import Pipeline PIPELINE_CONFIGS = { "letter_english": { "steps": [ { "step": "Segmentation", "settings": { "model": "yolo", "model_settings": {"model": "Riksarkivet/yolov9-lines-within-regions-1"}, "generation_settings": {"batch_size": 8}, }, }, { "step": "TextRecognition", "settings": { "model": "TrOCR", "model_settings": {"model": "microsoft/trocr-base-handwritten"}, "generation_settings": {"batch_size": 16}, }, }, {"step": "OrderLines"}, ] }, "letter_swedish": { "steps": [ { "step": "Segmentation", "settings": { "model": "yolo", "model_settings": {"model": "Riksarkivet/yolov9-lines-within-regions-1"}, "generation_settings": {"batch_size": 8}, }, }, { "step": "TextRecognition", "settings": { "model": "TrOCR", "model_settings": {"model": "Riksarkivet/trocr-base-handwritten-hist-swe-2"}, "generation_settings": {"batch_size": 16}, }, }, {"step": "OrderLines"}, ] }, "spread_english": { "steps": [ { "step": "Segmentation", "settings": { "model": "yolo", "model_settings": {"model": "Riksarkivet/yolov9-regions-1"}, "generation_settings": {"batch_size": 4}, }, }, { "step": "Segmentation", "settings": { "model": "yolo", "model_settings": {"model": "Riksarkivet/yolov9-lines-within-regions-1"}, "generation_settings": {"batch_size": 8}, }, }, { "step": "TextRecognition", "settings": { "model": "TrOCR", "model_settings": {"model": "microsoft/trocr-base-handwritten"}, "generation_settings": {"batch_size": 16}, }, }, {"step": "ReadingOrderMarginalia", "settings": {"two_page": True}}, ] }, "spread_swedish": { "steps": [ { "step": "Segmentation", "settings": { "model": "yolo", "model_settings": {"model": "Riksarkivet/yolov9-regions-1"}, "generation_settings": {"batch_size": 4}, }, }, { "step": "Segmentation", "settings": { "model": "yolo", "model_settings": {"model": "Riksarkivet/yolov9-lines-within-regions-1"}, "generation_settings": {"batch_size": 8}, }, }, { "step": "TextRecognition", "settings": { "model": "TrOCR", "model_settings": {"model": "Riksarkivet/trocr-base-handwritten-hist-swe-2"}, "generation_settings": {"batch_size": 16}, }, }, {"step": "ReadingOrderMarginalia", "settings": {"two_page": True}}, ] }, } @spaces.GPU def process_htr(image: Image.Image, document_type: Literal["letter_english", "letter_swedish", "spread_english", "spread_swedish"] = "spread_swedish", confidence_threshold: float = 0.8, custom_settings: Optional[str] = None) -> Dict: """Process handwritten text recognition on uploaded images using HTRflow pipelines.""" try: if image is None: return {"success": False, "error": "No image provided", "results": None} with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file: image.save(temp_file.name, "PNG") temp_image_path = temp_file.name try: if custom_settings: try: config = json.loads(custom_settings) except json.JSONDecodeError: return {"success": False, "error": "Invalid JSON in custom_settings parameter", "results": None} else: config = PIPELINE_CONFIGS[document_type] collection = Collection([temp_image_path]) pipeline = Pipeline.from_config(config) processed_collection = pipeline.run(collection) img_buffer = io.BytesIO() image.save(img_buffer, format="PNG") image_base64 = base64.b64encode(img_buffer.getvalue()).decode("utf-8") results = extract_text_results(processed_collection, confidence_threshold) processing_state = { "collection_data": serialize_collection_data(processed_collection), "image_base64": image_base64, "image_size": image.size, "document_type": document_type, "confidence_threshold": confidence_threshold, "timestamp": datetime.now().isoformat(), } return { "success": True, "results": results, "processing_state": json.dumps(processing_state), "metadata": { "total_lines": len(results.get("text_lines", [])), "average_confidence": results.get("average_confidence", 0), "document_type": document_type, "image_dimensions": image.size, }, } finally: if os.path.exists(temp_image_path): os.unlink(temp_image_path) except Exception as e: return {"success": False, "error": f"HTR processing failed: {str(e)}", "results": None} def visualize_results(processing_state: str, visualization_type: Literal["overlay", "confidence_heatmap", "text_regions"] = "overlay", show_confidence: bool = True, highlight_low_confidence: bool = True, image: Optional[Image.Image] = None) -> Dict: """Generate interactive visualizations of HTR processing results.""" try: state = json.loads(processing_state) collection_data = state["collection_data"] if image is not None: original_image = image else: image_data = base64.b64decode(state["image_base64"]) original_image = Image.open(io.BytesIO(image_data)) viz_image = create_visualization(original_image, collection_data, visualization_type, show_confidence, highlight_low_confidence) img_buffer = io.BytesIO() viz_image.save(img_buffer, format="PNG") img_base64 = base64.b64encode(img_buffer.getvalue()).decode("utf-8") return { "success": True, "visualization": { "image_base64": img_base64, "image_format": "PNG", "visualization_type": visualization_type, "dimensions": viz_image.size, }, "metadata": {"total_elements": len(collection_data.get("text_elements", []))}, } except Exception as e: return {"success": False, "error": f"Visualization generation failed: {str(e)}", "visualization": None} def export_results(processing_state: str, output_formats: List[Literal["txt", "json", "alto", "page"]] = ["txt"], confidence_filter: float = 0.0) -> Dict: """Export HTR results to multiple formats using HTRflow's native export functionality.""" try: state = json.loads(processing_state) with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file: image_data = base64.b64decode(state["image_base64"]) image = Image.open(io.BytesIO(image_data)) image.save(temp_file.name, "PNG") temp_image_path = temp_file.name try: collection = Collection([temp_image_path]) pipeline = Pipeline.from_config(PIPELINE_CONFIGS[state["document_type"]]) processed_collection = pipeline.run(collection) temp_dir = Path(tempfile.mkdtemp()) exports = {} for fmt in output_formats: export_dir = temp_dir / fmt processed_collection.save(directory=str(export_dir), serializer=fmt) export_files = [] for root, _, files in os.walk(export_dir): for file in files: file_path = os.path.join(root, file) with open(file_path, 'r', encoding='utf-8') as f: content = f.read() export_files.append({"filename": file, "content": content}) exports[fmt] = export_files shutil.rmtree(temp_dir) return { "success": True, "exports": exports, "export_metadata": { "formats_generated": output_formats, "confidence_filter": confidence_filter, "timestamp": datetime.now().isoformat(), }, } finally: if os.path.exists(temp_image_path): os.unlink(temp_image_path) except Exception as e: return {"success": False, "error": f"Export generation failed: {str(e)}", "exports": None} def extract_text_results(collection: Collection, confidence_threshold: float) -> Dict: results = {"extracted_text": "", "text_lines": [], "confidence_scores": []} for page in collection.pages: for node in page.traverse(): if hasattr(node, "text") and node.text and hasattr(node, "confidence") and node.confidence >= confidence_threshold: results["text_lines"].append({ "text": node.text, "confidence": node.confidence, "bbox": getattr(node, "bbox", None), }) results["extracted_text"] += node.text + "\n" results["confidence_scores"].append(node.confidence) results["average_confidence"] = sum(results["confidence_scores"]) / len(results["confidence_scores"]) if results["confidence_scores"] else 0 return results def serialize_collection_data(collection: Collection) -> Dict: text_elements = [] for page in collection.pages: for node in page.traverse(): if hasattr(node, "text") and node.text: text_elements.append({ "text": node.text, "confidence": getattr(node, "confidence", 1.0), "bbox": getattr(node, "bbox", None), }) return {"text_elements": text_elements} def create_visualization(image, collection_data, visualization_type, show_confidence, highlight_low_confidence): viz_image = image.copy() draw = ImageDraw.Draw(viz_image) try: font = ImageFont.truetype("arial.ttf", 12) except: font = ImageFont.load_default() for element in collection_data.get("text_elements", []): if element.get("bbox"): bbox = element["bbox"] confidence = element.get("confidence", 1.0) if visualization_type == "overlay": color = (255, 165, 0) if highlight_low_confidence and confidence < 0.7 else (0, 255, 0) draw.rectangle(bbox, outline=color, width=2) if show_confidence: draw.text((bbox[0], bbox[1] - 15), f"{confidence:.2f}", fill=color, font=font) elif visualization_type == "confidence_heatmap": if confidence < 0.5: color = (255, 0, 0, 100) elif confidence < 0.8: color = (255, 255, 0, 100) else: color = (0, 255, 0, 100) overlay = Image.new("RGBA", viz_image.size, (0, 0, 0, 0)) overlay_draw = ImageDraw.Draw(overlay) overlay_draw.rectangle(bbox, fill=color) viz_image = Image.alpha_composite(viz_image.convert("RGBA"), overlay) elif visualization_type == "text_regions": colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0)] color = colors[hash(str(bbox)) % len(colors)] draw.rectangle(bbox, outline=color, width=3) return viz_image.convert("RGB") if visualization_type == "confidence_heatmap" else viz_image def create_htrflow_mcp_server(): demo = gr.TabbedInterface( [ gr.Interface( fn=process_htr, inputs=[ gr.Image(type="pil", label="Upload Image"), gr.Dropdown(choices=["letter_english", "letter_swedish", "spread_english", "spread_swedish"], value="letter_english", label="Document Type"), gr.Slider(0.0, 1.0, value=0.8, label="Confidence Threshold"), gr.Textbox(label="Custom Settings (JSON)", placeholder="Optional custom pipeline settings"), ], outputs=gr.JSON(label="Processing Results"), title="HTR Processing Tool", description="Process handwritten text using configurable HTRflow pipelines", api_name="process_htr", ), gr.Interface( fn=visualize_results, inputs=[ gr.Textbox(label="Processing State (JSON)", placeholder="Paste processing results from HTR tool"), gr.Dropdown(choices=["overlay", "confidence_heatmap", "text_regions"], value="overlay", label="Visualization Type"), gr.Checkbox(value=True, label="Show Confidence Scores"), gr.Checkbox(value=True, label="Highlight Low Confidence"), gr.Image(type="pil", label="Image (optional)"), ], outputs=gr.JSON(label="Visualization Results"), title="Results Visualization Tool", description="Generate interactive visualizations of HTR results", api_name="visualize_results", ), gr.Interface( fn=export_results, inputs=[ gr.Textbox(label="Processing State (JSON)", placeholder="Paste processing results from HTR tool"), gr.CheckboxGroup(choices=["txt", "json", "alto", "page"], value=["txt"], label="Output Formats"), gr.Slider(0.0, 1.0, value=0.0, label="Confidence Filter"), ], outputs=gr.JSON(label="Export Results"), title="Export Tool", description="Export HTR results to multiple formats", api_name="export_results", ), ], ["HTR Processing", "Results Visualization", "Export Results"], title="HTRflow MCP Server", ) return demo if __name__ == "__main__": demo = create_htrflow_mcp_server() demo.launch(mcp_server=True)